VictorLJZ commited on
Commit
aff69d7
·
2 Parent(s): 1516987 7b3e756

Merge pull request #23 from bowang-lab/victor/benchmarking

Browse files
.gitignore CHANGED
@@ -179,4 +179,6 @@ model-weights/
179
 
180
  .DS_Store
181
 
182
- benchmarking/data/
 
 
 
179
 
180
  .DS_Store
181
 
182
+ benchmarking/data/
183
+ model_cache/
184
+ medgemma/
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -34,20 +34,20 @@ class ReXVQABenchmark(Benchmark):
34
  data_dir (str): Directory to store/cache downloaded data
35
  **kwargs: Additional configuration parameters
36
  split (str): Dataset split to use (default: 'test')
37
- cache_dir (str): Directory for caching HuggingFace datasets
38
  trust_remote_code (bool): Whether to trust remote code (default: False)
39
  max_questions (int): Maximum number of questions to load (default: None, load all)
40
  images_dir (str): Directory containing extracted PNG images (default: None)
41
  """
42
  self.split = kwargs.get("split", "test")
43
- self.cache_dir = kwargs.get("cache_dir", None)
44
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
45
  self.max_questions = kwargs.get("max_questions", None)
46
- self.images_dir = "benchmarking/data/rexvqa/images/deid_png"
47
  self.image_dataset = None
48
  self.image_mapping = {} # Maps study_id to image data
49
 
50
  super().__init__(data_dir, **kwargs)
 
 
 
51
 
52
  @staticmethod
53
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
@@ -166,8 +166,8 @@ class ReXVQABenchmark(Benchmark):
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
- self.download_test_vqa_data_json()
170
- self.download_rexgradient_images()
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
@@ -197,7 +197,7 @@ class ReXVQABenchmark(Benchmark):
197
  self.image_dataset = load_dataset(
198
  "rajpurkarlab/ReXGradient-160K",
199
  split="test",
200
- cache_dir=self.cache_dir,
201
  trust_remote_code=self.trust_remote_code
202
  )
203
  print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
 
34
  data_dir (str): Directory to store/cache downloaded data
35
  **kwargs: Additional configuration parameters
36
  split (str): Dataset split to use (default: 'test')
 
37
  trust_remote_code (bool): Whether to trust remote code (default: False)
38
  max_questions (int): Maximum number of questions to load (default: None, load all)
39
  images_dir (str): Directory containing extracted PNG images (default: None)
40
  """
41
  self.split = kwargs.get("split", "test")
 
42
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
  self.max_questions = kwargs.get("max_questions", None)
 
44
  self.image_dataset = None
45
  self.image_mapping = {} # Maps study_id to image data
46
 
47
  super().__init__(data_dir, **kwargs)
48
+
49
+ # Set images_dir after parent initialization
50
+ self.images_dir = f"{self.data_dir}/images/deid_png"
51
 
52
  @staticmethod
53
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
 
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
+ self.download_test_vqa_data_json(self.data_dir)
170
+ self.download_rexgradient_images(self.data_dir)
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
 
197
  self.image_dataset = load_dataset(
198
  "rajpurkarlab/ReXGradient-160K",
199
  split="test",
200
+ cache_dir=self.data_dir,
201
  trust_remote_code=self.trust_remote_code
202
  )
203
  print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -33,15 +33,15 @@ class MedRAXProvider(LLMProvider):
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
37
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
38
- "WebBrowserTool", # For web browsing and search capabilities
39
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
40
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
41
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
42
- "XRayVQATool", # For visual question answering on X-rays
43
  "XRayPhraseGroundingTool", # For locating described features in X-rays
44
- "MedGemmaVQATool"
 
 
 
 
45
  ]
46
 
47
  rag_config = RAGConfig(
@@ -64,11 +64,11 @@ class MedRAXProvider(LLMProvider):
64
  agent, tools_dict = initialize_agent(
65
  prompt_file="medrax/docs/system_prompts.txt",
66
  tools_to_use=selected_tools,
67
- model_dir="/model-weights",
68
  temp_dir="temp", # Change this to the path of the temporary directory
69
  device="cuda:0",
70
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
71
- temperature=0.3,
72
  top_p=0.95,
73
  model_kwargs=model_kwargs,
74
  rag_config=rag_config,
 
33
  print("Starting server...")
34
 
35
  selected_tools = [
 
 
 
36
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
37
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
 
39
  "XRayPhraseGroundingTool", # For locating described features in X-rays
40
+ "MedGemmaVQATool",
41
+ # "XRayVQATool", # For visual question answering on X-rays
42
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
+ # "WebBrowserTool", # For web browsing and search capabilities
44
+ # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
45
  ]
46
 
47
  rag_config = RAGConfig(
 
64
  agent, tools_dict = initialize_agent(
65
  prompt_file="medrax/docs/system_prompts.txt",
66
  tools_to_use=selected_tools,
67
+ model_dir="/scratch/ssd004/scratch/victorli/model-weights",
68
  temp_dir="temp", # Change this to the path of the temporary directory
69
  device="cuda:0",
70
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
71
+ temperature=1.0,
72
  top_p=0.95,
73
  model_kwargs=model_kwargs,
74
  rag_config=rag_config,
main.py CHANGED
@@ -10,6 +10,7 @@ with different model weights, tools, and parameters.
10
  """
11
 
12
  import warnings
 
13
  from typing import Dict, List, Optional, Any
14
  from dotenv import load_dotenv
15
  from transformers import logging
@@ -33,11 +34,11 @@ _ = load_dotenv()
33
  def initialize_agent(
34
  prompt_file: str,
35
  tools_to_use: Optional[List[str]] = None,
36
- model_dir: str = "/model-weights",
37
  temp_dir: str = "temp",
38
  device: str = "cpu",
39
- model: str = "gpt-4.1-2025-04-14",
40
- temperature: float = 0.7,
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
@@ -67,7 +68,7 @@ def initialize_agent(
67
  prompt = prompts[system_prompt]
68
 
69
  # Define the URL of the MedGemma FastAPI service.
70
- MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://127.0.0.1:8002")
71
 
72
  all_tools = {
73
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
@@ -88,24 +89,29 @@ def initialize_agent(
88
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
89
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
90
  "WebBrowserTool": lambda: WebBrowserTool(),
 
91
  "MedSAM2Tool": lambda: MedSAM2Tool(
92
  device=device, cache_dir=model_dir, temp_dir=temp_dir
93
  ),
94
  "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
95
- }
96
-
97
- try:
98
- tools_dict["PythonSandboxTool"] = create_python_sandbox()
99
- except Exception as e:
100
- print(f"Error creating PythonSandboxTool: {e}")
101
- print("Skipping PythonSandboxTool")
102
 
103
  # Initialize only selected tools or all if none specified
104
  tools_dict: Dict[str, BaseTool] = {}
105
- tools_to_use = tools_to_use or all_tools.keys()
 
 
 
106
  for tool_name in tools_to_use:
 
 
 
 
 
 
107
  if tool_name in all_tools:
108
  tools_dict[tool_name] = all_tools[tool_name]()
 
109
 
110
  # Set up checkpointing for conversation state
111
  checkpointer = MemorySaver()
@@ -145,20 +151,20 @@ if __name__ == "__main__":
145
  selected_tools = [
146
  "ImageVisualizerTool", # For displaying images in the UI
147
  # "DicomProcessorTool", # For processing DICOM medical image files
 
 
148
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
149
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
150
- "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
151
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
152
  "XRayVQATool", # For visual question answering on X-rays
153
  # "LlavaMedTool", # For multimodal medical image understanding
154
  "XRayPhraseGroundingTool", # For locating described features in X-rays
155
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
156
  # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
157
  # "WebBrowserTool", # For web browsing and search capabilities
 
158
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
159
  # "PythonSandboxTool", # Add the Python sandbox tool
160
- "MedGemmaVQATool" # Google MedGemma VQA tool
161
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
162
  ]
163
 
164
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
@@ -187,11 +193,11 @@ if __name__ == "__main__":
187
  agent, tools_dict = initialize_agent(
188
  prompt_file="medrax/docs/system_prompts.txt",
189
  tools_to_use=selected_tools,
190
- model_dir="/model-weights",
191
  temp_dir="temp", # Change this to the path of the temporary directory
192
- device="cuda:0",
193
- model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
194
- temperature=0.7,
195
  top_p=0.95,
196
  model_kwargs=model_kwargs,
197
  rag_config=rag_config,
 
10
  """
11
 
12
  import warnings
13
+ import os
14
  from typing import Dict, List, Optional, Any
15
  from dotenv import load_dotenv
16
  from transformers import logging
 
34
  def initialize_agent(
35
  prompt_file: str,
36
  tools_to_use: Optional[List[str]] = None,
37
+ model_dir: str = "model-weights",
38
  temp_dir: str = "temp",
39
  device: str = "cpu",
40
+ model: str = "gemini-2.5-pro",
41
+ temperature: float = 1.0,
42
  top_p: float = 0.95,
43
  rag_config: Optional[RAGConfig] = None,
44
  model_kwargs: Dict[str, Any] = {},
 
68
  prompt = prompts[system_prompt]
69
 
70
  # Define the URL of the MedGemma FastAPI service.
71
+ MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://172.17.8.141:8002")
72
 
73
  all_tools = {
74
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
 
89
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
90
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
91
  "WebBrowserTool": lambda: WebBrowserTool(),
92
+ "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
93
  "MedSAM2Tool": lambda: MedSAM2Tool(
94
  device=device, cache_dir=model_dir, temp_dir=temp_dir
95
  ),
96
  "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
97
+ }
 
 
 
 
 
 
98
 
99
  # Initialize only selected tools or all if none specified
100
  tools_dict: Dict[str, BaseTool] = {}
101
+
102
+ if tools_to_use is None:
103
+ tools_to_use = []
104
+
105
  for tool_name in tools_to_use:
106
+ if tool_name == "PythonSandboxTool":
107
+ try:
108
+ tools_dict["PythonSandboxTool"] = create_python_sandbox()
109
+ except Exception as e:
110
+ print(f"Error creating PythonSandboxTool: {e}")
111
+ print("Skipping PythonSandboxTool")
112
  if tool_name in all_tools:
113
  tools_dict[tool_name] = all_tools[tool_name]()
114
+
115
 
116
  # Set up checkpointing for conversation state
117
  checkpointer = MemorySaver()
 
151
  selected_tools = [
152
  "ImageVisualizerTool", # For displaying images in the UI
153
  # "DicomProcessorTool", # For processing DICOM medical image files
154
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
155
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
156
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
157
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
158
+ "MedGemmaVQATool" # Google MedGemma VQA tool
 
159
  "XRayVQATool", # For visual question answering on X-rays
160
  # "LlavaMedTool", # For multimodal medical image understanding
161
  "XRayPhraseGroundingTool", # For locating described features in X-rays
162
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
163
  # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
164
  # "WebBrowserTool", # For web browsing and search capabilities
165
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
166
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
167
  # "PythonSandboxTool", # Add the Python sandbox tool
 
 
168
  ]
169
 
170
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
 
193
  agent, tools_dict = initialize_agent(
194
  prompt_file="medrax/docs/system_prompts.txt",
195
  tools_to_use=selected_tools,
196
+ model_dir="model-weights",
197
  temp_dir="temp", # Change this to the path of the temporary directory
198
+ device="cpu",
199
+ model="gemini-2.5-pro", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
200
+ temperature=1.0,
201
  top_p=0.95,
202
  model_kwargs=model_kwargs,
203
  rag_config=rag_config,
medrax/docs/system_prompts.txt CHANGED
@@ -17,10 +17,9 @@ Examples:
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
  [CHESTAGENTBENCH_PROMPT]
20
- You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
21
- Solve using your own vision and reasoning and use tools (if available) to complement your reasoning.
22
- You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
- Think critically about and criticize the tool outputs.
24
- If you need to look up some information before asking a follow up question, you are allowed to do that.
25
  When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
26
- It is extremely important that you strictly answer in the format mentioned above.
 
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
  [CHESTAGENTBENCH_PROMPT]
20
+ You are an expert medical assistant who can answer medical questions and analyze medical images with world-class accuracy.
21
+ Use your state-of-the art reasoning and critical thinking skills to answer the questions that you are asked.
22
+ You may use tools (if available) to complement your reasoning and you are allowed to make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
+ Think critically about how to best use the tools available to you and scrutinize the tool outputs.
 
24
  When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
25
+ It is extremely important that you answer strictly in the format described above.
medrax/tools/__init__.py CHANGED
@@ -11,4 +11,3 @@ from .utils import *
11
  from .rag import *
12
  from .browsing import *
13
  from .python_tool import *
14
- from .medsam2 import *
 
11
  from .rag import *
12
  from .browsing import *
13
  from .python_tool import *
 
medrax/tools/medgemma.py DELETED
@@ -1,225 +0,0 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from pydantic import BaseModel, Field
3
- from typing import List, Optional, Any, Dict, Tuple
4
- from pathlib import Path
5
- import torch
6
- from PIL import Image
7
- from transformers import pipeline, BitsAndBytesConfig
8
- import asyncio
9
- import uvicorn
10
- import os
11
- import uuid
12
- import traceback
13
- import sys
14
- import transformers
15
-
16
- print("--- ENVIRONMENT CHECK ---")
17
- print(f"Python Executable: {sys.executable}")
18
- print(f"PyTorch version: {torch.__version__}")
19
- print(f"Transformers version: {transformers.__version__}")
20
- print("-----------------------")
21
-
22
- # --- Configuration ---
23
- CACHE_DIR = "./model_cache"
24
- UPLOAD_DIR = "./uploaded_images"
25
-
26
- # Create directories if they don't exist
27
- os.makedirs(CACHE_DIR, exist_ok=True)
28
- os.makedirs(UPLOAD_DIR, exist_ok=True)
29
-
30
- # --- Pydantic Models for API ---
31
- class VQAInput(BaseModel):
32
- prompt: str = Field(..., description="Question or instruction about the medical images")
33
- system_prompt: Optional[str] = Field(
34
- "You are an expert radiologist.",
35
- description="System prompt to set the context for the model",
36
- )
37
- max_new_tokens: int = Field(
38
- 300, description="Maximum number of tokens to generate in the response"
39
- )
40
-
41
- class VQAResponse(BaseModel):
42
- response: str
43
- metadata: Dict[str, Any]
44
-
45
- class ErrorResponse(BaseModel):
46
- error: str
47
- metadata: Dict[str, Any]
48
-
49
- # --- MedGemma Model Handling ---
50
- class MedGemmaModel:
51
- _instance = None
52
-
53
- def __new__(cls, *args, **kwargs):
54
- if not cls._instance:
55
- cls._instance = super(MedGemmaModel, cls).__new__(cls)
56
- return cls._instance
57
-
58
- def __init__(self,
59
- model_name: str = "google/medgemma-4b-it",
60
- device: Optional[str] = "cuda",
61
- dtype: torch.dtype = torch.bfloat16,
62
- load_in_4bit: bool = False):
63
- if hasattr(self, 'pipe') and self.pipe is not None:
64
- return
65
-
66
- self.device = device if device and torch.cuda.is_available() else "cpu"
67
- self.dtype = dtype
68
- self.pipe = None
69
-
70
- model_kwargs = {"torch_dtype": self.dtype, "cache_dir": CACHE_DIR}
71
-
72
- if load_in_4bit:
73
- model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
74
- model_kwargs["device_map"] = {"": self.device}
75
-
76
- try:
77
- self.pipe = pipeline("image-text-to-text",
78
- model=model_name,
79
- model_kwargs=model_kwargs,
80
- trust_remote_code=True,
81
- use_cache=True)
82
- except Exception as e:
83
- raise RuntimeError(f"Failed to initialize MedGemma pipeline: {str(e)}")
84
-
85
- def _prepare_messages(
86
- self, image_paths: List[str], prompt: str, system_prompt: str
87
- ) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
88
- images = []
89
- for path in image_paths:
90
- if not Path(path).is_file():
91
- raise FileNotFoundError(f"Image file not found: {path}")
92
-
93
- image = Image.open(path)
94
- if image.mode != "RGB":
95
- image = image.convert("RGB")
96
- images.append(image)
97
-
98
- messages = [
99
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
100
- {
101
- "role": "user",
102
- "content": [{"type": "text", "text": prompt}]
103
- + [{"type": "image", "image": img} for img in images],
104
- },
105
- ]
106
-
107
- return messages, images
108
-
109
- async def aget_response(self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int) -> str:
110
- loop = asyncio.get_event_loop()
111
- messages, _ = await loop.run_in_executor(None, self._prepare_messages, image_paths, prompt, system_prompt)
112
-
113
- def _generate():
114
- return self.pipe(
115
- text=messages,
116
- max_new_tokens=max_new_tokens,
117
- do_sample=False,
118
- )
119
-
120
- output = await loop.run_in_executor(None, _generate)
121
-
122
- if (
123
- isinstance(output, list)
124
- and output
125
- and isinstance(output[0].get("generated_text"), list)
126
- ):
127
- generated_text = output[0]["generated_text"]
128
- if generated_text:
129
- return generated_text[-1].get("content", "").strip()
130
-
131
- return "No response generated"
132
-
133
- # --- FastAPI Application ---
134
- app = FastAPI(title="MedGemma VQA API",
135
- description="API for medical visual question answering using Google's MedGemma model.")
136
-
137
- medgemma_model: Optional[MedGemmaModel] = None
138
-
139
- @app.on_event("startup")
140
- async def startup_event():
141
- """Load the MedGemma model at application startup."""
142
- global medgemma_model
143
- try:
144
- medgemma_model = MedGemmaModel()
145
- print("MedGemma model loaded successfully.")
146
- except RuntimeError as e:
147
- print(f"Error loading MedGemma model: {e}")
148
- # Depending on the desired behavior, you might want to exit the application
149
- # if the model fails to load.
150
- # exit(1)
151
-
152
- @app.post("/analyze-images/",
153
- response_model=VQAResponse,
154
- responses={500: {"model": ErrorResponse},
155
- 404: {"model": ErrorResponse}},
156
- summary="Analyze one or more medical images")
157
- async def analyze_images(
158
- images: List[UploadFile] = File(..., description="List of medical image files to analyze (JPG or PNG)."),
159
- prompt: str = Form(..., description="Question or instruction about the medical images."),
160
- system_prompt: Optional[str] = Form("You are an expert radiologist.", description="System prompt to set the context for the model."),
161
- max_new_tokens: int = Form(100, description="Maximum number of tokens to generate in the response.")
162
- ):
163
- """
164
- Upload one or more medical images and a prompt to get an analysis from the MedGemma model.
165
- """
166
- if medgemma_model is None or medgemma_model.pipe is None:
167
- raise HTTPException(status_code=503, detail="Model is not available. Please try again later.")
168
-
169
- image_paths = []
170
- for image in images:
171
- if image.content_type not in ["image/jpeg", "image/png"]:
172
- raise HTTPException(status_code=400, detail=f"Unsupported image format: {image.content_type}. Only JPG and PNG are supported.")
173
-
174
- # Generate a unique filename to avoid overwrites
175
- unique_filename = f"{uuid.uuid4()}_{image.filename}"
176
- file_path = os.path.join(UPLOAD_DIR, unique_filename)
177
-
178
- try:
179
- with open(file_path, "wb") as buffer:
180
- buffer.write(await image.read())
181
- image_paths.append(file_path)
182
- except Exception as e:
183
- raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
184
-
185
-
186
- try:
187
- response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
188
- metadata = {
189
- "image_paths": image_paths,
190
- "prompt": prompt,
191
- "system_prompt": system_prompt,
192
- "max_new_tokens": max_new_tokens,
193
- "num_images": len(image_paths),
194
- "analysis_status": "completed",
195
- }
196
- return VQAResponse(response=response_text, metadata=metadata)
197
- except FileNotFoundError as e:
198
- raise HTTPException(status_code=404, detail=f"Image file not found: {str(e)}")
199
- except Exception as e:
200
- print("--- AN EXCEPTION OCCURRED IN THE ENDPOINT ---")
201
- traceback.print_exc()
202
- # Catch potential CUDA out-of-memory errors and other exceptions
203
- error_message = "An unexpected error occurred during analysis."
204
- if "CUDA out of memory" in str(e):
205
- error_message = "GPU memory exhausted. Try reducing image resolution or max_new_tokens."
206
-
207
- metadata = {
208
- "image_paths": image_paths,
209
- "prompt": prompt,
210
- "analysis_status": "failed",
211
- "error_details": str(e),
212
- }
213
- raise HTTPException(status_code=500, detail=error_message)
214
- finally:
215
- # Clean up saved images
216
- for path in image_paths:
217
- try:
218
- os.remove(path)
219
- except OSError:
220
- # Log this error if needed, but don't let it crash the request
221
- pass
222
-
223
-
224
- if __name__ == "__main__":
225
- uvicorn.run(app, host="0.0.0.0", port=8002)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
medrax/tools/medgemma_client.py DELETED
@@ -1,145 +0,0 @@
1
- import httpx
2
- from typing import Dict, List, Optional, Type, Any
3
- from langchain_core.tools import BaseTool
4
- from langchain_core.callbacks import (
5
- AsyncCallbackManagerForToolRun,
6
- CallbackManagerForToolRun,
7
- )
8
- from pydantic import BaseModel, Field
9
- import os
10
-
11
- # This input schema should be identical to the one in your original tool
12
- class MedGemmaVQAInput(BaseModel):
13
- """Input schema for the MedGemma VQA Tool. The agent provides local paths to images."""
14
- image_paths: List[str] = Field(
15
- ...,
16
- description="List of paths to medical image files to analyze. These are local paths accessible to the agent.",
17
- )
18
- prompt: str = Field(..., description="Question or instruction about the medical images")
19
- system_prompt: Optional[str] = Field(
20
- "You are an expert radiologist.",
21
- description="System prompt to set the context for the model",
22
- )
23
- max_new_tokens: int = Field(
24
- 300, description="Maximum number of tokens to generate in the response"
25
- )
26
-
27
- class MedGemmaAPIClientTool(BaseTool):
28
- """
29
- A client tool to interact with a remote MedGemma VQA FastAPI service.
30
- This tool takes local image paths, reads them, and sends them to the API endpoint
31
- for analysis.
32
- """
33
- name: str = "medgemma_medical_vqa_service"
34
- description: str = (
35
- "Sends medical images and a prompt to a specialized MedGemma VQA service for analysis. "
36
- "Use this for expert-level reasoning, diagnosis assistance, and detailed image interpretation "
37
- "across modalities like chest X-rays, dermatology, etc. Input must be local image paths and a prompt."
38
- )
39
- args_schema: Type[BaseModel] = MedGemmaVQAInput
40
- api_url: str # The URL of the running FastAPI service
41
-
42
- def _run(
43
- self,
44
- image_paths: List[str],
45
- prompt: str,
46
- system_prompt: str = "You are an expert radiologist.",
47
- max_new_tokens: int = 300,
48
- run_manager: Optional[CallbackManagerForToolRun] = None,
49
- ) -> str:
50
- """Execute the tool synchronously."""
51
- # httpx is a modern HTTP client that supports sync and async
52
- timeout_config = httpx.Timeout(300.0, connect=10.0)
53
- client = httpx.Client(timeout=timeout_config)
54
-
55
- # Prepare the multipart form data
56
- files_to_send = []
57
- opened_files = []
58
- try:
59
- for path in image_paths:
60
- f = open(path, "rb")
61
- opened_files.append(f)
62
- # The key 'images' must match the parameter name in the FastAPI endpoint
63
- files_to_send.append(("images", (os.path.basename(path), f, "image/jpeg")))
64
-
65
- data = {
66
- "prompt": prompt,
67
- "system_prompt": system_prompt,
68
- "max_new_tokens": max_new_tokens,
69
- }
70
-
71
- response = client.post(
72
- f"{self.api_url}/analyze-images/",
73
- data=data,
74
- files=files_to_send,
75
- )
76
- response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
77
-
78
- # The agent expects a string response from a tool
79
- return response.json()["response"]
80
-
81
- # --- KEY FIX 3: More specific exception handling for clearer errors ---
82
- except httpx.TimeoutException:
83
- return f"Error: The request to the MedGemma API timed out after {timeout_config.read} seconds. The server might be overloaded or the model is taking too long to load. Try again later."
84
- except httpx.ConnectError:
85
- return f"Error: Could not connect to the MedGemma API. Check if the server address '{self.api_url}' is correct and running."
86
- except httpx.HTTPStatusError as e:
87
- return f"Error: The MedGemma API returned an error (Status {e.response.status_code}): {e.response.text}"
88
- except Exception as e:
89
- return f"An unexpected error occurred in the MedGemma client tool: {str(e)}"
90
- finally:
91
- # Important: Ensure all opened files are closed.
92
- for f in opened_files:
93
- f.close()
94
-
95
- async def _arun(
96
- self,
97
- image_paths: List[str],
98
- prompt: str,
99
- system_prompt: str = "You are an expert radiologist.",
100
- max_new_tokens: int = 300,
101
- run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
102
- ) -> str:
103
- """Execute the tool asynchronously."""
104
- async with httpx.AsyncClient() as client:
105
- files_to_send = []
106
- opened_files = []
107
- try:
108
- # Note: File I/O is blocking, for a truly async app you might use aiofiles
109
- # But for this use case, this is generally acceptable.
110
- for path in image_paths:
111
- f = open(path, "rb")
112
- opened_files.append(f)
113
- files_to_send.append(("images", (os.path.basename(path), f, "image/jpeg")))
114
-
115
- data = {
116
- "prompt": prompt,
117
- "system_prompt": system_prompt,
118
- "max_new_tokens": max_new_tokens,
119
- }
120
-
121
- response = await client.post(
122
- f"{self.api_url}/analyze-images/",
123
- data=data,
124
- files=files_to_send,
125
- timeout=120.0
126
- )
127
- response.raise_for_status()
128
-
129
- return response.json()["response"]
130
-
131
- except httpx.HTTPStatusError as e:
132
- return f"Error calling MedGemma API: {e.response.status_code} - {e.response.text}"
133
- except Exception as e:
134
- return f"An unexpected error occurred: {str(e)}"
135
- finally:
136
- for f in opened_files:
137
- f.close()
138
-
139
- if __name__ == "__main__":
140
- client_tool = MedGemmaAPIClientTool(api_url="http://localhost:8002")
141
- result = client_tool.run({
142
- "image_paths": ["demo/chest/pneumonia1.jpg"],
143
- "prompt": "What abnormality do you see?"
144
- })
145
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -57,7 +57,6 @@ dependencies = [
57
  "torch>=2.2.0",
58
  "torchvision>=0.10.0",
59
  "scikit-image>=0.18.0",
60
- "gradio>=5.0.0",
61
  "opencv-python>=4.8.0",
62
  "matplotlib>=3.8.0",
63
  "diffusers>=0.20.0",
@@ -65,13 +64,11 @@ dependencies = [
65
  "pylibjpeg>=1.0.0",
66
  "jupyter>=1.0.0",
67
  "albumentations>=1.0.0",
68
- "pyarrow>=10.0.0",
69
  "chromadb>=0.0.10",
70
  "pinecone-client>=3.2.2",
71
  "langchain-pinecone>=0.0.1",
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
- "langchain-sandbox>=0.0.6",
75
  "seaborn>=0.12.0",
76
  "huggingface_hub>=0.17.0",
77
  "iopath>=0.1.10",
 
57
  "torch>=2.2.0",
58
  "torchvision>=0.10.0",
59
  "scikit-image>=0.18.0",
 
60
  "opencv-python>=4.8.0",
61
  "matplotlib>=3.8.0",
62
  "diffusers>=0.20.0",
 
64
  "pylibjpeg>=1.0.0",
65
  "jupyter>=1.0.0",
66
  "albumentations>=1.0.0",
 
67
  "chromadb>=0.0.10",
68
  "pinecone-client>=3.2.2",
69
  "langchain-pinecone>=0.0.1",
70
  "langchain-google-genai>=0.1.0",
71
  "ray>=2.9.0",
 
72
  "seaborn>=0.12.0",
73
  "huggingface_hub>=0.17.0",
74
  "iopath>=0.1.10",