Emily Xie commited on
Commit
205758b
·
1 Parent(s): aa6bc6b

medgemma fastapi tool integration

Browse files
README.md CHANGED
@@ -22,6 +22,7 @@ MedRAX is built on a robust technical foundation:
22
 
23
  ### Integrated Tools
24
  - **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
 
25
  - **Segmentation**: Employs MedSAM2 (advanced medical image segmentation) and PSPNet model trained on ChestX-Det for precise anatomical structure identification
26
  - **Grounding**: Uses Maira-2 for localizing specific findings in medical images
27
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
@@ -130,6 +131,10 @@ PINECONE_API_KEY=
130
  # Requires Google Custom Search API credentials.
131
  GOOGLE_SEARCH_API_KEY=
132
  GOOGLE_SEARCH_ENGINE_ID=
 
 
 
 
133
  ```
134
 
135
  ### Getting Started
@@ -232,6 +237,21 @@ XRayVQATool(
232
  ```
233
  - CheXagent weights download automatically
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  ### MedSAM2 Tool
236
  ```python
237
  MedSAM2Tool(
 
22
 
23
  ### Integrated Tools
24
  - **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
25
+ - **MedGemma VQA**: Advanced medical visual question answering using Google's MedGemma 4B model for comprehensive medical image analysis across multiple modalities
26
  - **Segmentation**: Employs MedSAM2 (advanced medical image segmentation) and PSPNet model trained on ChestX-Det for precise anatomical structure identification
27
  - **Grounding**: Uses Maira-2 for localizing specific findings in medical images
28
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
 
131
  # Requires Google Custom Search API credentials.
132
  GOOGLE_SEARCH_API_KEY=
133
  GOOGLE_SEARCH_ENGINE_ID=
134
+
135
+ # MedGemma VQA Tool (Optional)
136
+ # URL for the MedGemma FastAPI service
137
+ MEDGEMMA_API_URL=http://127.0.0.1:8002
138
  ```
139
 
140
  ### Getting Started
 
237
  ```
238
  - CheXagent weights download automatically
239
 
240
+ ### MedGemma VQA Tool
241
+ ```python
242
+ MedGemmaAPIClientTool(
243
+ device=device,
244
+ cache_dir=model_dir,
245
+ api_url=MEDGEMMA_API_URL)
246
+ )
247
+ ```
248
+ - **Advanced Medical VQA**: Uses Google's MedGemma 4B instruction-tuned model for comprehensive medical image analysis
249
+ - **Multi-modal Capabilities**: Specialized for chest X-rays, dermatology, ophthalmology, and pathology images
250
+ - **Expert-level Analysis**: Provides radiologist-level medical reasoning and diagnosis assistance
251
+ - **High Performance**: Supports up to 128K context length and 896x896 image resolution
252
+ - **Memory Efficient**: 4-bit quantization available (~4GB VRAM) with full precision option (~8GB VRAM)
253
+ - **Automatic Setup**: Model weights download automatically when service starts
254
+
255
  ### MedSAM2 Tool
256
  ```python
257
  MedSAM2Tool(
main.py CHANGED
@@ -73,7 +73,7 @@ def initialize_agent(
73
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
74
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
75
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
76
- "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
77
  "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
78
  cache_dir=model_dir, device=device
79
  ),
@@ -90,7 +90,7 @@ def initialize_agent(
90
  "MedSAM2Tool": lambda: MedSAM2Tool(
91
  device=device, cache_dir=model_dir, temp_dir=temp_dir
92
  ),
93
- "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(api_url=MEDGEMMA_API_URL)
94
  }
95
 
96
  try:
@@ -157,9 +157,13 @@ if __name__ == "__main__":
157
  # "WebBrowserTool", # For web browsing and search capabilities
158
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
159
  # "PythonSandboxTool", # Add the Python sandbox tool
160
- "MedGemmaVQATool" # For visual question answering on medical images
161
  ]
162
 
 
 
 
 
163
  # Configure the Retrieval Augmented Generation (RAG) system
164
  # This allows the agent to access and use medical knowledge documents
165
  rag_config = RAGConfig(
@@ -185,7 +189,7 @@ if __name__ == "__main__":
185
  model_dir="model-weights",
186
  temp_dir="temp", # Change this to the path of the temporary directory
187
  device="cuda",
188
- model="grok-4", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
189
  temperature=0.7,
190
  top_p=0.95,
191
  model_kwargs=model_kwargs,
 
73
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
74
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
75
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
76
+ "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
77
  "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
78
  cache_dir=model_dir, device=device
79
  ),
 
90
  "MedSAM2Tool": lambda: MedSAM2Tool(
91
  device=device, cache_dir=model_dir, temp_dir=temp_dir
92
  ),
93
+ "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
94
  }
95
 
96
  try:
 
157
  # "WebBrowserTool", # For web browsing and search capabilities
158
  # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
159
  # "PythonSandboxTool", # Add the Python sandbox tool
160
+ "MedGemmaVQATool" # Google MedGemma VQA tool
161
  ]
162
 
163
+ # Setup the MedGemma environment if the MedGemmaVQATool is selected
164
+ if "MedGemmaVQATool" in selected_tools:
165
+ setup_medgemma_env()
166
+
167
  # Configure the Retrieval Augmented Generation (RAG) system
168
  # This allows the agent to access and use medical knowledge documents
169
  rag_config = RAGConfig(
 
189
  model_dir="model-weights",
190
  temp_dir="temp", # Change this to the path of the temporary directory
191
  device="cuda",
192
+ model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
193
  temperature=0.7,
194
  top_p=0.95,
195
  model_kwargs=model_kwargs,
medrax/tools/__init__.py CHANGED
@@ -3,8 +3,7 @@
3
  from .classification import *
4
  from .report_generation import *
5
  from .segmentation import *
6
- from .xray_vqa import *
7
- from .llava_med import *
8
  from .grounding import *
9
  from .generation import *
10
  from .dicom import *
@@ -13,4 +12,3 @@ from .rag import *
13
  from .web_browser import *
14
  from .python_tool import *
15
  from .medsam2 import *
16
- from .medgemma_client import *
 
3
  from .classification import *
4
  from .report_generation import *
5
  from .segmentation import *
6
+ from .vqa import *
 
7
  from .grounding import *
8
  from .generation import *
9
  from .dicom import *
 
12
  from .web_browser import *
13
  from .python_tool import *
14
  from .medsam2 import *
 
medrax/tools/vqa/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visual Question Answering tools for medical images."""
2
+
3
+ from .llava_med import LlavaMedTool, LlavaMedInput
4
+ from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
5
+ from .medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
6
+ from .medgemma_setup import setup_medgemma_env
7
+
8
+ __all__ = [
9
+ "LlavaMedTool",
10
+ "LlavaMedInput",
11
+ "CheXagentXRayVQATool",
12
+ "XRayVQAToolInput",
13
+ "MedGemmaAPIClientTool",
14
+ "MedGemmaVQAInput",
15
+ "setup_medgemma_env"
16
+ ]
medrax/tools/vqa/llava_med.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Type
2
+ from pydantic import BaseModel, Field
3
+
4
+ import torch
5
+
6
+ from langchain_core.callbacks import (
7
+ AsyncCallbackManagerForToolRun,
8
+ CallbackManagerForToolRun,
9
+ )
10
+ from langchain_core.tools import BaseTool
11
+
12
+ from PIL import Image
13
+
14
+
15
+ from medrax.llava.conversation import conv_templates
16
+ from medrax.llava.model.builder import load_pretrained_model
17
+ from medrax.llava.mm_utils import tokenizer_image_token, process_images
18
+ from medrax.llava.constants import (
19
+ IMAGE_TOKEN_INDEX,
20
+ DEFAULT_IMAGE_TOKEN,
21
+ DEFAULT_IM_START_TOKEN,
22
+ DEFAULT_IM_END_TOKEN,
23
+ )
24
+
25
+
26
+ class LlavaMedInput(BaseModel):
27
+ """Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images."""
28
+
29
+ question: str = Field(..., description="The question to ask about the medical image")
30
+ image_path: Optional[str] = Field(
31
+ None,
32
+ description="Path to the medical image file (optional), only supports JPG or PNG images",
33
+ )
34
+
35
+
36
+ class LlavaMedTool(BaseTool):
37
+ """Tool that performs medical visual question answering using LLaVA-Med.
38
+
39
+ This tool uses a large language model fine-tuned on medical images to answer
40
+ questions about medical images. It can handle both image-based questions and
41
+ general medical questions without images.
42
+ """
43
+
44
+ name: str = "llava_med_qa"
45
+ description: str = (
46
+ "A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. "
47
+ "While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. "
48
+ "Input should be a question and optionally a path to a medical image file."
49
+ )
50
+ args_schema: Type[BaseModel] = LlavaMedInput
51
+ tokenizer: Any = None
52
+ model: Any = None
53
+ image_processor: Any = None
54
+ context_len: int = 200000
55
+
56
+ def __init__(
57
+ self,
58
+ model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
59
+ cache_dir: str = "/model-weights",
60
+ low_cpu_mem_usage: bool = True,
61
+ torch_dtype: torch.dtype = torch.bfloat16,
62
+ device: str = "cuda",
63
+ load_in_4bit: bool = False,
64
+ load_in_8bit: bool = False,
65
+ **kwargs,
66
+ ):
67
+ super().__init__()
68
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
69
+ model_path=model_path,
70
+ model_base=None,
71
+ model_name=model_path,
72
+ load_in_4bit=load_in_4bit,
73
+ load_in_8bit=load_in_8bit,
74
+ cache_dir=cache_dir,
75
+ low_cpu_mem_usage=low_cpu_mem_usage,
76
+ torch_dtype=torch_dtype,
77
+ device=device,
78
+ **kwargs,
79
+ )
80
+ self.model.eval()
81
+
82
+ def _process_input(
83
+ self, question: str, image_path: Optional[str] = None
84
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
85
+ if self.model.config.mm_use_im_start_end:
86
+ question = (
87
+ DEFAULT_IM_START_TOKEN
88
+ + DEFAULT_IMAGE_TOKEN
89
+ + DEFAULT_IM_END_TOKEN
90
+ + "\n"
91
+ + question
92
+ )
93
+ else:
94
+ question = DEFAULT_IMAGE_TOKEN + "\n" + question
95
+
96
+ conv = conv_templates["vicuna_v1"].copy()
97
+ conv.append_message(conv.roles[0], question)
98
+ conv.append_message(conv.roles[1], None)
99
+ prompt = conv.get_prompt()
100
+
101
+ input_ids = (
102
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
103
+ .unsqueeze(0)
104
+ .cuda()
105
+ )
106
+
107
+ image_tensor = None
108
+ if image_path:
109
+ image = Image.open(image_path)
110
+ image_tensor = process_images([image], self.image_processor, self.model.config)[0]
111
+ image_tensor = image_tensor.unsqueeze(0).half().cuda()
112
+
113
+ return input_ids, image_tensor
114
+
115
+ def _run(
116
+ self,
117
+ question: str,
118
+ image_path: Optional[str] = None,
119
+ run_manager: Optional[CallbackManagerForToolRun] = None,
120
+ ) -> Tuple[str, Dict]:
121
+ """Answer a medical question, optionally based on an input image.
122
+
123
+ Args:
124
+ question (str): The medical question to answer.
125
+ image_path (Optional[str]): The path to the medical image file (if applicable).
126
+ run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
127
+
128
+ Returns:
129
+ Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
130
+
131
+ Raises:
132
+ Exception: If there's an error processing the input or generating the answer.
133
+ """
134
+ try:
135
+ input_ids, image_tensor = self._process_input(question, image_path)
136
+ input_ids = input_ids.to(device=self.model.device)
137
+ image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
138
+
139
+ with torch.inference_mode():
140
+ output_ids = self.model.generate(
141
+ input_ids,
142
+ images=image_tensor,
143
+ do_sample=False,
144
+ temperature=0.2,
145
+ max_new_tokens=500,
146
+ use_cache=True,
147
+ )
148
+
149
+ output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150
+ metadata = {
151
+ "question": question,
152
+ "image_path": image_path,
153
+ "analysis_status": "completed",
154
+ }
155
+ return output, metadata
156
+ except Exception as e:
157
+ return f"Error generating answer: {str(e)}", {
158
+ "question": question,
159
+ "image_path": image_path,
160
+ "analysis_status": "failed",
161
+ }
162
+
163
+ async def _arun(
164
+ self,
165
+ question: str,
166
+ image_path: Optional[str] = None,
167
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
168
+ ) -> Tuple[str, Dict]:
169
+ """Asynchronously answer a medical question, optionally based on an input image.
170
+
171
+ This method currently calls the synchronous version, as the model inference
172
+ is not inherently asynchronous. For true asynchronous behavior, consider
173
+ using a separate thread or process.
174
+
175
+ Args:
176
+ question (str): The medical question to answer.
177
+ image_path (Optional[str]): The path to the medical image file (if applicable).
178
+ run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
179
+
180
+ Returns:
181
+ Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
182
+
183
+ Raises:
184
+ Exception: If there's an error processing the input or generating the answer.
185
+ """
186
+ return self._run(question, image_path)
medrax/tools/vqa/medgemma/medgemma.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from pathlib import Path
4
+ import sys
5
+ import traceback
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+ import uuid
8
+
9
+ from PIL import Image
10
+
11
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
12
+ from pydantic import BaseModel, Field
13
+ import torch
14
+ import transformers
15
+ from transformers import BitsAndBytesConfig, pipeline
16
+ import uvicorn
17
+
18
+ #TODO: delete this
19
+ print("ENVIRONMENT CHECK")
20
+ print(f"Python Executable: {sys.executable}")
21
+ print(f"PyTorch version: {torch.__version__}")
22
+ print(f"Transformers version: {transformers.__version__}")
23
+
24
+ # Configuration
25
+ UPLOAD_DIR = "./medgemma_images"
26
+
27
+ # Create directories if they don't exist
28
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
29
+
30
+ # Pydantic Models for API
31
+ class VQAInput(BaseModel):
32
+ """Input schema for the MedGemma VQA API endpoint.
33
+
34
+ Defines the structure for requests to the /analyze-images/ endpoint.
35
+ Used for validating incoming API requests and generating OpenAPI documentation.
36
+ """
37
+ prompt: str = Field(..., description="Question or instruction about the medical images")
38
+ system_prompt: Optional[str] = Field(
39
+ "You are an expert radiologist.",
40
+ description="System prompt to set the context for the model",
41
+ )
42
+ max_new_tokens: int = Field(
43
+ 300, description="Maximum number of tokens to generate in the response"
44
+ )
45
+
46
+ class VQAResponse(BaseModel):
47
+ """Response schema for successful MedGemma VQA API requests.
48
+
49
+ Defines the structure of successful responses from the /analyze-images/ endpoint.
50
+ Used for response validation and OpenAPI documentation.
51
+ """
52
+ response: str = Field(..., description="Generated medical analysis response from MedGemma model")
53
+ metadata: Dict[str, Any] = Field(..., description="Additional metadata about the analysis request and results")
54
+
55
+ class ErrorResponse(BaseModel):
56
+ """Error response schema for failed MedGemma VQA API requests.
57
+
58
+ Defines the structure of error responses from the /analyze-images/ endpoint.
59
+ Used for error response validation and OpenAPI documentation.
60
+ """
61
+ error: str = Field(..., description="Human-readable error message describing what went wrong")
62
+ metadata: Dict[str, Any] = Field(..., description="Additional metadata about the error and request context")
63
+
64
+ # MedGemma Model Handling
65
+ class MedGemmaModel:
66
+ """Medical visual question answering model using Google's MedGemma 4B model.
67
+
68
+ MedGemma is a specialized multimodal AI model trained on medical images and text.
69
+ It provides expert-level analysis for chest X-rays, dermatology images,
70
+ ophthalmology images, and histopathology slides.
71
+
72
+ Key capabilities:
73
+ - Medical image classification and analysis across multiple modalities
74
+ - Visual question answering for radiology, dermatology, pathology, ophthalmology
75
+ - Clinical reasoning and medical knowledge integration
76
+ - Multi-modal medical understanding (text + images)
77
+ - Support for up to 128K context length
78
+
79
+ Performance:
80
+ - Full precision (bfloat16): ~8GB VRAM, recommended for medical applications
81
+ - 4-bit quantization (default): Available but may affect quality on some systems
82
+
83
+ This class implements a singleton pattern to ensure only one model instance
84
+ is loaded in memory, optimizing resource usage for the FastAPI service.
85
+ """
86
+
87
+ _instance = None
88
+
89
+ def __new__(cls, *args, **kwargs):
90
+ """Create or return the singleton instance of MedGemmaModel.
91
+
92
+ Ensures only one model instance exists in memory, preventing
93
+ multiple model loads and conserving GPU memory.
94
+
95
+ Returns:
96
+ MedGemmaModel: The singleton instance
97
+ """
98
+ if not cls._instance:
99
+ cls._instance = super(MedGemmaModel, cls).__new__(cls)
100
+ return cls._instance
101
+
102
+ def __init__(
103
+ self,
104
+ model_name: str = "google/medgemma-4b-it",
105
+ device: Optional[str] = "cuda",
106
+ dtype: torch.dtype = torch.bfloat16,
107
+ cache_dir: Optional[str] = None,
108
+ load_in_4bit: bool = True,
109
+ **kwargs: Any,
110
+ ) -> None:
111
+ """Initialize the MedGemmaModel.
112
+
113
+ Args:
114
+ model_name: Name of the MedGemma model to use (default: "google/medgemma-4b-it")
115
+ device: Device to run model on - "cuda" or "cpu" (default: "cuda")
116
+ dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
117
+ cache_dir: Directory to cache downloaded models (default: None)
118
+ load_in_4bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
119
+ **kwargs: Additional arguments passed to the model pipeline
120
+
121
+ Raises:
122
+ RuntimeError: If model initialization fails (e.g., insufficient GPU memory)
123
+ """
124
+ # Re-initialization guard
125
+ if hasattr(self, 'pipe') and self.pipe is not None:
126
+ return
127
+
128
+ self.device = device if device and torch.cuda.is_available() else "cpu"
129
+ self.dtype = dtype
130
+ self.cache_dir = cache_dir
131
+
132
+ # Setup model configuration
133
+ model_kwargs = {
134
+ "torch_dtype": self.dtype,
135
+ }
136
+
137
+ if cache_dir:
138
+ model_kwargs["cache_dir"] = cache_dir
139
+
140
+ # Handle device mapping and quantization
141
+ pipeline_kwargs = {
142
+ "model": model_name,
143
+ "model_kwargs": model_kwargs,
144
+ "trust_remote_code": True,
145
+ "use_cache": True,
146
+ }
147
+
148
+ if load_in_4bit:
149
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
150
+ model_kwargs["device_map"] = {"": self.device}
151
+
152
+ try:
153
+ self.pipe = pipeline("image-text-to-text", **pipeline_kwargs)
154
+ except Exception as e:
155
+ raise RuntimeError(f"Failed to initialize MedGemma pipeline: {str(e)}")
156
+
157
+ def _prepare_messages(
158
+ self, image_paths: List[str], prompt: str, system_prompt: str
159
+ ) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
160
+ """Prepare chat messages in the format expected by MedGemma.
161
+
162
+ Converts image paths to PIL Image objects and formats them into the
163
+ chat message structure that MedGemma expects for multimodal input.
164
+
165
+ Args:
166
+ image_paths: List of file paths to medical images
167
+ prompt: User's question or instruction about the images
168
+ system_prompt: System context message to set the model's role
169
+
170
+ Returns:
171
+ Tuple containing:
172
+ - List of formatted chat messages for MedGemma
173
+ - List of loaded PIL Image objects
174
+
175
+ Raises:
176
+ FileNotFoundError: If any image file cannot be found
177
+ """
178
+ images = []
179
+ for path in image_paths:
180
+ if not Path(path).is_file():
181
+ raise FileNotFoundError(f"Image file not found: {path}")
182
+
183
+ image = Image.open(path)
184
+ if image.mode != "RGB":
185
+ image = image.convert("RGB")
186
+ images.append(image)
187
+
188
+ # Create messages in chat format
189
+ messages = [
190
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
191
+ {
192
+ "role": "user",
193
+ "content": [{"type": "text", "text": prompt}]
194
+ + [{"type": "image", "image": img} for img in images],
195
+ },
196
+ ]
197
+
198
+ return messages, images
199
+
200
+ def _generate_response(self, messages: List[Dict[str, Any]], max_new_tokens: int) -> str:
201
+ """Generate response using MedGemma pipeline.
202
+
203
+ Processes the formatted messages through the MedGemma model to generate
204
+ a medical analysis response.
205
+
206
+ Args:
207
+ messages: Formatted chat messages with images and text
208
+ max_new_tokens: Maximum number of tokens to generate in response
209
+
210
+ Returns:
211
+ Generated response text from MedGemma model
212
+ """
213
+ # Generate using pipeline
214
+ output = self.pipe(
215
+ text=messages,
216
+ max_new_tokens=max_new_tokens,
217
+ do_sample=False,
218
+ )
219
+
220
+ # Extract generated text from pipeline output
221
+ if (
222
+ isinstance(output, list)
223
+ and output
224
+ and isinstance(output[0].get("generated_text"), list)
225
+ ):
226
+ generated_text = output[0]["generated_text"]
227
+ if generated_text:
228
+ return generated_text[-1].get("content", "").strip()
229
+
230
+ return "No response generated"
231
+
232
+ def _create_error_response(
233
+ self,
234
+ image_paths: List[str],
235
+ prompt: str,
236
+ error_message: str,
237
+ error_type: str,
238
+ error_details: str,
239
+ ) -> Dict[str, Any]:
240
+ """Create standardized error response metadata.
241
+
242
+ Generates consistent error metadata structure for logging and debugging
243
+ purposes across different error scenarios.
244
+
245
+ Args:
246
+ image_paths: List of image paths that were being processed
247
+ prompt: User prompt that was being processed
248
+ error_message: Human-readable error message
249
+ error_type: Categorization of the error (e.g., "memory_error", "file_not_found")
250
+ error_details: Detailed technical error information
251
+
252
+ Returns:
253
+ Dictionary containing standardized error metadata
254
+ """
255
+ return {
256
+ "image_paths": image_paths,
257
+ "prompt": prompt,
258
+ "analysis_status": "failed",
259
+ "error_type": error_type,
260
+ "error_details": error_details,
261
+ }
262
+
263
+ async def aget_response(self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int) -> str:
264
+ """Async method to get response from MedGemma model.
265
+
266
+ Main entry point for generating medical analysis responses. Handles
267
+ the complete pipeline from image loading to response generation
268
+ in an asynchronous manner.
269
+
270
+ Args:
271
+ image_paths: List of file paths to medical images
272
+ prompt: User's question or instruction about the images
273
+ system_prompt: System context message to set the model's role
274
+ max_new_tokens: Maximum number of tokens to generate in response
275
+
276
+ Returns:
277
+ Generated medical analysis response as a string
278
+
279
+ Raises:
280
+ FileNotFoundError: If any image file cannot be found
281
+ RuntimeError: If model inference fails
282
+ """
283
+ loop = asyncio.get_event_loop()
284
+ messages, _ = await loop.run_in_executor(None, self._prepare_messages, image_paths, prompt, system_prompt)
285
+
286
+ def _generate():
287
+ return self._generate_response(messages, max_new_tokens)
288
+
289
+ return await loop.run_in_executor(None, _generate)
290
+
291
+ # FastAPI Application
292
+ app = FastAPI(
293
+ title="MedGemma VQA API",
294
+ description="API for medical visual question answering using Google's MedGemma model."
295
+ )
296
+
297
+ medgemma_model: Optional[MedGemmaModel] = None
298
+
299
+ @app.on_event("startup")
300
+ async def startup_event():
301
+ """Load the MedGemma model at application startup.
302
+
303
+ This function is called when the FastAPI application starts up.
304
+ It initializes the MedGemma model as a global singleton instance,
305
+ ensuring the model is loaded and ready to handle requests.
306
+
307
+ The model is loaded with default settings optimized for medical
308
+ image analysis, including 4-bit quantization for memory efficiency.
309
+
310
+ Raises:
311
+ SystemExit: If model loading fails, the application will exit
312
+ to prevent serving requests with an unavailable model.
313
+ """
314
+ global medgemma_model
315
+ try:
316
+ medgemma_model = MedGemmaModel()
317
+ print("MedGemma model loaded successfully.")
318
+ except RuntimeError as e:
319
+ print(f"Error loading MedGemma model: {e}")
320
+ exit(1)
321
+
322
+ @app.post("/analyze-images/",
323
+ response_model=VQAResponse,
324
+ responses={
325
+ 500: {"model": ErrorResponse, "description": "Internal server error or model inference failure"},
326
+ 404: {"model": ErrorResponse, "description": "Image file not found"},
327
+ 400: {"description": "Invalid request format or unsupported image type"},
328
+ 503: {"description": "Model not available or not loaded"}
329
+ },
330
+ summary="Analyze one or more medical images",
331
+ description="Upload medical images and receive AI-powered analysis using Google's MedGemma model.")
332
+ async def analyze_images(
333
+ images: List[UploadFile] = File(..., description="List of medical image files to analyze (JPG or PNG)."),
334
+ prompt: str = Form(..., description="Question or instruction about the medical images."),
335
+ system_prompt: Optional[str] = Form("You are an expert radiologist.", description="System prompt to set the context for the model."),
336
+ max_new_tokens: int = Form(100, description="Maximum number of tokens to generate in the response.")
337
+ ):
338
+ """Analyze medical images using MedGemma AI model.
339
+
340
+ This endpoint accepts one or more medical images along with a prompt
341
+ and returns AI-generated medical analysis.
342
+
343
+ The endpoint handles the complete pipeline:
344
+ 1. Validates uploaded image files
345
+ 2. Saves images temporarily to disk
346
+ 3. Processes images through MedGemma model
347
+ 4. Returns structured analysis with metadata
348
+ 5. Cleans up temporary files
349
+
350
+ Args:
351
+ images: List of uploaded image files (JPG/PNG format)
352
+ prompt: Medical question or instruction about the images
353
+ system_prompt: Context setting for the AI model (default: radiologist role)
354
+ max_new_tokens: Maximum response length (default: 100)
355
+
356
+ Returns:
357
+ VQAResponse: Contains the AI-generated analysis and request metadata
358
+
359
+ Raises:
360
+ HTTPException 400: Invalid image format or request structure
361
+ HTTPException 404: Image file not found during processing
362
+ HTTPException 500: Model inference error or memory issues
363
+ HTTPException 503: Model not available for processing
364
+ """
365
+ # Check if model is available
366
+ if medgemma_model is None or medgemma_model.pipe is None:
367
+ raise HTTPException(status_code=503, detail="Model is not available. Please try again later.")
368
+
369
+ # Process uploaded images
370
+ image_paths = []
371
+ for image in images:
372
+ # Validate image format
373
+ if image.content_type not in ["image/jpeg", "image/png"]:
374
+ raise HTTPException(status_code=400, detail=f"Unsupported image format: {image.content_type}. Only JPG and PNG are supported.")
375
+
376
+ # Generate unique filename to avoid conflicts
377
+ unique_filename = f"{uuid.uuid4()}_{image.filename}"
378
+ file_path = os.path.join(UPLOAD_DIR, unique_filename)
379
+
380
+ try:
381
+ # Save uploaded image to disk
382
+ with open(file_path, "wb") as buffer:
383
+ buffer.write(await image.read())
384
+ image_paths.append(file_path)
385
+ except Exception as e:
386
+ raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
387
+
388
+ try:
389
+ # Generate AI analysis
390
+ response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
391
+
392
+ # Prepare success response
393
+ metadata = {
394
+ "image_paths": image_paths,
395
+ "prompt": prompt,
396
+ "system_prompt": system_prompt,
397
+ "max_new_tokens": max_new_tokens,
398
+ "num_images": len(image_paths),
399
+ "analysis_status": "completed",
400
+ }
401
+ return VQAResponse(response=response_text, metadata=metadata)
402
+
403
+ except FileNotFoundError as e:
404
+ raise HTTPException(status_code=404, detail=f"Image file not found: {str(e)}")
405
+ except torch.cuda.OutOfMemoryError as e:
406
+ error_message = "GPU memory exhausted. Try reducing image resolution or max_new_tokens."
407
+ metadata = medgemma_model._create_error_response(
408
+ image_paths, prompt, error_message, "memory_error", str(e)
409
+ )
410
+ raise HTTPException(status_code=500, detail=error_message)
411
+ except Exception as e:
412
+ traceback.print_exc()
413
+ metadata = medgemma_model._create_error_response(
414
+ image_paths, prompt, f"Analysis failed: {str(e)}", "general_error", str(e)
415
+ )
416
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
417
+ finally:
418
+ # Clean up temporary image files
419
+ for path in image_paths:
420
+ try:
421
+ os.remove(path)
422
+ except OSError:
423
+ pass
424
+
425
+ if __name__ == "__main__":
426
+ """Launch the MedGemma VQA API server.
427
+
428
+ Starts the FastAPI application with uvicorn server, binding to all
429
+ network interfaces on port 8002.
430
+ """
431
+ uvicorn.run(app, host="0.0.0.0", port=8002)
medrax/tools/vqa/medgemma/medgemma_client.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple, Type
3
+
4
+ import httpx
5
+ from langchain_core.callbacks import (
6
+ AsyncCallbackManagerForToolRun,
7
+ CallbackManagerForToolRun,
8
+ )
9
+ from langchain_core.tools import BaseTool
10
+ from pydantic import BaseModel, Field
11
+
12
+ class MedGemmaVQAInput(BaseModel):
13
+ """Input schema for the MedGemma VQA Tool. Only supports JPG or PNG images."""
14
+ image_paths: List[str] = Field(
15
+ ...,
16
+ description="List of paths to medical image files to analyze, only supports JPG or PNG images",
17
+ )
18
+ prompt: str = Field(..., description="Question or instruction about the medical images")
19
+ system_prompt: Optional[str] = Field(
20
+ "You are an expert radiologist.",
21
+ description="System prompt to set the context for the model",
22
+ )
23
+ max_new_tokens: int = Field(
24
+ 300, description="Maximum number of tokens to generate in the response"
25
+ )
26
+
27
+ class MedGemmaAPIClientTool(BaseTool):
28
+ """Medical visual question answering tool using Google's MedGemma 4B model via API.
29
+
30
+ MedGemma is a specialized multimodal AI model trained on medical images and text.
31
+ It provides expert-level analysis for chest X-rays, dermatology images,
32
+ ophthalmology images, and histopathology slides.
33
+
34
+ Key capabilities:
35
+ - Medical image classification and analysis across multiple modalities
36
+ - Visual question answering for radiology, dermatology, pathology, ophthalmology
37
+ - Clinical reasoning and medical knowledge integration
38
+ - Multi-modal medical understanding (text + images)
39
+ - Support for up to 128K context length
40
+
41
+ Performance:
42
+ - Full precision (bfloat16): ~8GB VRAM, recommended for medical applications
43
+ - 4-bit quantization (default): Available but may affect quality on some systems
44
+ """
45
+
46
+ name: str = "medgemma_medical_vqa"
47
+ description: str = (
48
+ "Advanced medical visual question answering tool using Google's MedGemma 4B instruction-tuned model via API. "
49
+ "Specialized for comprehensive medical image analysis across multiple modalities including chest X-rays, "
50
+ "dermatology images, ophthalmology images, and histopathology slides. Provides expert-level medical "
51
+ "reasoning, diagnosis assistance, and detailed image interpretation with radiologist-level expertise. "
52
+ "Input: List of medical image paths and medical question/prompt with optional custom system prompt. "
53
+ "Output: Comprehensive medical analysis and answers based on visual content with detailed reasoning. "
54
+ "Supports multi-image analysis, comparative studies, and complex medical reasoning tasks. "
55
+ "Model handles images up to 896x896 resolution and supports context up to 128K tokens."
56
+ )
57
+ args_schema: Type[BaseModel] = MedGemmaVQAInput
58
+ return_direct: bool = True
59
+
60
+ # API configuration
61
+ api_url: str # The URL of the running FastAPI service
62
+
63
+ def __init__(self, api_url: str, **kwargs: Any):
64
+ """Initialize the MedGemmaAPIClientTool.
65
+
66
+ Args:
67
+ api_url: The URL of the running MedGemma FastAPI service
68
+ **kwargs: Additional arguments passed to BaseTool
69
+ """
70
+ super().__init__(api_url=api_url, **kwargs)
71
+
72
+ def _prepare_request_data(
73
+ self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
74
+ ) -> Tuple[List, Dict]:
75
+ """Prepare multipart form data for API request.
76
+
77
+ Args:
78
+ image_paths: List of paths to medical images
79
+ prompt: Question or instruction about the images
80
+ system_prompt: System context for the model
81
+ max_new_tokens: Maximum number of tokens to generate
82
+
83
+ Returns:
84
+ Tuple of files list and data dictionary
85
+ """
86
+ files_to_send = []
87
+ opened_files = []
88
+
89
+ for path in image_paths:
90
+ with open(path, "rb") as f:
91
+ files_to_send.append(("images", (os.path.basename(path), f.read(), "image/jpeg")))
92
+
93
+ data = {
94
+ "prompt": prompt,
95
+ "system_prompt": system_prompt,
96
+ "max_new_tokens": max_new_tokens,
97
+ }
98
+
99
+ return files_to_send, data, opened_files
100
+
101
+ def _create_error_response(
102
+ self,
103
+ image_paths: List[str],
104
+ prompt: str,
105
+ error_message: str,
106
+ error_type: str,
107
+ error_details: str,
108
+ ) -> Tuple[Dict[str, Any], Dict]:
109
+ """Create standardized error response.
110
+
111
+ Args:
112
+ image_paths: List of image paths
113
+ prompt: User prompt
114
+ error_message: Human-readable error message
115
+ error_type: Type of error
116
+ error_details: Detailed error information
117
+
118
+ Returns:
119
+ Tuple of error output and metadata
120
+ """
121
+ output = {"error": error_message}
122
+ metadata = {
123
+ "image_paths": image_paths,
124
+ "prompt": prompt,
125
+ "analysis_status": "failed",
126
+ "error_type": error_type,
127
+ "error_details": error_details,
128
+ }
129
+ return output, metadata
130
+
131
+ def _run(
132
+ self,
133
+ image_paths: List[str],
134
+ prompt: str,
135
+ system_prompt: str = "You are an expert radiologist.",
136
+ max_new_tokens: int = 300,
137
+ run_manager: Optional[CallbackManagerForToolRun] = None,
138
+ ) -> Tuple[Dict[str, Any], Dict]:
139
+ """Execute medical visual question answering via API.
140
+
141
+ Args:
142
+ image_paths: List of paths to medical images
143
+ prompt: Question or instruction about the images
144
+ system_prompt: System context for the model
145
+ max_new_tokens: Maximum number of tokens to generate
146
+ run_manager: Optional callback manager
147
+
148
+ Returns:
149
+ Tuple of output dictionary and metadata
150
+ """
151
+ # httpx is a modern HTTP client that supports sync and async
152
+ timeout_config = httpx.Timeout(300.0, connect=10.0)
153
+ client = httpx.Client(timeout=timeout_config)
154
+
155
+ try:
156
+ # Prepare the multipart form data
157
+ files_to_send, data, opened_files = self._prepare_request_data(
158
+ image_paths, prompt, system_prompt, max_new_tokens
159
+ )
160
+
161
+ response = client.post(
162
+ f"{self.api_url}/analyze-images/",
163
+ data=data,
164
+ files=files_to_send,
165
+ )
166
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
167
+
168
+ response_data = response.json()
169
+ output = {"response": response_data["response"]}
170
+
171
+ metadata = {
172
+ "image_paths": image_paths,
173
+ "prompt": prompt,
174
+ "system_prompt": system_prompt,
175
+ "max_new_tokens": max_new_tokens,
176
+ "num_images": len(image_paths),
177
+ "analysis_status": "completed",
178
+ }
179
+
180
+ return output, metadata
181
+
182
+ except httpx.TimeoutException as e:
183
+ return self._create_error_response(
184
+ image_paths,
185
+ prompt,
186
+ f"Error: The request to the MedGemma API timed out after {timeout_config.read} seconds. The server might be overloaded or the model is taking too long to load. Try again later.",
187
+ "timeout_error",
188
+ str(e)
189
+ )
190
+ except httpx.ConnectError as e:
191
+ return self._create_error_response(
192
+ image_paths,
193
+ prompt,
194
+ f"Error: Could not connect to the MedGemma API. Check if the server address '{self.api_url}' is correct and running.",
195
+ "connection_error",
196
+ str(e)
197
+ )
198
+ except httpx.HTTPStatusError as e:
199
+ return self._create_error_response(
200
+ image_paths,
201
+ prompt,
202
+ f"Error: The MedGemma API returned an error (Status {e.response.status_code}): {e.response.text}",
203
+ "http_error",
204
+ f"Status {e.response.status_code}: {e.response.text}"
205
+ )
206
+ except Exception as e:
207
+ return self._create_error_response(
208
+ image_paths,
209
+ prompt,
210
+ f"An unexpected error occurred in the MedGemma client tool: {str(e)}",
211
+ "general_error",
212
+ str(e)
213
+ )
214
+ finally:
215
+ # Ensure all opened files are closed
216
+ if 'opened_files' in locals():
217
+ for f in opened_files:
218
+ f.close()
219
+
220
+ async def _arun(
221
+ self,
222
+ image_paths: List[str],
223
+ prompt: str,
224
+ system_prompt: str = "You are an expert radiologist.",
225
+ max_new_tokens: int = 300,
226
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
227
+ ) -> Tuple[Dict[str, Any], Dict]:
228
+ """Execute the tool asynchronously."""
229
+ async with httpx.AsyncClient() as client:
230
+ try:
231
+ # Prepare the multipart form data
232
+ files_to_send, data, opened_files = self._prepare_request_data(
233
+ image_paths, prompt, system_prompt, max_new_tokens
234
+ )
235
+
236
+ response = await client.post(
237
+ f"{self.api_url}/analyze-images/",
238
+ data=data,
239
+ files=files_to_send,
240
+ timeout=120.0
241
+ )
242
+ response.raise_for_status()
243
+
244
+ response_data = response.json()
245
+ output = {"response": response_data["response"]}
246
+
247
+ metadata = {
248
+ "image_paths": image_paths,
249
+ "prompt": prompt,
250
+ "system_prompt": system_prompt,
251
+ "max_new_tokens": max_new_tokens,
252
+ "num_images": len(image_paths),
253
+ "analysis_status": "completed",
254
+ }
255
+
256
+ return output, metadata
257
+
258
+ except httpx.HTTPStatusError as e:
259
+ return self._create_error_response(
260
+ image_paths,
261
+ prompt,
262
+ f"Error calling MedGemma API: {e.response.status_code} - {e.response.text}",
263
+ "http_error",
264
+ f"Status {e.response.status_code}: {e.response.text}"
265
+ )
266
+ except Exception as e:
267
+ return self._create_error_response(
268
+ image_paths,
269
+ prompt,
270
+ f"An unexpected error occurred: {str(e)}",
271
+ "general_error",
272
+ str(e)
273
+ )
274
+ finally:
275
+ # Ensure all opened files are closed
276
+ if 'opened_files' in locals():
277
+ for f in opened_files:
278
+ f.close()
279
+
280
+ #TODO: delete this
281
+ if __name__ == "__main__":
282
+ tool = MedGemmaAPIClientTool(api_url="http://kn045:8002")
283
+ output, metadata = tool._run(
284
+ image_paths=["/home/emxie/scratch/MedRAX2/demo/chest/pneumonia1.jpg"],
285
+ prompt="Classify the xray",
286
+ system_prompt="You are a radiologist.",
287
+ max_new_tokens=300
288
+ )
289
+ print(output)
290
+ print(metadata)
medrax/tools/vqa/medgemma/medgemma_requirements.txt ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.9.0
2
+ annotated_types==0.7.0+computecanada
3
+ anyio==4.9.0+computecanada
4
+ bitsandbytes==0.46.0+computecanada
5
+ certifi==2025.7.14+computecanada
6
+ charset_normalizer==3.4.2+computecanada
7
+ click==8.2.1+computecanada
8
+ fastapi==0.116.1+computecanada
9
+ filelock==3.18.0+computecanada
10
+ fsspec==2025.7.0+computecanada
11
+ h11==0.16.0+computecanada
12
+ hf_xet==1.1.3+computecanada
13
+ httpcore==1.0.9+computecanada
14
+ httpx==0.28.1+computecanada
15
+ huggingface-hub==0.34.3
16
+ idna==3.10+computecanada
17
+ inquirerpy==0.3.4+computecanada
18
+ jinja2==3.1.6+computecanada
19
+ jsonpatch==1.33+computecanada
20
+ jsonpointer==3.0.0+computecanada
21
+ langchain-core==0.3.72
22
+ langsmith==0.4.8+computecanada
23
+ MarkupSafe==2.1.5+computecanada
24
+ mpmath==1.3.0+computecanada
25
+ networkx==3.5+computecanada
26
+ numpy==2.2.2+computecanada
27
+ orjson==3.10.5+computecanada
28
+ packaging==25.0+computecanada
29
+ pfzy==0.3.4+computecanada
30
+ pillow==11.1.0+computecanada
31
+ prompt_toolkit==3.0.51+computecanada
32
+ psutil==6.1.1+computecanada
33
+ pydantic==2.11.7+computecanada
34
+ pydantic_core==2.33.2+computecanada
35
+ python_multipart==0.0.20+computecanada
36
+ PyYAML==6.0.2+computecanada
37
+ regex==2024.11.6+computecanada
38
+ requests==2.32.4+computecanada
39
+ requests_toolbelt==1.0.0+computecanada
40
+ safetensors==0.5.3+computecanada
41
+ sniffio==1.3.1+computecanada
42
+ sshuttle==1.3.1
43
+ starlette==0.47.2
44
+ sympy==1.14.0+computecanada
45
+ tenacity==9.1.2+computecanada
46
+ tokenizers==0.21.1+computecanada
47
+ torch==2.7.1+computecanada
48
+ tqdm==4.67.1+computecanada
49
+ transformers==4.54.1
50
+ typing_extensions==4.14.1+computecanada
51
+ typing_inspection==0.4.1+computecanada
52
+ urllib3==2.5.0+computecanada
53
+ uvicorn==0.35.0+computecanada
54
+ wcwidth==0.2.13+computecanada
55
+ zstandard==0.23.0+computecanada
medrax/tools/vqa/medgemma/medgemma_setup.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import subprocess
4
+ import venv
5
+
6
+ def setup_medgemma_env():
7
+ """Set up MedGemma virtual environment and launch the FastAPI service.
8
+
9
+ This function performs the following steps:
10
+ 1. Creates a virtual environment for MedGemma if it doesn't exist
11
+ 2. Installs MedGemma-specific dependencies from requirements.txt
12
+ 3. Launches the MedGemma FastAPI service in the isolated environment
13
+
14
+ Returns:
15
+ None: Launches MedGemma service as a background process
16
+
17
+ Raises:
18
+ subprocess.CalledProcessError: If pip installation fails
19
+ FileNotFoundError: If required files are missing
20
+ OSError: If virtual environment creation fails
21
+ """
22
+ # Get the directory containing this script
23
+ current_dir = Path(__file__).resolve().parent
24
+
25
+ # Define paths for MedGemma components
26
+ medgemma_path = current_dir / "medgemma.py"
27
+ requirements_path = current_dir / "medgemma_requirements.txt"
28
+ env_dir = current_dir / "medgemma_env"
29
+
30
+ # Determine executable paths based on operating system
31
+ if os.name == "nt": # Windows
32
+ pip_executable = env_dir / "Scripts" / "pip"
33
+ python_executable = env_dir / "Scripts" / "python"
34
+ else: # Unix/Linux/macOS
35
+ pip_executable = env_dir / "bin" / "pip"
36
+ python_executable = env_dir / "bin" / "python"
37
+
38
+ # Create virtual environment if it doesn't exist
39
+ if not env_dir.exists():
40
+ print("Creating MedGemma virtual environment...")
41
+ venv.create(env_dir, with_pip=True)
42
+
43
+ # Install MedGemma dependencies
44
+ print("Installing MedGemma dependencies...")
45
+ subprocess.check_call([
46
+ str(pip_executable),
47
+ "install",
48
+ "-r",
49
+ str(requirements_path)
50
+ ])
51
+
52
+ # Ensure environment exists before accessing executables
53
+ if not env_dir.exists():
54
+ raise RuntimeError("Failed to create MedGemma virtual environment")
55
+
56
+ # Launch MedGemma FastAPI service
57
+ print("Launching MedGemma FastAPI service...")
58
+ subprocess.Popen([
59
+ str(python_executable),
60
+ str(medgemma_path)
61
+ ])
62
+ # Note: stdout and stderr redirection commented out for debugging
63
+ # stdout=subprocess.DEVNULL,
64
+ # stderr=subprocess.DEVNULL,
medrax/tools/vqa/xray_vqa.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Type, Any
2
+ from pathlib import Path
3
+ from pydantic import BaseModel, Field
4
+
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from langchain_core.callbacks import (
9
+ AsyncCallbackManagerForToolRun,
10
+ CallbackManagerForToolRun,
11
+ )
12
+ from langchain_core.tools import BaseTool
13
+
14
+
15
+ class XRayVQAToolInput(BaseModel):
16
+ """Input schema for the CheXagent Tool."""
17
+
18
+ image_paths: List[str] = Field(
19
+ ..., description="List of paths to chest X-ray images to analyze"
20
+ )
21
+ prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
22
+ max_new_tokens: int = Field(
23
+ 512, description="Maximum number of tokens to generate in the response"
24
+ )
25
+
26
+
27
+ class CheXagentXRayVQATool(BaseTool):
28
+ """Tool that leverages CheXagent for comprehensive chest X-ray analysis."""
29
+
30
+ name: str = "chexagent_xray_vqa"
31
+ description: str = (
32
+ "A versatile tool for analyzing chest X-rays. "
33
+ "Can perform multiple tasks including: visual question answering, report generation, "
34
+ "abnormality detection, comparative analysis, anatomical description, "
35
+ "and clinical interpretation. Input should be paths to X-ray images "
36
+ "and a natural language prompt describing the analysis needed."
37
+ )
38
+ args_schema: Type[BaseModel] = XRayVQAToolInput
39
+ return_direct: bool = True
40
+ cache_dir: Optional[str] = None
41
+ device: Optional[str] = None
42
+ dtype: torch.dtype = torch.bfloat16
43
+ tokenizer: Optional[AutoTokenizer] = None
44
+ model: Optional[AutoModelForCausalLM] = None
45
+
46
+ def __init__(
47
+ self,
48
+ model_name: str = "StanfordAIMI/CheXagent-2-3b",
49
+ device: Optional[str] = "cuda",
50
+ dtype: torch.dtype = torch.bfloat16,
51
+ cache_dir: Optional[str] = None,
52
+ **kwargs: Any,
53
+ ) -> None:
54
+ """Initialize the CheXagentXRayVQATool.
55
+
56
+ Args:
57
+ model_name: Name of the CheXagent model to use
58
+ device: Device to run model on (cuda/cpu)
59
+ dtype: Data type for model weights
60
+ cache_dir: Directory to cache downloaded models
61
+ **kwargs: Additional arguments
62
+ """
63
+ super().__init__(**kwargs)
64
+
65
+ # Dangerous code, but works for now
66
+ import transformers
67
+
68
+ original_transformers_version = transformers.__version__
69
+ transformers.__version__ = "4.40.0"
70
+
71
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
72
+ self.dtype = dtype
73
+ self.cache_dir = cache_dir
74
+
75
+ # Load tokenizer and model
76
+ self.tokenizer = AutoTokenizer.from_pretrained(
77
+ model_name,
78
+ trust_remote_code=True,
79
+ cache_dir=cache_dir,
80
+ )
81
+ self.model = AutoModelForCausalLM.from_pretrained(
82
+ model_name,
83
+ device_map=self.device,
84
+ trust_remote_code=True,
85
+ cache_dir=cache_dir,
86
+ )
87
+ self.model = self.model.to(dtype=self.dtype)
88
+ self.model.eval()
89
+
90
+ transformers.__version__ = original_transformers_version
91
+
92
+ def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str:
93
+ """Generate response using CheXagent model.
94
+
95
+ Args:
96
+ image_paths: List of paths to chest X-ray images
97
+ prompt: Question or instruction about the images
98
+ max_new_tokens: Maximum number of tokens to generate
99
+ Returns:
100
+ str: Model's response
101
+ """
102
+ query = self.tokenizer.from_list_format(
103
+ [*[{"image": path} for path in image_paths], {"text": prompt}]
104
+ )
105
+ conv = [
106
+ {"from": "system", "value": "You are a helpful assistant."},
107
+ {"from": "human", "value": query},
108
+ ]
109
+ input_ids = self.tokenizer.apply_chat_template(
110
+ conv, add_generation_prompt=True, return_tensors="pt"
111
+ ).to(device=self.device)
112
+
113
+ # Run inference
114
+ with torch.inference_mode():
115
+ output = self.model.generate(
116
+ input_ids,
117
+ do_sample=False,
118
+ num_beams=1,
119
+ temperature=1.0,
120
+ top_p=1.0,
121
+ use_cache=True,
122
+ max_new_tokens=max_new_tokens,
123
+ )[0]
124
+ response = self.tokenizer.decode(output[input_ids.size(1) : -1])
125
+
126
+ return response
127
+
128
+ def _run(
129
+ self,
130
+ image_paths: List[str],
131
+ prompt: str,
132
+ max_new_tokens: int = 512,
133
+ run_manager: Optional[CallbackManagerForToolRun] = None,
134
+ ) -> Tuple[Dict[str, Any], Dict]:
135
+ """Execute the chest X-ray analysis.
136
+
137
+ Args:
138
+ image_paths: List of paths to chest X-ray images
139
+ prompt: Question or instruction about the images
140
+ max_new_tokens: Maximum number of tokens to generate
141
+ run_manager: Optional callback manager
142
+
143
+ Returns:
144
+ Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
145
+ """
146
+ try:
147
+ # Verify image paths
148
+ for path in image_paths:
149
+ if not Path(path).is_file():
150
+ raise FileNotFoundError(f"Image file not found: {path}")
151
+
152
+ response = self._generate_response(image_paths, prompt, max_new_tokens)
153
+
154
+ output = {
155
+ "response": response,
156
+ }
157
+
158
+ metadata = {
159
+ "image_paths": image_paths,
160
+ "prompt": prompt,
161
+ "max_new_tokens": max_new_tokens,
162
+ "analysis_status": "completed",
163
+ }
164
+
165
+ return output, metadata
166
+
167
+ except Exception as e:
168
+ output = {"error": str(e)}
169
+ metadata = {
170
+ "image_paths": image_paths,
171
+ "prompt": prompt,
172
+ "max_new_tokens": max_new_tokens,
173
+ "analysis_status": "failed",
174
+ "error_details": str(e),
175
+ }
176
+ return output, metadata
177
+
178
+ async def _arun(
179
+ self,
180
+ image_paths: List[str],
181
+ prompt: str,
182
+ max_new_tokens: int = 512,
183
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
184
+ ) -> Tuple[Dict[str, Any], Dict]:
185
+ """Async version of _run."""
186
+ return self._run(image_paths, prompt, max_new_tokens)