Emily Xie commited on
Commit
aa6bc6b
·
1 Parent(s): 35945d9

for test on gpu

Browse files
main.py CHANGED
@@ -65,6 +65,9 @@ def initialize_agent(
65
  prompts = load_prompts_from_file(prompt_file)
66
  prompt = prompts["MEDICAL_ASSISTANT"]
67
 
 
 
 
68
  all_tools = {
69
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
70
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
@@ -87,6 +90,7 @@ def initialize_agent(
87
  "MedSAM2Tool": lambda: MedSAM2Tool(
88
  device=device, cache_dir=model_dir, temp_dir=temp_dir
89
  ),
 
90
  }
91
 
92
  try:
@@ -149,10 +153,11 @@ if __name__ == "__main__":
149
  # "LlavaMedTool", # For multimodal medical image understanding
150
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
151
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
152
- "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
153
- "WebBrowserTool", # For web browsing and search capabilities
154
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
155
  # "PythonSandboxTool", # Add the Python sandbox tool
 
156
  ]
157
 
158
  # Configure the Retrieval Augmented Generation (RAG) system
 
65
  prompts = load_prompts_from_file(prompt_file)
66
  prompt = prompts["MEDICAL_ASSISTANT"]
67
 
68
+ # Define the URL of the MedGemma FastAPI service.
69
+ MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://127.0.0.1:8002")
70
+
71
  all_tools = {
72
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
73
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
 
90
  "MedSAM2Tool": lambda: MedSAM2Tool(
91
  device=device, cache_dir=model_dir, temp_dir=temp_dir
92
  ),
93
+ "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(api_url=MEDGEMMA_API_URL)
94
  }
95
 
96
  try:
 
153
  # "LlavaMedTool", # For multimodal medical image understanding
154
  # "XRayPhraseGroundingTool", # For locating described features in X-rays
155
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
156
+ # "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
157
+ # "WebBrowserTool", # For web browsing and search capabilities
158
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
159
  # "PythonSandboxTool", # Add the Python sandbox tool
160
+ "MedGemmaVQATool" # For visual question answering on medical images
161
  ]
162
 
163
  # Configure the Retrieval Augmented Generation (RAG) system
medrax/tools/__init__.py CHANGED
@@ -13,4 +13,4 @@ from .rag import *
13
  from .web_browser import *
14
  from .python_tool import *
15
  from .medsam2 import *
16
-
 
13
  from .web_browser import *
14
  from .python_tool import *
15
  from .medsam2 import *
16
+ from .medgemma_client import *
medrax/tools/medgemma.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Optional, Any, Dict, Tuple
4
+ from pathlib import Path
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import pipeline, BitsAndBytesConfig
8
+ import asyncio
9
+ import uvicorn
10
+ import os
11
+ import uuid
12
+ import traceback
13
+ import sys
14
+ import transformers
15
+
16
+ print("--- ENVIRONMENT CHECK ---")
17
+ print(f"Python Executable: {sys.executable}")
18
+ print(f"PyTorch version: {torch.__version__}")
19
+ print(f"Transformers version: {transformers.__version__}")
20
+ print("-----------------------")
21
+
22
+ # --- Configuration ---
23
+ CACHE_DIR = "./model_cache"
24
+ UPLOAD_DIR = "./uploaded_images"
25
+
26
+ # Create directories if they don't exist
27
+ os.makedirs(CACHE_DIR, exist_ok=True)
28
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
29
+
30
+ # --- Pydantic Models for API ---
31
+ class VQAInput(BaseModel):
32
+ prompt: str = Field(..., description="Question or instruction about the medical images")
33
+ system_prompt: Optional[str] = Field(
34
+ "You are an expert radiologist.",
35
+ description="System prompt to set the context for the model",
36
+ )
37
+ max_new_tokens: int = Field(
38
+ 300, description="Maximum number of tokens to generate in the response"
39
+ )
40
+
41
+ class VQAResponse(BaseModel):
42
+ response: str
43
+ metadata: Dict[str, Any]
44
+
45
+ class ErrorResponse(BaseModel):
46
+ error: str
47
+ metadata: Dict[str, Any]
48
+
49
+ # --- MedGemma Model Handling ---
50
+ class MedGemmaModel:
51
+ _instance = None
52
+
53
+ def __new__(cls, *args, **kwargs):
54
+ if not cls._instance:
55
+ cls._instance = super(MedGemmaModel, cls).__new__(cls)
56
+ return cls._instance
57
+
58
+ def __init__(self,
59
+ model_name: str = "google/medgemma-4b-it",
60
+ device: Optional[str] = "cuda",
61
+ dtype: torch.dtype = torch.bfloat16,
62
+ load_in_4bit: bool = False):
63
+ if hasattr(self, 'pipe') and self.pipe is not None:
64
+ return
65
+
66
+ self.device = device if device and torch.cuda.is_available() else "cpu"
67
+ self.dtype = dtype
68
+ self.pipe = None
69
+
70
+ model_kwargs = {"torch_dtype": self.dtype, "cache_dir": CACHE_DIR}
71
+
72
+ if load_in_4bit:
73
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
74
+ model_kwargs["device_map"] = {"": self.device}
75
+
76
+ try:
77
+ self.pipe = pipeline("image-text-to-text",
78
+ model=model_name,
79
+ model_kwargs=model_kwargs,
80
+ trust_remote_code=True,
81
+ use_cache=True)
82
+ except Exception as e:
83
+ raise RuntimeError(f"Failed to initialize MedGemma pipeline: {str(e)}")
84
+
85
+ def _prepare_messages(
86
+ self, image_paths: List[str], prompt: str, system_prompt: str
87
+ ) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
88
+ images = []
89
+ for path in image_paths:
90
+ if not Path(path).is_file():
91
+ raise FileNotFoundError(f"Image file not found: {path}")
92
+
93
+ image = Image.open(path)
94
+ if image.mode != "RGB":
95
+ image = image.convert("RGB")
96
+ images.append(image)
97
+
98
+ messages = [
99
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
100
+ {
101
+ "role": "user",
102
+ "content": [{"type": "text", "text": prompt}]
103
+ + [{"type": "image", "image": img} for img in images],
104
+ },
105
+ ]
106
+
107
+ return messages, images
108
+
109
+ async def aget_response(self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int) -> str:
110
+ loop = asyncio.get_event_loop()
111
+ messages, _ = await loop.run_in_executor(None, self._prepare_messages, image_paths, prompt, system_prompt)
112
+
113
+ def _generate():
114
+ return self.pipe(
115
+ text=messages,
116
+ max_new_tokens=max_new_tokens,
117
+ do_sample=False,
118
+ )
119
+
120
+ output = await loop.run_in_executor(None, _generate)
121
+
122
+ if (
123
+ isinstance(output, list)
124
+ and output
125
+ and isinstance(output[0].get("generated_text"), list)
126
+ ):
127
+ generated_text = output[0]["generated_text"]
128
+ if generated_text:
129
+ return generated_text[-1].get("content", "").strip()
130
+
131
+ return "No response generated"
132
+
133
+ # --- FastAPI Application ---
134
+ app = FastAPI(title="MedGemma VQA API",
135
+ description="API for medical visual question answering using Google's MedGemma model.")
136
+
137
+ medgemma_model: Optional[MedGemmaModel] = None
138
+
139
+ @app.on_event("startup")
140
+ async def startup_event():
141
+ """Load the MedGemma model at application startup."""
142
+ global medgemma_model
143
+ try:
144
+ medgemma_model = MedGemmaModel()
145
+ print("MedGemma model loaded successfully.")
146
+ except RuntimeError as e:
147
+ print(f"Error loading MedGemma model: {e}")
148
+ # Depending on the desired behavior, you might want to exit the application
149
+ # if the model fails to load.
150
+ # exit(1)
151
+
152
+ @app.post("/analyze-images/",
153
+ response_model=VQAResponse,
154
+ responses={500: {"model": ErrorResponse},
155
+ 404: {"model": ErrorResponse}},
156
+ summary="Analyze one or more medical images")
157
+ async def analyze_images(
158
+ images: List[UploadFile] = File(..., description="List of medical image files to analyze (JPG or PNG)."),
159
+ prompt: str = Form(..., description="Question or instruction about the medical images."),
160
+ system_prompt: Optional[str] = Form("You are an expert radiologist.", description="System prompt to set the context for the model."),
161
+ max_new_tokens: int = Form(100, description="Maximum number of tokens to generate in the response.")
162
+ ):
163
+ """
164
+ Upload one or more medical images and a prompt to get an analysis from the MedGemma model.
165
+ """
166
+ if medgemma_model is None or medgemma_model.pipe is None:
167
+ raise HTTPException(status_code=503, detail="Model is not available. Please try again later.")
168
+
169
+ image_paths = []
170
+ for image in images:
171
+ if image.content_type not in ["image/jpeg", "image/png"]:
172
+ raise HTTPException(status_code=400, detail=f"Unsupported image format: {image.content_type}. Only JPG and PNG are supported.")
173
+
174
+ # Generate a unique filename to avoid overwrites
175
+ unique_filename = f"{uuid.uuid4()}_{image.filename}"
176
+ file_path = os.path.join(UPLOAD_DIR, unique_filename)
177
+
178
+ try:
179
+ with open(file_path, "wb") as buffer:
180
+ buffer.write(await image.read())
181
+ image_paths.append(file_path)
182
+ except Exception as e:
183
+ raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
184
+
185
+
186
+ try:
187
+ response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
188
+ metadata = {
189
+ "image_paths": image_paths,
190
+ "prompt": prompt,
191
+ "system_prompt": system_prompt,
192
+ "max_new_tokens": max_new_tokens,
193
+ "num_images": len(image_paths),
194
+ "analysis_status": "completed",
195
+ }
196
+ return VQAResponse(response=response_text, metadata=metadata)
197
+ except FileNotFoundError as e:
198
+ raise HTTPException(status_code=404, detail=f"Image file not found: {str(e)}")
199
+ except Exception as e:
200
+ print("--- AN EXCEPTION OCCURRED IN THE ENDPOINT ---")
201
+ traceback.print_exc()
202
+ # Catch potential CUDA out-of-memory errors and other exceptions
203
+ error_message = "An unexpected error occurred during analysis."
204
+ if "CUDA out of memory" in str(e):
205
+ error_message = "GPU memory exhausted. Try reducing image resolution or max_new_tokens."
206
+
207
+ metadata = {
208
+ "image_paths": image_paths,
209
+ "prompt": prompt,
210
+ "analysis_status": "failed",
211
+ "error_details": str(e),
212
+ }
213
+ raise HTTPException(status_code=500, detail=error_message)
214
+ finally:
215
+ # Clean up saved images
216
+ for path in image_paths:
217
+ try:
218
+ os.remove(path)
219
+ except OSError:
220
+ # Log this error if needed, but don't let it crash the request
221
+ pass
222
+
223
+
224
+ if __name__ == "__main__":
225
+ uvicorn.run(app, host="0.0.0.0", port=8002)
medrax/tools/medgemma_client.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from typing import Dict, List, Optional, Type, Any
3
+ from langchain_core.tools import BaseTool
4
+ from langchain_core.callbacks import (
5
+ AsyncCallbackManagerForToolRun,
6
+ CallbackManagerForToolRun,
7
+ )
8
+ from pydantic import BaseModel, Field
9
+ import os
10
+
11
+ # This input schema should be identical to the one in your original tool
12
+ class MedGemmaVQAInput(BaseModel):
13
+ """Input schema for the MedGemma VQA Tool. The agent provides local paths to images."""
14
+ image_paths: List[str] = Field(
15
+ ...,
16
+ description="List of paths to medical image files to analyze. These are local paths accessible to the agent.",
17
+ )
18
+ prompt: str = Field(..., description="Question or instruction about the medical images")
19
+ system_prompt: Optional[str] = Field(
20
+ "You are an expert radiologist.",
21
+ description="System prompt to set the context for the model",
22
+ )
23
+ max_new_tokens: int = Field(
24
+ 300, description="Maximum number of tokens to generate in the response"
25
+ )
26
+
27
+ class MedGemmaAPIClientTool(BaseTool):
28
+ """
29
+ A client tool to interact with a remote MedGemma VQA FastAPI service.
30
+ This tool takes local image paths, reads them, and sends them to the API endpoint
31
+ for analysis.
32
+ """
33
+ name: str = "medgemma_medical_vqa_service"
34
+ description: str = (
35
+ "Sends medical images and a prompt to a specialized MedGemma VQA service for analysis. "
36
+ "Use this for expert-level reasoning, diagnosis assistance, and detailed image interpretation "
37
+ "across modalities like chest X-rays, dermatology, etc. Input must be local image paths and a prompt."
38
+ )
39
+ args_schema: Type[BaseModel] = MedGemmaVQAInput
40
+ api_url: str # The URL of the running FastAPI service
41
+
42
+ def _run(
43
+ self,
44
+ image_paths: List[str],
45
+ prompt: str,
46
+ system_prompt: str = "You are an expert radiologist.",
47
+ max_new_tokens: int = 300,
48
+ run_manager: Optional[CallbackManagerForToolRun] = None,
49
+ ) -> str:
50
+ """Execute the tool synchronously."""
51
+ # httpx is a modern HTTP client that supports sync and async
52
+ timeout_config = httpx.Timeout(300.0, connect=10.0)
53
+ client = httpx.Client(timeout=timeout_config)
54
+
55
+ # Prepare the multipart form data
56
+ files_to_send = []
57
+ opened_files = []
58
+ try:
59
+ for path in image_paths:
60
+ f = open(path, "rb")
61
+ opened_files.append(f)
62
+ # The key 'images' must match the parameter name in the FastAPI endpoint
63
+ files_to_send.append(("images", (os.path.basename(path), f, "image/jpeg")))
64
+
65
+ data = {
66
+ "prompt": prompt,
67
+ "system_prompt": system_prompt,
68
+ "max_new_tokens": max_new_tokens,
69
+ }
70
+
71
+ response = client.post(
72
+ f"{self.api_url}/analyze-images/",
73
+ data=data,
74
+ files=files_to_send,
75
+ )
76
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
77
+
78
+ # The agent expects a string response from a tool
79
+ return response.json()["response"]
80
+
81
+ # --- KEY FIX 3: More specific exception handling for clearer errors ---
82
+ except httpx.TimeoutException:
83
+ return f"Error: The request to the MedGemma API timed out after {timeout_config.read} seconds. The server might be overloaded or the model is taking too long to load. Try again later."
84
+ except httpx.ConnectError:
85
+ return f"Error: Could not connect to the MedGemma API. Check if the server address '{self.api_url}' is correct and running."
86
+ except httpx.HTTPStatusError as e:
87
+ return f"Error: The MedGemma API returned an error (Status {e.response.status_code}): {e.response.text}"
88
+ except Exception as e:
89
+ return f"An unexpected error occurred in the MedGemma client tool: {str(e)}"
90
+ finally:
91
+ # Important: Ensure all opened files are closed.
92
+ for f in opened_files:
93
+ f.close()
94
+
95
+ async def _arun(
96
+ self,
97
+ image_paths: List[str],
98
+ prompt: str,
99
+ system_prompt: str = "You are an expert radiologist.",
100
+ max_new_tokens: int = 300,
101
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
102
+ ) -> str:
103
+ """Execute the tool asynchronously."""
104
+ async with httpx.AsyncClient() as client:
105
+ files_to_send = []
106
+ opened_files = []
107
+ try:
108
+ # Note: File I/O is blocking, for a truly async app you might use aiofiles
109
+ # But for this use case, this is generally acceptable.
110
+ for path in image_paths:
111
+ f = open(path, "rb")
112
+ opened_files.append(f)
113
+ files_to_send.append(("images", (os.path.basename(path), f, "image/jpeg")))
114
+
115
+ data = {
116
+ "prompt": prompt,
117
+ "system_prompt": system_prompt,
118
+ "max_new_tokens": max_new_tokens,
119
+ }
120
+
121
+ response = await client.post(
122
+ f"{self.api_url}/analyze-images/",
123
+ data=data,
124
+ files=files_to_send,
125
+ timeout=120.0
126
+ )
127
+ response.raise_for_status()
128
+
129
+ return response.json()["response"]
130
+
131
+ except httpx.HTTPStatusError as e:
132
+ return f"Error calling MedGemma API: {e.response.status_code} - {e.response.text}"
133
+ except Exception as e:
134
+ return f"An unexpected error occurred: {str(e)}"
135
+ finally:
136
+ for f in opened_files:
137
+ f.close()
138
+
139
+ if __name__ == "__main__":
140
+ client_tool = MedGemmaAPIClientTool(api_url="http://localhost:8002")
141
+ result = client_tool.run({
142
+ "image_paths": ["demo/chest/pneumonia1.jpg"],
143
+ "prompt": "What abnormality do you see?"
144
+ })
145
+ print(result)