Adibvafa commited on
Commit
bc86327
·
2 Parent(s): e116685 aff69d7

Fix merge conflicts

Browse files
.env.example CHANGED
@@ -6,4 +6,5 @@ GOOGLE_SEARCH_ENGINE_ID=
6
  OPENROUTER_API_KEY=
7
  OPENROUTER_BASE_URL=
8
  COHERE_API_KEY=
9
- PINECONE_API_KEY=
 
 
6
  OPENROUTER_API_KEY=
7
  OPENROUTER_BASE_URL=
8
  COHERE_API_KEY=
9
+ PINECONE_API_KEY=
10
+ MEDGEMMA_API_URL=
.gitignore CHANGED
@@ -179,4 +179,6 @@ model-weights/
179
 
180
  .DS_Store
181
 
182
- benchmarking/data/
 
 
 
179
 
180
  .DS_Store
181
 
182
+ benchmarking/data/
183
+ model_cache/
184
+ medgemma/
README.md CHANGED
@@ -22,12 +22,14 @@ MedRAX is built on a robust technical foundation:
22
 
23
  ### Integrated Tools
24
  - **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
 
25
  - **Segmentation**: Employs MedSAM2 (advanced medical image segmentation) and PSPNet model trained on ChestX-Det for precise anatomical structure identification
26
  - **Grounding**: Uses Maira-2 for localizing specific findings in medical images
27
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
28
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
29
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
30
  - **Web Browser**: Provides web search capabilities and URL content retrieval using Google Custom Search API
 
31
  - **Python Sandbox**: Executes Python code in a secure, stateful sandbox environment using `langchain-sandbox` and Pyodide. Supports custom data analysis, calculations, and dynamic package installations. Pre-configured with medical analysis packages including pandas, numpy, pydicom, SimpleITK, scikit-image, Pillow, scikit-learn, matplotlib, seaborn, and openpyxl. **Requires Deno runtime.**
32
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
33
  <br><br>
@@ -130,6 +132,10 @@ PINECONE_API_KEY=
130
  # Requires Google Custom Search API credentials.
131
  GOOGLE_SEARCH_API_KEY=
132
  GOOGLE_SEARCH_ENGINE_ID=
 
 
 
 
133
  ```
134
 
135
  ### Getting Started
@@ -159,6 +165,7 @@ selected_tools = [
159
  "ChestXRaySegmentationTool",
160
  "PythonSandboxTool", # Python code execution
161
  "WebBrowserTool", # Web search and URL access
 
162
  # Add or remove tools as needed
163
  ]
164
 
@@ -174,17 +181,10 @@ agent, tools_dict = initialize_agent(
174
 
175
  The following tools will automatically download their model weights when initialized:
176
 
177
- ### Classification Tools
178
  ```python
179
  # TorchXRayVision-based classifier (original)
180
  TorchXRayVisionClassifierTool(device=device)
181
-
182
- # ArcPlus SwinTransformer-based classifier (new)
183
- ArcPlusClassifierTool(
184
- model_path="/path/to/Ark6_swinLarge768_ep50.pth.tar", # Optional
185
- num_classes=18, # Default
186
- device=device
187
- )
188
  ```
189
 
190
  ### Segmentation Tool
@@ -232,6 +232,21 @@ XRayVQATool(
232
  ```
233
  - CheXagent weights download automatically
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  ### MedSAM2 Tool
236
  ```python
237
  MedSAM2Tool(
@@ -263,6 +278,7 @@ No additional model weights required:
263
  ImageVisualizerTool()
264
  DicomProcessorTool(temp_dir=temp_dir)
265
  WebBrowserTool() # Requires Google Search API credentials
 
266
  ```
267
  <br>
268
 
@@ -281,6 +297,25 @@ ChestXRayGeneratorTool(
281
  2. Place weights in `{model_dir}/roentgen`
282
  3. Optional tool, can be excluded if not needed
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  ### Knowledge Base Setup (MedicalRAGTool)
285
 
286
  The `MedicalRAGTool` uses a Pinecone vector database to store and retrieve medical knowledge. To use this tool, you need to set up a Pinecone account and a Cohere account.
@@ -383,6 +418,8 @@ If you are running a local LLM using frameworks like [Ollama](https://ollama.com
383
 
384
  **WebBrowserTool**: Requires Google Custom Search API credentials, which can be set in the `.env` file.
385
 
 
 
386
  **PythonSandboxTool**: Requires Deno runtime installation:
387
  ```bash
388
  # Verify Deno is installed
 
22
 
23
  ### Integrated Tools
24
  - **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
25
+ - **MedGemma VQA**: Advanced medical visual question answering using Google's MedGemma 4B model for comprehensive medical image analysis across multiple modalities
26
  - **Segmentation**: Employs MedSAM2 (advanced medical image segmentation) and PSPNet model trained on ChestX-Det for precise anatomical structure identification
27
  - **Grounding**: Uses Maira-2 for localizing specific findings in medical images
28
  - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
29
  - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
30
  - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
31
  - **Web Browser**: Provides web search capabilities and URL content retrieval using Google Custom Search API
32
+ - **DuckDuckGo Search**: Offers privacy-focused web search capabilities using DuckDuckGo search engine for medical research, fact-checking, and accessing current medical information without API keys
33
  - **Python Sandbox**: Executes Python code in a secure, stateful sandbox environment using `langchain-sandbox` and Pyodide. Supports custom data analysis, calculations, and dynamic package installations. Pre-configured with medical analysis packages including pandas, numpy, pydicom, SimpleITK, scikit-image, Pillow, scikit-learn, matplotlib, seaborn, and openpyxl. **Requires Deno runtime.**
34
  - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
35
  <br><br>
 
132
  # Requires Google Custom Search API credentials.
133
  GOOGLE_SEARCH_API_KEY=
134
  GOOGLE_SEARCH_ENGINE_ID=
135
+
136
+ # MedGemma VQA Tool (Optional)
137
+ # URL for the MedGemma FastAPI service
138
+ MEDGEMMA_API_URL=
139
  ```
140
 
141
  ### Getting Started
 
165
  "ChestXRaySegmentationTool",
166
  "PythonSandboxTool", # Python code execution
167
  "WebBrowserTool", # Web search and URL access
168
+ "DuckDuckGoSearchTool", # Privacy-focused web search
169
  # Add or remove tools as needed
170
  ]
171
 
 
181
 
182
  The following tools will automatically download their model weights when initialized:
183
 
184
+ ### Classification Tool
185
  ```python
186
  # TorchXRayVision-based classifier (original)
187
  TorchXRayVisionClassifierTool(device=device)
 
 
 
 
 
 
 
188
  ```
189
 
190
  ### Segmentation Tool
 
232
  ```
233
  - CheXagent weights download automatically
234
 
235
+ ### MedGemma VQA Tool
236
+ ```python
237
+ MedGemmaAPIClientTool(
238
+ device=device,
239
+ cache_dir=model_dir,
240
+ api_url=MEDGEMMA_API_URL)
241
+ )
242
+ ```
243
+ - Uses Google's MedGemma 4B instruction-tuned model for comprehensive medical image analysis
244
+ - Specialized for chest X-rays, dermatology, ophthalmology, and pathology images
245
+ - Provides radiologist-level medical reasoning and diagnosis assistance
246
+ - Supports up to 128K context length and 896x896 image resolution
247
+ - 4-bit quantization available (~4GB VRAM) with full precision option (~8GB VRAM)
248
+ - Model weights download automatically when the service starts
249
+
250
  ### MedSAM2 Tool
251
  ```python
252
  MedSAM2Tool(
 
278
  ImageVisualizerTool()
279
  DicomProcessorTool(temp_dir=temp_dir)
280
  WebBrowserTool() # Requires Google Search API credentials
281
+ DuckDuckGoSearchTool() # No API key required, privacy-focused search
282
  ```
283
  <br>
284
 
 
297
  2. Place weights in `{model_dir}/roentgen`
298
  3. Optional tool, can be excluded if not needed
299
 
300
+ ### ArcPlus SwinTransformer-based Classifier
301
+ ```python
302
+ ArcPlusClassifierTool(
303
+ model_path="/path/to/Ark6_swinLarge768_ep50.pth.tar", # Optional
304
+ num_classes=18, # Default
305
+ device=device
306
+ )
307
+ ```
308
+
309
+ The ArcPlus classifier requires manual setup as the pre-trained model is not publicly available for automatic download:
310
+
311
+ 1. **Request Access**: Visit [https://github.com/jlianglab/Ark](https://github.com/jlianglab/Ark) and request the pretrained model through their Google Forms
312
+ 2. **Download Model**: Once approved, download the `Ark6_swinLarge768_ep50.pth.tar` file
313
+ 3. **Place in Directory**: Drag the downloaded file into your `model-weights` directory
314
+ 4. **Initialize Tool**: The tool will automatically look for the model file in the specified `cache_dir`
315
+
316
+ The ArcPlus model provides advanced chest X-ray classification across 6 medical datasets (MIMIC, CheXpert, NIH, RSNA, VinDr, Shenzhen) with 52+ pathology categories.
317
+ ```
318
+
319
  ### Knowledge Base Setup (MedicalRAGTool)
320
 
321
  The `MedicalRAGTool` uses a Pinecone vector database to store and retrieve medical knowledge. To use this tool, you need to set up a Pinecone account and a Cohere account.
 
418
 
419
  **WebBrowserTool**: Requires Google Custom Search API credentials, which can be set in the `.env` file.
420
 
421
+ **DuckDuckGoSearchTool**: No API key required. Uses DuckDuckGo's privacy-focused search engine for medical research and fact-checking.
422
+
423
  **PythonSandboxTool**: Requires Deno runtime installation:
424
  ```bash
425
  # Verify Deno is installed
benchmarking/benchmarks/base.py CHANGED
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
4
  from typing import Dict, List, Optional, Any, Iterator, Tuple
5
  from dataclasses import dataclass
6
  from pathlib import Path
 
7
 
8
 
9
  @dataclass
@@ -31,17 +32,31 @@ class Benchmark(ABC):
31
  Args:
32
  data_dir (str): Directory containing benchmark data
33
  **kwargs: Additional configuration parameters
 
34
  """
35
  self.data_dir = Path(data_dir)
36
  self.config = kwargs
37
  self.data_points = []
38
  self._load_data()
 
39
 
40
  @abstractmethod
41
  def _load_data(self) -> None:
42
  """Load benchmark data from the data directory."""
43
  pass
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
46
  """Get a specific data point by index.
47
 
 
4
  from typing import Dict, List, Optional, Any, Iterator, Tuple
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
+ import random
8
 
9
 
10
  @dataclass
 
32
  Args:
33
  data_dir (str): Directory containing benchmark data
34
  **kwargs: Additional configuration parameters
35
+ random_seed (int): Random seed for shuffling data (default: None, no shuffling)
36
  """
37
  self.data_dir = Path(data_dir)
38
  self.config = kwargs
39
  self.data_points = []
40
  self._load_data()
41
+ self._shuffle_data()
42
 
43
  @abstractmethod
44
  def _load_data(self) -> None:
45
  """Load benchmark data from the data directory."""
46
  pass
47
 
48
+ def _shuffle_data(self) -> None:
49
+ """Shuffle the data points if a random seed is provided.
50
+
51
+ This method is called automatically after data loading to ensure
52
+ reproducible benchmark runs when a random seed is specified.
53
+ """
54
+ random_seed = self.config.get("random_seed", None)
55
+ if random_seed is not None:
56
+ random.seed(random_seed)
57
+ random.shuffle(self.data_points)
58
+ print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
59
+
60
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
61
  """Get a specific data point by index.
62
 
benchmarking/benchmarks/chestagentbench_benchmark.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import random
3
  from pathlib import Path
4
  from typing import Dict, Optional, Any
5
  from .base import Benchmark, BenchmarkDataPoint
@@ -31,10 +30,6 @@ class ChestAgentBenchBenchmark(Benchmark):
31
  except Exception as e:
32
  print(f"Error loading item {i}: {e}")
33
  continue
34
-
35
- # Shuffle the final data
36
- random.seed(42)
37
- random.shuffle(self.data_points)
38
 
39
  def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
40
  # Use full_question_id or question_id if available, else fallback
 
1
  import json
 
2
  from pathlib import Path
3
  from typing import Dict, Optional, Any
4
  from .base import Benchmark, BenchmarkDataPoint
 
30
  except Exception as e:
31
  print(f"Error loading item {i}: {e}")
32
  continue
 
 
 
 
33
 
34
  def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
35
  # Use full_question_id or question_id if available, else fallback
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -34,20 +34,20 @@ class ReXVQABenchmark(Benchmark):
34
  data_dir (str): Directory to store/cache downloaded data
35
  **kwargs: Additional configuration parameters
36
  split (str): Dataset split to use (default: 'test')
37
- cache_dir (str): Directory for caching HuggingFace datasets
38
  trust_remote_code (bool): Whether to trust remote code (default: False)
39
  max_questions (int): Maximum number of questions to load (default: None, load all)
40
  images_dir (str): Directory containing extracted PNG images (default: None)
41
  """
42
  self.split = kwargs.get("split", "test")
43
- self.cache_dir = kwargs.get("cache_dir", None)
44
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
45
  self.max_questions = kwargs.get("max_questions", None)
46
- self.images_dir = "benchmarking/data/rexvqa/images/deid_png"
47
  self.image_dataset = None
48
  self.image_mapping = {} # Maps study_id to image data
49
 
50
  super().__init__(data_dir, **kwargs)
 
 
 
51
 
52
  @staticmethod
53
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
@@ -166,8 +166,8 @@ class ReXVQABenchmark(Benchmark):
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
- self.download_test_vqa_data_json()
170
- self.download_rexgradient_images()
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
@@ -197,7 +197,7 @@ class ReXVQABenchmark(Benchmark):
197
  self.image_dataset = load_dataset(
198
  "rajpurkarlab/ReXGradient-160K",
199
  split="test",
200
- cache_dir=self.cache_dir,
201
  trust_remote_code=self.trust_remote_code
202
  )
203
  print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
 
34
  data_dir (str): Directory to store/cache downloaded data
35
  **kwargs: Additional configuration parameters
36
  split (str): Dataset split to use (default: 'test')
 
37
  trust_remote_code (bool): Whether to trust remote code (default: False)
38
  max_questions (int): Maximum number of questions to load (default: None, load all)
39
  images_dir (str): Directory containing extracted PNG images (default: None)
40
  """
41
  self.split = kwargs.get("split", "test")
 
42
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
  self.max_questions = kwargs.get("max_questions", None)
 
44
  self.image_dataset = None
45
  self.image_mapping = {} # Maps study_id to image data
46
 
47
  super().__init__(data_dir, **kwargs)
48
+
49
+ # Set images_dir after parent initialization
50
+ self.images_dir = f"{self.data_dir}/images/deid_png"
51
 
52
  @staticmethod
53
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
 
166
  """Load ReXVQA data from local JSON file."""
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
+ self.download_test_vqa_data_json(self.data_dir)
170
+ self.download_rexgradient_images(self.data_dir)
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
 
197
  self.image_dataset = load_dataset(
198
  "rajpurkarlab/ReXGradient-160K",
199
  split="test",
200
+ cache_dir=self.data_dir,
201
  trust_remote_code=self.trust_remote_code
202
  )
203
  print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
benchmarking/cli.py CHANGED
@@ -73,6 +73,8 @@ def run_benchmark_command(args) -> None:
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
 
 
76
 
77
  benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
78
 
@@ -135,12 +137,14 @@ def main():
135
  help="Output directory for results (default: benchmark_results)")
136
  run_parser.add_argument("--max-questions", type=int,
137
  help="Maximum number of questions to process (default: all)")
138
- run_parser.add_argument("--temperature", type=float, default=0.7,
139
  help="Model temperature for response generation (default: 0.7)")
140
  run_parser.add_argument("--top-p", type=float, default=0.95,
141
  help="Top-p nucleus sampling parameter (default: 0.95)")
142
  run_parser.add_argument("--max-tokens", type=int, default=5000,
143
  help="Maximum tokens per model response (default: 5000)")
 
 
144
 
145
  run_parser.set_defaults(func=run_benchmark_command)
146
 
 
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
76
+ if args.random_seed is not None:
77
+ benchmark_kwargs["random_seed"] = args.random_seed
78
 
79
  benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
80
 
 
137
  help="Output directory for results (default: benchmark_results)")
138
  run_parser.add_argument("--max-questions", type=int,
139
  help="Maximum number of questions to process (default: all)")
140
+ run_parser.add_argument("--temperature", type=float, default=1,
141
  help="Model temperature for response generation (default: 0.7)")
142
  run_parser.add_argument("--top-p", type=float, default=0.95,
143
  help="Top-p nucleus sampling parameter (default: 0.95)")
144
  run_parser.add_argument("--max-tokens", type=int, default=5000,
145
  help="Maximum tokens per model response (default: 5000)")
146
+ run_parser.add_argument("--random-seed", type=int, default=42,
147
+ help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
148
 
149
  run_parser.set_defaults(func=run_benchmark_command)
150
 
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -33,20 +33,36 @@ class MedRAXProvider(LLMProvider):
33
  print("Starting server...")
34
 
35
  selected_tools = [
 
36
  # "ImageVisualizerTool", # For displaying images in the UI
37
  # "DicomProcessorTool", # For processing DICOM medical image files
38
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
39
- # "LlavaMedTool", # For multimodal medical image understanding
40
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
41
- # "PythonSandboxTool", # Add the Python sandbox tool
42
-
 
 
 
 
43
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
 
 
 
 
 
 
 
 
 
 
44
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
 
 
45
  # "WebBrowserTool", # For web browsing and search capabilities
46
- # "XRayVQATool", # For visual question answering on X-rays
47
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
48
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
49
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
50
  ]
51
 
52
  rag_config = RAGConfig(
@@ -69,11 +85,11 @@ class MedRAXProvider(LLMProvider):
69
  agent, tools_dict = initialize_agent(
70
  prompt_file="medrax/docs/system_prompts.txt",
71
  tools_to_use=selected_tools,
72
- model_dir="/model-weights",
73
  temp_dir="temp", # Change this to the path of the temporary directory
74
  device="cuda:1",
75
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
76
- temperature=0.3,
77
  top_p=0.95,
78
  model_kwargs=model_kwargs,
79
  rag_config=rag_config,
 
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
+ # Image Processing Tools
37
  # "ImageVisualizerTool", # For displaying images in the UI
38
  # "DicomProcessorTool", # For processing DICOM medical image files
39
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
 
40
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
41
+
42
+ # Classification Tools
43
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
44
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
45
+
46
+ # Report Generation Tools
47
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
48
+
49
+ # Grounding Tools
50
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
51
+
52
+ # VQA Tools
53
+ # "LlavaMedTool", # For multimodal medical image understanding
54
+ # "XRayVQATool", # For visual question answering on X-rays
55
+ "MedGemmaVQATool",
56
+
57
+ # RAG Tools
58
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
59
+
60
+ # Search Tools
61
  # "WebBrowserTool", # For web browsing and search capabilities
62
+ # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
63
+
64
+ # Development Tools
65
+ # "PythonSandboxTool", # Add the Python sandbox tool
66
  ]
67
 
68
  rag_config = RAGConfig(
 
85
  agent, tools_dict = initialize_agent(
86
  prompt_file="medrax/docs/system_prompts.txt",
87
  tools_to_use=selected_tools,
88
+ model_dir="/scratch/ssd004/scratch/victorli/model-weights",
89
  temp_dir="temp", # Change this to the path of the temporary directory
90
  device="cuda:1",
91
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
92
+ temperature=1.0,
93
  top_p=0.95,
94
  model_kwargs=model_kwargs,
95
  rag_config=rag_config,
benchmarking/runner.py CHANGED
@@ -268,12 +268,6 @@ class BenchmarkRunner:
268
  if match:
269
  return match.group(1).upper()
270
 
271
- # Fallback: look for the '<|A|>' format (legacy code, will remove later on)
272
- legacy_pattern = r'\s*<\|([A-F])\|>'
273
- match = re.search(legacy_pattern, response_text)
274
- if match:
275
- return match.group(1).upper()
276
-
277
  # If no pattern matches, return the full response
278
  return response_text.strip()
279
 
 
268
  if match:
269
  return match.group(1).upper()
270
 
 
 
 
 
 
 
271
  # If no pattern matches, return the full response
272
  return response_text.strip()
273
 
interface.py CHANGED
@@ -193,7 +193,11 @@ class ChatInterface:
193
 
194
  # First, display the tool usage card
195
  try:
196
- tool_output_json = json.loads(msg.content)
 
 
 
 
197
  tool_output_str = json.dumps(tool_output_json, indent=2)
198
  except (json.JSONDecodeError, TypeError):
199
  tool_output_str = str(msg.content)
@@ -216,19 +220,19 @@ class ChatInterface:
216
  # Special handling for image_visualizer
217
  if tool_name == "image_visualizer":
218
  try:
219
- # Tool returns (output, metadata) tuple
220
- # msg.content should be the serialized version of this
221
- result = eval(msg.content) # Safe here as it's from our tool
222
- if isinstance(result, tuple) and len(result) >= 1:
223
- output_dict = result[0]
224
- if isinstance(output_dict, dict) and "image_path" in output_dict:
225
- self.display_file_path = output_dict["image_path"]
226
- chat_history.append(
227
- ChatMessage(
228
- role="assistant",
229
- content={"path": self.display_file_path},
230
- )
231
  )
 
232
  except Exception:
233
  pass
234
 
 
193
 
194
  # First, display the tool usage card
195
  try:
196
+ # Handle case where tool returns tuple (output, metadata)
197
+ content = msg.content
198
+ content_tuple = ast.literal_eval(content)
199
+ content = json.dumps(content_tuple[0])
200
+ tool_output_json = json.loads(content)
201
  tool_output_str = json.dumps(tool_output_json, indent=2)
202
  except (json.JSONDecodeError, TypeError):
203
  tool_output_str = str(msg.content)
 
220
  # Special handling for image_visualizer
221
  if tool_name == "image_visualizer":
222
  try:
223
+ # Handle case where tool returns tuple (output, metadata)
224
+ content = msg.content
225
+ content_tuple = ast.literal_eval(content)
226
+ result = content_tuple[0]
227
+
228
+ if isinstance(result, dict) and "image_path" in result:
229
+ self.display_file_path = result["image_path"]
230
+ chat_history.append(
231
+ ChatMessage(
232
+ role="assistant",
233
+ content={"path": self.display_file_path},
 
234
  )
235
+ )
236
  except Exception:
237
  pass
238
 
main.py CHANGED
@@ -10,6 +10,7 @@ with different model weights, tools, and parameters.
10
  """
11
 
12
  import warnings
 
13
  from typing import Dict, List, Optional, Any
14
  from dotenv import load_dotenv
15
  from transformers import logging
@@ -33,11 +34,11 @@ _ = load_dotenv()
33
  def initialize_agent(
34
  prompt_file: str,
35
  tools_to_use: Optional[List[str]] = None,
36
- model_dir: str = "/model-weights",
37
  temp_dir: str = "temp",
38
  device: str = "cpu",
39
- model: str = "gpt-4.1-2025-04-14",
40
- temperature: float = 0.7,
41
  top_p: float = 0.95,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
@@ -66,12 +67,15 @@ def initialize_agent(
66
  prompts = load_prompts_from_file(prompt_file)
67
  prompt = prompts[system_prompt]
68
 
 
 
 
69
  all_tools = {
70
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
71
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
72
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
73
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
74
- "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
75
  "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
76
  cache_dir=model_dir, device=device
77
  ),
@@ -85,23 +89,29 @@ def initialize_agent(
85
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
86
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
87
  "WebBrowserTool": lambda: WebBrowserTool(),
 
88
  "MedSAM2Tool": lambda: MedSAM2Tool(
89
  device=device, cache_dir=model_dir, temp_dir=temp_dir
90
  ),
91
- }
92
-
93
- try:
94
- tools_dict["PythonSandboxTool"] = create_python_sandbox()
95
- except Exception as e:
96
- print(f"Error creating PythonSandboxTool: {e}")
97
- print("Skipping PythonSandboxTool")
98
 
99
  # Initialize only selected tools or all if none specified
100
  tools_dict: Dict[str, BaseTool] = {}
101
- tools_to_use = tools_to_use or all_tools.keys()
 
 
 
102
  for tool_name in tools_to_use:
 
 
 
 
 
 
103
  if tool_name in all_tools:
104
  tools_dict[tool_name] = all_tools[tool_name]()
 
105
 
106
  # Set up checkpointing for conversation state
107
  checkpointer = MemorySaver()
@@ -139,22 +149,47 @@ if __name__ == "__main__":
139
  # Example: initialize with only specific tools
140
  # Here three tools are commented out, you can uncomment them to use them
141
  selected_tools = [
 
142
  "ImageVisualizerTool", # For displaying images in the UI
143
  # "DicomProcessorTool", # For processing DICOM medical image files
 
 
 
 
 
 
 
 
 
144
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
145
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
146
- "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
 
147
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
 
 
 
 
 
 
148
  "XRayVQATool", # For visual question answering on X-rays
149
  # "LlavaMedTool", # For multimodal medical image understanding
150
- "XRayPhraseGroundingTool", # For locating described features in X-rays
151
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
152
- "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
153
- "WebBrowserTool", # For web browsing and search capabilities
154
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
 
 
 
 
 
 
155
  # "PythonSandboxTool", # Add the Python sandbox tool
156
  ]
157
 
 
 
 
 
158
  # Configure the Retrieval Augmented Generation (RAG) system
159
  # This allows the agent to access and use medical knowledge documents
160
  rag_config = RAGConfig(
@@ -177,11 +212,11 @@ if __name__ == "__main__":
177
  agent, tools_dict = initialize_agent(
178
  prompt_file="medrax/docs/system_prompts.txt",
179
  tools_to_use=selected_tools,
180
- model_dir="/model-weights",
181
  temp_dir="temp", # Change this to the path of the temporary directory
182
  device="cuda:1",
183
- model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
184
- temperature=0.7,
185
  top_p=0.95,
186
  model_kwargs=model_kwargs,
187
  rag_config=rag_config,
 
10
  """
11
 
12
  import warnings
13
+ import os
14
  from typing import Dict, List, Optional, Any
15
  from dotenv import load_dotenv
16
  from transformers import logging
 
34
  def initialize_agent(
35
  prompt_file: str,
36
  tools_to_use: Optional[List[str]] = None,
37
+ model_dir: str = "model-weights",
38
  temp_dir: str = "temp",
39
  device: str = "cpu",
40
+ model: str = "gemini-2.5-pro",
41
+ temperature: float = 1.0,
42
  top_p: float = 0.95,
43
  rag_config: Optional[RAGConfig] = None,
44
  model_kwargs: Dict[str, Any] = {},
 
67
  prompts = load_prompts_from_file(prompt_file)
68
  prompt = prompts[system_prompt]
69
 
70
+ # Define the URL of the MedGemma FastAPI service.
71
+ MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://172.17.8.141:8002")
72
+
73
  all_tools = {
74
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
75
  "ArcPlusClassifierTool": lambda: ArcPlusClassifierTool(cache_dir=model_dir, device=device),
76
  "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
77
  "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
78
+ "CheXagentXRayVQATool": lambda: CheXagentXRayVQATool(cache_dir=model_dir, device=device),
79
  "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
80
  cache_dir=model_dir, device=device
81
  ),
 
89
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
90
  "MedicalRAGTool": lambda: RAGTool(config=rag_config),
91
  "WebBrowserTool": lambda: WebBrowserTool(),
92
+ "DuckDuckGoSearchTool": lambda: DuckDuckGoSearchTool(),
93
  "MedSAM2Tool": lambda: MedSAM2Tool(
94
  device=device, cache_dir=model_dir, temp_dir=temp_dir
95
  ),
96
+ "MedGemmaVQATool": lambda: MedGemmaAPIClientTool(cache_dir=model_dir, device=device, api_url=MEDGEMMA_API_URL)
97
+ }
 
 
 
 
 
98
 
99
  # Initialize only selected tools or all if none specified
100
  tools_dict: Dict[str, BaseTool] = {}
101
+
102
+ if tools_to_use is None:
103
+ tools_to_use = []
104
+
105
  for tool_name in tools_to_use:
106
+ if tool_name == "PythonSandboxTool":
107
+ try:
108
+ tools_dict["PythonSandboxTool"] = create_python_sandbox()
109
+ except Exception as e:
110
+ print(f"Error creating PythonSandboxTool: {e}")
111
+ print("Skipping PythonSandboxTool")
112
  if tool_name in all_tools:
113
  tools_dict[tool_name] = all_tools[tool_name]()
114
+
115
 
116
  # Set up checkpointing for conversation state
117
  checkpointer = MemorySaver()
 
149
  # Example: initialize with only specific tools
150
  # Here three tools are commented out, you can uncomment them to use them
151
  selected_tools = [
152
+ # Image Processing Tools
153
  "ImageVisualizerTool", # For displaying images in the UI
154
  # "DicomProcessorTool", # For processing DICOM medical image files
155
+
156
+ # Segmentation Tools
157
+ "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
158
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
159
+
160
+ # Generation Tools
161
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
162
+
163
+ # Classification Tools
164
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
165
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
166
+
167
+ # Report Generation Tools
168
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
169
+
170
+ # Grounding Tools
171
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
172
+
173
+ # VQA Tools
174
+ "MedGemmaVQATool", # Google MedGemma VQA tool
175
  "XRayVQATool", # For visual question answering on X-rays
176
  # "LlavaMedTool", # For multimodal medical image understanding
177
+
178
+ # RAG Tools
 
 
179
  "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
180
+
181
+ # Search Tools
182
+ "WebBrowserTool", # For web browsing and search capabilities
183
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
184
+
185
+ # Development Tools
186
  # "PythonSandboxTool", # Add the Python sandbox tool
187
  ]
188
 
189
+ # Setup the MedGemma environment if the MedGemmaVQATool is selected
190
+ if "MedGemmaVQATool" in selected_tools:
191
+ setup_medgemma_env()
192
+
193
  # Configure the Retrieval Augmented Generation (RAG) system
194
  # This allows the agent to access and use medical knowledge documents
195
  rag_config = RAGConfig(
 
212
  agent, tools_dict = initialize_agent(
213
  prompt_file="medrax/docs/system_prompts.txt",
214
  tools_to_use=selected_tools,
215
+ model_dir="model-weights",
216
  temp_dir="temp", # Change this to the path of the temporary directory
217
  device="cuda:1",
218
+ model="gpt-4.1-2025-04-14", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5
219
+ temperature=1.0,
220
  top_p=0.95,
221
  model_kwargs=model_kwargs,
222
  rag_config=rag_config,
medrax/docs/system_prompts.txt CHANGED
@@ -17,10 +17,9 @@ Examples:
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
  [CHESTAGENTBENCH_PROMPT]
20
- You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
21
- Solve using your own vision and reasoning and use tools (if available) to complement your reasoning.
22
- You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
- Think critically about and criticize the tool outputs.
24
- If you need to look up some information before asking a follow up question, you are allowed to do that.
25
  When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
26
- It is extremely important that you strictly answer in the format mentioned above.
 
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
  [CHESTAGENTBENCH_PROMPT]
20
+ You are an expert medical assistant who can answer medical questions and analyze medical images with world-class accuracy.
21
+ Use your state-of-the art reasoning and critical thinking skills to answer the questions that you are asked.
22
+ You may use tools (if available) to complement your reasoning and you are allowed to make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
+ Think critically about how to best use the tools available to you and scrutinize the tool outputs.
 
24
  When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
25
+ It is extremely important that you answer strictly in the format described above.
medrax/tools/__init__.py CHANGED
@@ -3,14 +3,11 @@
3
  from .classification import *
4
  from .report_generation import *
5
  from .segmentation import *
6
- from .xray_vqa import *
7
- from .llava_med import *
8
  from .grounding import *
9
- from .generation import *
10
  from .dicom import *
11
  from .utils import *
12
  from .rag import *
13
- from .web_browser import *
14
  from .python_tool import *
15
- from .medsam2 import *
16
-
 
3
  from .classification import *
4
  from .report_generation import *
5
  from .segmentation import *
6
+ from .vqa import *
 
7
  from .grounding import *
8
+ from .xray_generation import *
9
  from .dicom import *
10
  from .utils import *
11
  from .rag import *
12
+ from .browsing import *
13
  from .python_tool import *
 
 
medrax/tools/browsing/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web browsing tools for MedRAX2 medical agents."""
2
+
3
+ from .duckduckgo import DuckDuckGoSearchTool, WebSearchInput
4
+ from .web_browser import WebBrowserTool, WebBrowserSchema, SearchQuerySchema, VisitUrlSchema
5
+
6
+ __all__ = [
7
+ "DuckDuckGoSearchTool",
8
+ "WebSearchInput",
9
+ "WebBrowserTool",
10
+ "WebBrowserSchema",
11
+ "SearchQuerySchema",
12
+ "VisitUrlSchema"
13
+ ]
medrax/tools/browsing/duckduckgo.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web search tool for MedRAX2 medical agents.
3
+
4
+ Provides DuckDuckGo search capabilities for medical agents to retrieve
5
+ real-time information from the web with proper error handling
6
+ and result formatting. Designed specifically for medical research,
7
+ fact-checking, and accessing current medical information.
8
+ """
9
+
10
+ import asyncio
11
+ import logging
12
+ import time
13
+ from datetime import datetime
14
+ from typing import Dict, Any, Tuple
15
+
16
+ from langchain_core.callbacks import (
17
+ AsyncCallbackManagerForToolRun,
18
+ CallbackManagerForToolRun,
19
+ )
20
+ from langchain_core.tools import BaseTool
21
+ from pydantic import BaseModel, Field
22
+
23
+ try:
24
+ from duckduckgo_search import DDGS
25
+ except ImportError:
26
+ DDGS = None
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class WebSearchInput(BaseModel):
32
+ """Input schema for web search tool."""
33
+
34
+ query: str = Field(
35
+ ...,
36
+ description="The search query to look up on the web. Be specific and include relevant medical keywords for better results.",
37
+ min_length=1,
38
+ max_length=500,
39
+ )
40
+ max_results: int = Field(
41
+ default=5,
42
+ description="Maximum number of search results to return (1-10)",
43
+ ge=1,
44
+ le=10,
45
+ )
46
+ region: str = Field(
47
+ default="us-en",
48
+ description="Region for search results (e.g., 'us-en', 'uk-en', 'ca-en')",
49
+ )
50
+
51
+
52
+ class DuckDuckGoSearchTool(BaseTool):
53
+ """
54
+ Tool that performs web searches using DuckDuckGo search engine for medical research.
55
+
56
+ This tool provides access to real-time web information through DuckDuckGo's
57
+ search API, specifically designed for medical agents that need to retrieve current
58
+ medical information, verify facts, or find resources on medical topics.
59
+
60
+ Features:
61
+ - Real-time web search capability for medical information
62
+ - Configurable number of results (1-10)
63
+ - Regional search support for localized medical results
64
+ - Robust error handling for network issues
65
+ - Structured result formatting for easy parsing
66
+ - Privacy-focused (DuckDuckGo doesn't track users)
67
+ - Medical-focused search optimization
68
+
69
+ Use Cases:
70
+ - Medical fact checking and verification
71
+ - Finding current medical news and updates
72
+ - Researching specific medical topics or questions
73
+ - Gathering multiple perspectives on medical issues
74
+ - Locating official medical resources and documentation
75
+ - Accessing current clinical guidelines and research
76
+
77
+ Rate Limiting:
78
+ DuckDuckGo has rate limits. Avoid making too many rapid requests
79
+ to prevent temporary blocking.
80
+ """
81
+
82
+ name: str = "duckduckgo_search"
83
+ description: str = (
84
+ "Search the web using DuckDuckGo to find current medical information, research, and resources. "
85
+ "Input should be a clear search query with relevant medical keywords. The tool returns a list of relevant web results "
86
+ "with titles, URLs, and brief snippets. Useful for medical fact-checking, finding current medical events, "
87
+ "researching medical topics, and gathering information from reliable medical sources. "
88
+ "Results are privacy-focused and don't track user searches. Optimized for medical research and clinical information."
89
+ )
90
+ args_schema: type[BaseModel] = WebSearchInput
91
+ return_direct: bool = False
92
+
93
+ def __init__(self, **kwargs):
94
+ """Initialize the DuckDuckGo search tool."""
95
+ super().__init__(**kwargs)
96
+
97
+ if DDGS is None:
98
+ logger.error(
99
+ "duckduckgo-search package not installed. Install with: pip install duckduckgo-search"
100
+ )
101
+ raise ImportError(
102
+ "duckduckgo-search package is required for web search functionality"
103
+ )
104
+
105
+ logger.info("DuckDuckGo search tool initialized successfully")
106
+
107
+ def _perform_search_sync(
108
+ self, query: str, max_results: int = 5, region: str = "us-en"
109
+ ) -> Dict[str, Any]:
110
+ """
111
+ Perform the actual web search using DuckDuckGo synchronously.
112
+
113
+ Args:
114
+ query (str): The search query.
115
+ max_results (int): Maximum number of results to return.
116
+ region (str): Region for localized results.
117
+
118
+ Returns:
119
+ Dict[str, Any]: Structured search results.
120
+ """
121
+ logger.info(
122
+ f"Performing web search: '{query}' (max_results={max_results}, region={region})"
123
+ )
124
+
125
+ try:
126
+ # Initialize DDGS with error handling
127
+ with DDGS() as ddgs:
128
+ # Perform the search
129
+ search_results = list(
130
+ ddgs.text(
131
+ keywords=query,
132
+ region=region,
133
+ safesearch="moderate",
134
+ timelimit=None,
135
+ max_results=max_results,
136
+ )
137
+ )
138
+
139
+ # Format results for the agent
140
+ formatted_results = []
141
+ for i, result in enumerate(search_results, 1):
142
+ formatted_result = {
143
+ "rank": i,
144
+ "title": result.get("title", "No title"),
145
+ "url": result.get("href", "No URL"),
146
+ "snippet": result.get("body", "No description available"),
147
+ "source": "DuckDuckGo",
148
+ }
149
+ formatted_results.append(formatted_result)
150
+
151
+ # Create summary for the agent
152
+ if formatted_results:
153
+ summary = (
154
+ f"Found {len(formatted_results)} results for '{query}'. Top results include: "
155
+ + ", ".join([f"{r['title']}" for r in formatted_results[:3]])
156
+ )
157
+ else:
158
+ summary = f"No results found for '{query}'"
159
+
160
+ # Log successful completion
161
+ logger.info(
162
+ f"Web search completed successfully: {len(formatted_results)} results"
163
+ )
164
+
165
+ return {
166
+ "query": query,
167
+ "results_count": len(formatted_results),
168
+ "results": formatted_results,
169
+ "summary": summary,
170
+ "search_engine": "DuckDuckGo",
171
+ "timestamp": datetime.now().isoformat(),
172
+ }
173
+
174
+ except Exception as e:
175
+ error_msg = f"Web search failed for query '{query}': {str(e)}"
176
+ logger.error(f"{error_msg}")
177
+
178
+ return {
179
+ "query": query,
180
+ "results_count": 0,
181
+ "results": [],
182
+ "error": error_msg,
183
+ "search_engine": "DuckDuckGo",
184
+ "timestamp": datetime.now().isoformat(),
185
+ }
186
+
187
+ def _run(
188
+ self,
189
+ query: str,
190
+ max_results: int = 5,
191
+ region: str = "us-en",
192
+ run_manager: CallbackManagerForToolRun | None = None,
193
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
194
+ """
195
+ Execute the web search synchronously.
196
+
197
+ Args:
198
+ query (str): Search query
199
+ max_results (int): Maximum number of results
200
+ region (str): Search region
201
+ run_manager: Callback manager (unused)
202
+
203
+ Returns:
204
+ Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
205
+ - output: Dictionary with search results
206
+ - metadata: Dictionary with execution metadata
207
+ """
208
+ # Create metadata structure
209
+ metadata = {
210
+ "query": query,
211
+ "max_results": max_results,
212
+ "region": region,
213
+ "timestamp": time.time(),
214
+ "tool": "duckduckgo_search",
215
+ "operation": "search",
216
+ }
217
+
218
+ try:
219
+ result = self._perform_search_sync(query, max_results, region)
220
+
221
+ # Check if search was successful
222
+ if "error" in result:
223
+ metadata["analysis_status"] = "failed"
224
+ metadata["error_details"] = result["error"]
225
+ else:
226
+ metadata["analysis_status"] = "completed"
227
+ metadata["results_count"] = result.get("results_count", 0)
228
+
229
+ return result, metadata
230
+
231
+ except Exception as e:
232
+ error_result = {
233
+ "query": query,
234
+ "results_count": 0,
235
+ "results": [],
236
+ "error": str(e),
237
+ "search_engine": "DuckDuckGo",
238
+ "timestamp": datetime.now().isoformat(),
239
+ }
240
+ metadata["analysis_status"] = "failed"
241
+ metadata["error_details"] = str(e)
242
+
243
+ return error_result, metadata
244
+
245
+ async def _arun(
246
+ self,
247
+ query: str,
248
+ max_results: int = 5,
249
+ region: str = "us-en",
250
+ run_manager: AsyncCallbackManagerForToolRun | None = None,
251
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
252
+ """
253
+ Execute the web search asynchronously.
254
+
255
+ Args:
256
+ query (str): Search query
257
+ max_results (int): Maximum number of results
258
+ region (str): Search region
259
+ run_manager: Callback manager (unused)
260
+
261
+ Returns:
262
+ Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
263
+ - output: Dictionary with search results
264
+ - metadata: Dictionary with execution metadata
265
+ """
266
+ # Try to get LangGraph stream writer for progress updates
267
+ writer = None
268
+ try:
269
+ from langgraph.config import get_stream_writer
270
+
271
+ writer = get_stream_writer()
272
+ except Exception:
273
+ # Stream writer not available (outside LangGraph context)
274
+ pass
275
+
276
+ if writer:
277
+ writer(
278
+ {
279
+ "tool_name": "DuckDuckGoSearchTool",
280
+ "status": "started",
281
+ "query": query,
282
+ "max_results": max_results,
283
+ "step": "Initiating web search",
284
+ }
285
+ )
286
+
287
+ try:
288
+ if writer:
289
+ writer(
290
+ {
291
+ "tool_name": "DuckDuckGoSearchTool",
292
+ "status": "searching",
293
+ "step": "Fetching results from DuckDuckGo API",
294
+ }
295
+ )
296
+
297
+ # Use asyncio to run sync search in executor
298
+ loop = asyncio.get_event_loop()
299
+ result, metadata = await loop.run_in_executor(
300
+ None, self._run, query, max_results, region
301
+ )
302
+
303
+ if writer:
304
+ # Parse result to get count for progress update
305
+ results_count = result.get("results_count", 0)
306
+ writer(
307
+ {
308
+ "tool_name": "DuckDuckGoSearchTool",
309
+ "status": "completed",
310
+ "step": f"Search completed with {results_count} results",
311
+ "results_count": results_count,
312
+ }
313
+ )
314
+
315
+ return result, metadata
316
+
317
+ except Exception as e:
318
+ if writer:
319
+ writer(
320
+ {
321
+ "tool_name": "DuckDuckGoSearchTool",
322
+ "status": "error",
323
+ "step": f"Search failed: {str(e)}",
324
+ "error": str(e),
325
+ }
326
+ )
327
+
328
+ error_result = {
329
+ "query": query,
330
+ "results_count": 0,
331
+ "results": [],
332
+ "error": str(e),
333
+ "search_engine": "DuckDuckGo",
334
+ "timestamp": datetime.now().isoformat(),
335
+ }
336
+
337
+ metadata = {
338
+ "query": query,
339
+ "max_results": max_results,
340
+ "region": region,
341
+ "timestamp": time.time(),
342
+ "tool": "duckduckgo_search",
343
+ "operation": "search",
344
+ "analysis_status": "failed",
345
+ "error_details": str(e),
346
+ }
347
+
348
+ return error_result, metadata
349
+
350
+ def get_search_summary(
351
+ self, query: str, max_results: int = 3
352
+ ) -> dict[str, str | list[str]]:
353
+ """
354
+ Get a quick summary of search results for a given query.
355
+
356
+ Args:
357
+ query (str): The search query.
358
+ max_results (int): Maximum number of results to summarize.
359
+
360
+ Returns:
361
+ Dict[str, Union[str, List[str]]]: Summary of search results.
362
+ """
363
+ try:
364
+ result, _ = self._run(query, max_results)
365
+
366
+ if "error" in result:
367
+ return {
368
+ "query": query,
369
+ "status": "error",
370
+ "error": result["error"],
371
+ "results": [],
372
+ }
373
+
374
+ # Extract key information
375
+ results = result.get("results", [])
376
+ titles = [r["title"] for r in results]
377
+ urls = [r["url"] for r in results]
378
+ snippets = [
379
+ (
380
+ r["snippet"][:100] + "..."
381
+ if len(r["snippet"]) > 100
382
+ else r["snippet"]
383
+ )
384
+ for r in results
385
+ ]
386
+
387
+ return {
388
+ "query": query,
389
+ "status": "success",
390
+ "total_results": result.get("results_count", 0),
391
+ "titles": titles,
392
+ "urls": urls,
393
+ "snippets": snippets,
394
+ }
395
+
396
+ except Exception as e:
397
+ logger.error(f"Error getting search summary: {e}")
398
+ return {
399
+ "query": query,
400
+ "status": "error",
401
+ "error": str(e),
402
+ "results": [],
403
+ }
medrax/tools/{web_browser.py → browsing/web_browser.py} RENAMED
File without changes
medrax/tools/classification/arcplus.py CHANGED
@@ -345,7 +345,8 @@ class ArcPlusClassifierTool(BaseTool):
345
  predictions = predictions[: len(self.disease_list)]
346
 
347
  # Create output dictionary mapping disease names to probabilities
348
- output = dict(zip(self.disease_list, predictions.astype(float)))
 
349
 
350
  metadata = {
351
  "image_path": image_path,
 
345
  predictions = predictions[: len(self.disease_list)]
346
 
347
  # Create output dictionary mapping disease names to probabilities
348
+ # Convert numpy floats to native Python floats for proper serialization
349
+ output = dict(zip(self.disease_list, [float(pred) for pred in predictions]))
350
 
351
  metadata = {
352
  "image_path": image_path,
medrax/tools/segmentation/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Medical image segmentation tools for MedRAX2."""
2
+
3
+ from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
4
+ from .medsam2 import MedSAM2Tool, MedSAM2Input
5
+
6
+ __all__ = [
7
+ "ChestXRaySegmentationTool",
8
+ "ChestXRaySegmentationInput",
9
+ "OrganMetrics",
10
+ "MedSAM2Tool",
11
+ "MedSAM2Input"
12
+ ]
medrax/tools/{medsam2.py → segmentation/medsam2.py} RENAMED
@@ -15,7 +15,7 @@ from langchain_core.callbacks import (
15
  from langchain_core.tools import BaseTool
16
 
17
  # Add MedSAM2 to Python path for proper module resolution
18
- medsam2_path = str(Path(__file__).parent.parent.parent / "MedSAM2")
19
  if medsam2_path not in sys.path:
20
  sys.path.append(medsam2_path)
21
 
@@ -93,7 +93,7 @@ class MedSAM2Tool(BaseTool):
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
 
96
- config_dir = Path(__file__).parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
 
99
  hf_hub_download(
 
15
  from langchain_core.tools import BaseTool
16
 
17
  # Add MedSAM2 to Python path for proper module resolution
18
+ medsam2_path = str(Path(__file__).parent.parent.parent.parent / "MedSAM2")
19
  if medsam2_path not in sys.path:
20
  sys.path.append(medsam2_path)
21
 
 
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
 
96
+ config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
 
99
  hf_hub_download(
medrax/tools/{segmentation.py → segmentation/segmentation.py} RENAMED
File without changes
medrax/tools/vqa/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visual Question Answering tools for medical images."""
2
+
3
+ from .llava_med import LlavaMedTool, LlavaMedInput
4
+ from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
5
+ from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
6
+ from .medgemma.medgemma_setup import setup_medgemma_env
7
+
8
+ __all__ = [
9
+ "LlavaMedTool",
10
+ "LlavaMedInput",
11
+ "CheXagentXRayVQATool",
12
+ "XRayVQAToolInput",
13
+ "MedGemmaAPIClientTool",
14
+ "MedGemmaVQAInput",
15
+ "setup_medgemma_env"
16
+ ]
medrax/tools/{llava_med.py → vqa/llava_med.py} RENAMED
@@ -151,7 +151,7 @@ class LlavaMedTool(BaseTool):
151
  output = {
152
  "answer": answer,
153
  }
154
-
155
  metadata = {
156
  "question": question,
157
  "image_path": image_path,
 
151
  output = {
152
  "answer": answer,
153
  }
154
+
155
  metadata = {
156
  "question": question,
157
  "image_path": image_path,
medrax/tools/vqa/medgemma/medgemma.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from pathlib import Path
4
+ import sys
5
+ import traceback
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+ import uuid
8
+
9
+ from PIL import Image
10
+
11
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
12
+ from pydantic import BaseModel, Field
13
+ import torch
14
+ import transformers
15
+ from transformers import BitsAndBytesConfig, pipeline
16
+ import uvicorn
17
+
18
+ # Configuration
19
+ UPLOAD_DIR = "./medgemma_images"
20
+
21
+ # Create directories if they don't exist
22
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
23
+
24
+ # Pydantic Models for API
25
+ class VQAInput(BaseModel):
26
+ """Input schema for the MedGemma VQA API endpoint.
27
+
28
+ Defines the structure for requests to the /analyze-images/ endpoint.
29
+ Used for validating incoming API requests and generating OpenAPI documentation.
30
+ """
31
+ prompt: str = Field(..., description="Question or instruction about the medical images")
32
+ system_prompt: Optional[str] = Field(
33
+ "You are an expert radiologist.",
34
+ description="System prompt to set the context for the model",
35
+ )
36
+ max_new_tokens: int = Field(
37
+ 300, description="Maximum number of tokens to generate in the response"
38
+ )
39
+
40
+ class VQAResponse(BaseModel):
41
+ """Response schema for successful MedGemma VQA API requests.
42
+
43
+ Defines the structure of successful responses from the /analyze-images/ endpoint.
44
+ Used for response validation and OpenAPI documentation.
45
+ """
46
+ response: str = Field(..., description="Generated medical analysis response from MedGemma model")
47
+ metadata: Dict[str, Any] = Field(..., description="Additional metadata about the analysis request and results")
48
+
49
+ class ErrorResponse(BaseModel):
50
+ """Error response schema for failed MedGemma VQA API requests.
51
+
52
+ Defines the structure of error responses from the /analyze-images/ endpoint.
53
+ Used for error response validation and OpenAPI documentation.
54
+ """
55
+ error: str = Field(..., description="Human-readable error message describing what went wrong")
56
+ metadata: Dict[str, Any] = Field(..., description="Additional metadata about the error and request context")
57
+
58
+ # MedGemma Model Handling
59
+ class MedGemmaModel:
60
+ """Medical visual question answering model using Google's MedGemma 4B model.
61
+
62
+ MedGemma is a specialized multimodal AI model trained on medical images and text.
63
+ It provides expert-level analysis for chest X-rays, dermatology images,
64
+ ophthalmology images, and histopathology slides.
65
+
66
+ Key capabilities:
67
+ - Medical image classification and analysis across multiple modalities
68
+ - Visual question answering for radiology, dermatology, pathology, ophthalmology
69
+ - Clinical reasoning and medical knowledge integration
70
+ - Multi-modal medical understanding (text + images)
71
+ - Support for up to 128K context length
72
+
73
+ Performance:
74
+ - Full precision (bfloat16): ~8GB VRAM, recommended for medical applications
75
+ - 4-bit quantization (default): Available but may affect quality on some systems
76
+
77
+ This class implements a singleton pattern to ensure only one model instance
78
+ is loaded in memory, optimizing resource usage for the FastAPI service.
79
+ """
80
+
81
+ _instance = None
82
+
83
+ def __new__(cls, *args, **kwargs):
84
+ """Create or return the singleton instance of MedGemmaModel.
85
+
86
+ Ensures only one model instance exists in memory, preventing
87
+ multiple model loads and conserving GPU memory.
88
+
89
+ Returns:
90
+ MedGemmaModel: The singleton instance
91
+ """
92
+ if not cls._instance:
93
+ cls._instance = super(MedGemmaModel, cls).__new__(cls)
94
+ return cls._instance
95
+
96
+ def __init__(
97
+ self,
98
+ model_name: str = "google/medgemma-4b-it",
99
+ device: Optional[str] = "cuda",
100
+ dtype: torch.dtype = torch.bfloat16,
101
+ cache_dir: Optional[str] = None,
102
+ load_in_4bit: bool = True,
103
+ **kwargs: Any,
104
+ ) -> None:
105
+ """Initialize the MedGemmaModel.
106
+
107
+ Args:
108
+ model_name: Name of the MedGemma model to use (default: "google/medgemma-4b-it")
109
+ device: Device to run model on - "cuda" or "cpu" (default: "cuda")
110
+ dtype: Data type for model weights - bfloat16 recommended for efficiency (default: torch.bfloat16)
111
+ cache_dir: Directory to cache downloaded models (default: None)
112
+ load_in_4bit: Whether to load model in 4-bit quantization for memory efficiency (default: True)
113
+ **kwargs: Additional arguments passed to the model pipeline
114
+
115
+ Raises:
116
+ RuntimeError: If model initialization fails (e.g., insufficient GPU memory)
117
+ """
118
+ # Re-initialization guard
119
+ if hasattr(self, 'pipe') and self.pipe is not None:
120
+ return
121
+
122
+ self.device = device if device and torch.cuda.is_available() else "cpu"
123
+ self.dtype = dtype
124
+ self.cache_dir = cache_dir
125
+
126
+ # Setup model configuration
127
+ model_kwargs = {
128
+ "torch_dtype": self.dtype,
129
+ }
130
+
131
+ if cache_dir:
132
+ model_kwargs["cache_dir"] = cache_dir
133
+
134
+ # Handle device mapping and quantization
135
+ pipeline_kwargs = {
136
+ "model": model_name,
137
+ "model_kwargs": model_kwargs,
138
+ "trust_remote_code": True,
139
+ "use_cache": True,
140
+ }
141
+
142
+ if load_in_4bit:
143
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
144
+ model_kwargs["device_map"] = {"": self.device}
145
+
146
+ try:
147
+ self.pipe = pipeline("image-text-to-text", **pipeline_kwargs)
148
+ except Exception as e:
149
+ raise RuntimeError(f"Failed to initialize MedGemma pipeline: {str(e)}")
150
+
151
+ def _prepare_messages(
152
+ self, image_paths: List[str], prompt: str, system_prompt: str
153
+ ) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
154
+ """Prepare chat messages in the format expected by MedGemma.
155
+
156
+ Converts image paths to PIL Image objects and formats them into the
157
+ chat message structure that MedGemma expects for multimodal input.
158
+
159
+ Args:
160
+ image_paths: List of file paths to medical images
161
+ prompt: User's question or instruction about the images
162
+ system_prompt: System context message to set the model's role
163
+
164
+ Returns:
165
+ Tuple containing:
166
+ - List of formatted chat messages for MedGemma
167
+ - List of loaded PIL Image objects
168
+
169
+ Raises:
170
+ FileNotFoundError: If any image file cannot be found
171
+ """
172
+ images = []
173
+ for path in image_paths:
174
+ if not Path(path).is_file():
175
+ raise FileNotFoundError(f"Image file not found: {path}")
176
+
177
+ image = Image.open(path)
178
+ if image.mode != "RGB":
179
+ image = image.convert("RGB")
180
+ images.append(image)
181
+
182
+ # Create messages in chat format
183
+ messages = [
184
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
185
+ {
186
+ "role": "user",
187
+ "content": [{"type": "text", "text": prompt}]
188
+ + [{"type": "image", "image": img} for img in images],
189
+ },
190
+ ]
191
+
192
+ return messages, images
193
+
194
+ def _generate_response(self, messages: List[Dict[str, Any]], max_new_tokens: int) -> str:
195
+ """Generate response using MedGemma pipeline.
196
+
197
+ Processes the formatted messages through the MedGemma model to generate
198
+ a medical analysis response.
199
+
200
+ Args:
201
+ messages: Formatted chat messages with images and text
202
+ max_new_tokens: Maximum number of tokens to generate in response
203
+
204
+ Returns:
205
+ Generated response text from MedGemma model
206
+ """
207
+ # Generate using pipeline
208
+ output = self.pipe(
209
+ text=messages,
210
+ max_new_tokens=max_new_tokens,
211
+ do_sample=False,
212
+ )
213
+
214
+ # Extract generated text from pipeline output
215
+ if (
216
+ isinstance(output, list)
217
+ and output
218
+ and isinstance(output[0].get("generated_text"), list)
219
+ ):
220
+ generated_text = output[0]["generated_text"]
221
+ if generated_text:
222
+ return generated_text[-1].get("content", "").strip()
223
+
224
+ return "No response generated"
225
+
226
+ def _create_error_response(
227
+ self,
228
+ image_paths: List[str],
229
+ prompt: str,
230
+ error_message: str,
231
+ error_type: str,
232
+ error_details: str,
233
+ ) -> Dict[str, Any]:
234
+ """Create standardized error response metadata.
235
+
236
+ Generates consistent error metadata structure for logging and debugging
237
+ purposes across different error scenarios.
238
+
239
+ Args:
240
+ image_paths: List of image paths that were being processed
241
+ prompt: User prompt that was being processed
242
+ error_message: Human-readable error message
243
+ error_type: Categorization of the error (e.g., "memory_error", "file_not_found")
244
+ error_details: Detailed technical error information
245
+
246
+ Returns:
247
+ Dictionary containing standardized error metadata
248
+ """
249
+ return {
250
+ "image_paths": image_paths,
251
+ "prompt": prompt,
252
+ "analysis_status": "failed",
253
+ "error_type": error_type,
254
+ "error_details": error_details,
255
+ }
256
+
257
+ async def aget_response(self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int) -> str:
258
+ """Async method to get response from MedGemma model.
259
+
260
+ Main entry point for generating medical analysis responses. Handles
261
+ the complete pipeline from image loading to response generation
262
+ in an asynchronous manner.
263
+
264
+ Args:
265
+ image_paths: List of file paths to medical images
266
+ prompt: User's question or instruction about the images
267
+ system_prompt: System context message to set the model's role
268
+ max_new_tokens: Maximum number of tokens to generate in response
269
+
270
+ Returns:
271
+ Generated medical analysis response as a string
272
+
273
+ Raises:
274
+ FileNotFoundError: If any image file cannot be found
275
+ RuntimeError: If model inference fails
276
+ """
277
+ loop = asyncio.get_event_loop()
278
+ messages, _ = await loop.run_in_executor(None, self._prepare_messages, image_paths, prompt, system_prompt)
279
+
280
+ def _generate():
281
+ return self._generate_response(messages, max_new_tokens)
282
+
283
+ return await loop.run_in_executor(None, _generate)
284
+
285
+ # FastAPI Application
286
+ app = FastAPI(
287
+ title="MedGemma VQA API",
288
+ description="API for medical visual question answering using Google's MedGemma model."
289
+ )
290
+
291
+ medgemma_model: Optional[MedGemmaModel] = None
292
+
293
+ @app.on_event("startup")
294
+ async def startup_event():
295
+ """Load the MedGemma model at application startup.
296
+
297
+ This function is called when the FastAPI application starts up.
298
+ It initializes the MedGemma model as a global singleton instance,
299
+ ensuring the model is loaded and ready to handle requests.
300
+
301
+ The model is loaded with default settings optimized for medical
302
+ image analysis, including 4-bit quantization for memory efficiency.
303
+
304
+ Raises:
305
+ SystemExit: If model loading fails, the application will exit
306
+ to prevent serving requests with an unavailable model.
307
+ """
308
+ global medgemma_model
309
+ try:
310
+ medgemma_model = MedGemmaModel()
311
+ print("MedGemma model loaded successfully.")
312
+ except RuntimeError as e:
313
+ print(f"Error loading MedGemma model: {e}")
314
+ exit(1)
315
+
316
+ @app.post("/analyze-images/",
317
+ response_model=VQAResponse,
318
+ responses={
319
+ 500: {"model": ErrorResponse, "description": "Internal server error or model inference failure"},
320
+ 404: {"model": ErrorResponse, "description": "Image file not found"},
321
+ 400: {"description": "Invalid request format or unsupported image type"},
322
+ 503: {"description": "Model not available or not loaded"}
323
+ },
324
+ summary="Analyze one or more medical images",
325
+ description="Upload medical images and receive AI-powered analysis using Google's MedGemma model.")
326
+ async def analyze_images(
327
+ images: List[UploadFile] = File(..., description="List of medical image files to analyze (JPG or PNG)."),
328
+ prompt: str = Form(..., description="Question or instruction about the medical images."),
329
+ system_prompt: Optional[str] = Form("You are an expert radiologist.", description="System prompt to set the context for the model."),
330
+ max_new_tokens: int = Form(100, description="Maximum number of tokens to generate in the response.")
331
+ ):
332
+ """Analyze medical images using MedGemma AI model.
333
+
334
+ This endpoint accepts one or more medical images along with a prompt
335
+ and returns AI-generated medical analysis.
336
+
337
+ The endpoint handles the complete pipeline:
338
+ 1. Validates uploaded image files
339
+ 2. Saves images temporarily to disk
340
+ 3. Processes images through MedGemma model
341
+ 4. Returns structured analysis with metadata
342
+ 5. Cleans up temporary files
343
+
344
+ Args:
345
+ images: List of uploaded image files (JPG/PNG format)
346
+ prompt: Medical question or instruction about the images
347
+ system_prompt: Context setting for the AI model (default: radiologist role)
348
+ max_new_tokens: Maximum response length (default: 100)
349
+
350
+ Returns:
351
+ VQAResponse: Contains the AI-generated analysis and request metadata
352
+
353
+ Raises:
354
+ HTTPException 400: Invalid image format or request structure
355
+ HTTPException 404: Image file not found during processing
356
+ HTTPException 500: Model inference error or memory issues
357
+ HTTPException 503: Model not available for processing
358
+ """
359
+ # Check if model is available
360
+ if medgemma_model is None or medgemma_model.pipe is None:
361
+ raise HTTPException(status_code=503, detail="Model is not available. Please try again later.")
362
+
363
+ # Process uploaded images
364
+ image_paths = []
365
+ for image in images:
366
+ # Validate image format
367
+ if image.content_type not in ["image/jpeg", "image/png"]:
368
+ raise HTTPException(status_code=400, detail=f"Unsupported image format: {image.content_type}. Only JPG and PNG are supported.")
369
+
370
+ # Generate unique filename to avoid conflicts
371
+ unique_filename = f"{uuid.uuid4()}_{image.filename}"
372
+ file_path = os.path.join(UPLOAD_DIR, unique_filename)
373
+
374
+ try:
375
+ # Save uploaded image to disk
376
+ with open(file_path, "wb") as buffer:
377
+ buffer.write(await image.read())
378
+ image_paths.append(file_path)
379
+ except Exception as e:
380
+ raise HTTPException(status_code=500, detail=f"Failed to save uploaded image: {str(e)}")
381
+
382
+ try:
383
+ # Generate AI analysis
384
+ response_text = await medgemma_model.aget_response(image_paths, prompt, system_prompt, max_new_tokens)
385
+
386
+ # Prepare success response
387
+ metadata = {
388
+ "image_paths": image_paths,
389
+ "prompt": prompt,
390
+ "system_prompt": system_prompt,
391
+ "max_new_tokens": max_new_tokens,
392
+ "num_images": len(image_paths),
393
+ "analysis_status": "completed",
394
+ }
395
+ return VQAResponse(response=response_text, metadata=metadata)
396
+
397
+ except FileNotFoundError as e:
398
+ raise HTTPException(status_code=404, detail=f"Image file not found: {str(e)}")
399
+ except torch.cuda.OutOfMemoryError as e:
400
+ error_message = "GPU memory exhausted. Try reducing image resolution or max_new_tokens."
401
+ metadata = medgemma_model._create_error_response(
402
+ image_paths, prompt, error_message, "memory_error", str(e)
403
+ )
404
+ raise HTTPException(status_code=500, detail=error_message)
405
+ except Exception as e:
406
+ traceback.print_exc()
407
+ metadata = medgemma_model._create_error_response(
408
+ image_paths, prompt, f"Analysis failed: {str(e)}", "general_error", str(e)
409
+ )
410
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
411
+ finally:
412
+ # Clean up temporary image files
413
+ for path in image_paths:
414
+ try:
415
+ os.remove(path)
416
+ except OSError:
417
+ pass
418
+
419
+ if __name__ == "__main__":
420
+ """Launch the MedGemma VQA API server.
421
+
422
+ Starts the FastAPI application with uvicorn server, binding to all
423
+ network interfaces on port 8002.
424
+ """
425
+ uvicorn.run(app, host="0.0.0.0", port=8002)
medrax/tools/vqa/medgemma/medgemma_client.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple, Type
3
+
4
+ import httpx
5
+ from langchain_core.callbacks import (
6
+ AsyncCallbackManagerForToolRun,
7
+ CallbackManagerForToolRun,
8
+ )
9
+ from langchain_core.tools import BaseTool
10
+ from pydantic import BaseModel, Field
11
+
12
+ class MedGemmaVQAInput(BaseModel):
13
+ """Input schema for the MedGemma VQA Tool. Only supports JPG or PNG images."""
14
+ image_paths: List[str] = Field(
15
+ ...,
16
+ description="List of paths to medical image files to analyze, only supports JPG or PNG images",
17
+ )
18
+ prompt: str = Field(..., description="Question or instruction about the medical images")
19
+ system_prompt: Optional[str] = Field(
20
+ "You are an expert radiologist.",
21
+ description="System prompt to set the context for the model",
22
+ )
23
+ max_new_tokens: int = Field(
24
+ 300, description="Maximum number of tokens to generate in the response"
25
+ )
26
+
27
+ class MedGemmaAPIClientTool(BaseTool):
28
+ """Medical visual question answering tool using Google's MedGemma 4B model via API.
29
+
30
+ MedGemma is a specialized multimodal AI model trained on medical images and text.
31
+ It provides expert-level analysis for chest X-rays, dermatology images,
32
+ ophthalmology images, and histopathology slides.
33
+
34
+ Key capabilities:
35
+ - Medical image classification and analysis across multiple modalities
36
+ - Visual question answering for radiology, dermatology, pathology, ophthalmology
37
+ - Clinical reasoning and medical knowledge integration
38
+ - Multi-modal medical understanding (text + images)
39
+ - Support for up to 128K context length
40
+
41
+ Performance:
42
+ - Full precision (bfloat16): ~8GB VRAM, recommended for medical applications
43
+ - 4-bit quantization (default): Available but may affect quality on some systems
44
+ """
45
+
46
+ name: str = "medgemma_medical_vqa"
47
+ description: str = (
48
+ "Advanced medical visual question answering tool using Google's MedGemma 4B instruction-tuned model via API. "
49
+ "Specialized for comprehensive medical image analysis across multiple modalities including chest X-rays, "
50
+ "dermatology images, ophthalmology images, and histopathology slides. Provides expert-level medical "
51
+ "reasoning, diagnosis assistance, and detailed image interpretation with radiologist-level expertise. "
52
+ "Input: List of medical image paths and medical question/prompt with optional custom system prompt. "
53
+ "Output: Comprehensive medical analysis and answers based on visual content with detailed reasoning. "
54
+ "Supports multi-image analysis, comparative studies, and complex medical reasoning tasks. "
55
+ "Model handles images up to 896x896 resolution and supports context up to 128K tokens."
56
+ )
57
+ args_schema: Type[BaseModel] = MedGemmaVQAInput
58
+ return_direct: bool = True
59
+
60
+ # API configuration
61
+ api_url: str # The URL of the running FastAPI service
62
+
63
+ def __init__(self, api_url: str, **kwargs: Any):
64
+ """Initialize the MedGemmaAPIClientTool.
65
+
66
+ Args:
67
+ api_url: The URL of the running MedGemma FastAPI service
68
+ **kwargs: Additional arguments passed to BaseTool
69
+ """
70
+ super().__init__(api_url=api_url, **kwargs)
71
+
72
+ def _prepare_request_data(
73
+ self, image_paths: List[str], prompt: str, system_prompt: str, max_new_tokens: int
74
+ ) -> Tuple[List, Dict]:
75
+ """Prepare multipart form data for API request.
76
+
77
+ Args:
78
+ image_paths: List of paths to medical images
79
+ prompt: Question or instruction about the images
80
+ system_prompt: System context for the model
81
+ max_new_tokens: Maximum number of tokens to generate
82
+
83
+ Returns:
84
+ Tuple of files list and data dictionary
85
+ """
86
+ files_to_send = []
87
+ opened_files = []
88
+
89
+ for path in image_paths:
90
+ with open(path, "rb") as f:
91
+ files_to_send.append(("images", (os.path.basename(path), f.read(), "image/jpeg")))
92
+
93
+ data = {
94
+ "prompt": prompt,
95
+ "system_prompt": system_prompt,
96
+ "max_new_tokens": max_new_tokens,
97
+ }
98
+
99
+ return files_to_send, data, opened_files
100
+
101
+ def _create_error_response(
102
+ self,
103
+ image_paths: List[str],
104
+ prompt: str,
105
+ error_message: str,
106
+ error_type: str,
107
+ error_details: str,
108
+ ) -> Tuple[Dict[str, Any], Dict]:
109
+ """Create standardized error response.
110
+
111
+ Args:
112
+ image_paths: List of image paths
113
+ prompt: User prompt
114
+ error_message: Human-readable error message
115
+ error_type: Type of error
116
+ error_details: Detailed error information
117
+
118
+ Returns:
119
+ Tuple of error output and metadata
120
+ """
121
+ output = {"error": error_message}
122
+ metadata = {
123
+ "image_paths": image_paths,
124
+ "prompt": prompt,
125
+ "analysis_status": "failed",
126
+ "error_type": error_type,
127
+ "error_details": error_details,
128
+ }
129
+ return output, metadata
130
+
131
+ def _run(
132
+ self,
133
+ image_paths: List[str],
134
+ prompt: str,
135
+ system_prompt: str = "You are an expert radiologist.",
136
+ max_new_tokens: int = 300,
137
+ run_manager: Optional[CallbackManagerForToolRun] = None,
138
+ ) -> Tuple[Dict[str, Any], Dict]:
139
+ """Execute medical visual question answering via API.
140
+
141
+ Args:
142
+ image_paths: List of paths to medical images
143
+ prompt: Question or instruction about the images
144
+ system_prompt: System context for the model
145
+ max_new_tokens: Maximum number of tokens to generate
146
+ run_manager: Optional callback manager
147
+
148
+ Returns:
149
+ Tuple of output dictionary and metadata
150
+ """
151
+ # httpx is a modern HTTP client that supports sync and async
152
+ timeout_config = httpx.Timeout(300.0, connect=10.0)
153
+ client = httpx.Client(timeout=timeout_config)
154
+
155
+ try:
156
+ # Prepare the multipart form data
157
+ files_to_send, data, opened_files = self._prepare_request_data(
158
+ image_paths, prompt, system_prompt, max_new_tokens
159
+ )
160
+
161
+ response = client.post(
162
+ f"{self.api_url}/analyze-images/",
163
+ data=data,
164
+ files=files_to_send,
165
+ )
166
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
167
+
168
+ response_data = response.json()
169
+ output = {"response": response_data["response"]}
170
+
171
+ metadata = {
172
+ "image_paths": image_paths,
173
+ "prompt": prompt,
174
+ "system_prompt": system_prompt,
175
+ "max_new_tokens": max_new_tokens,
176
+ "num_images": len(image_paths),
177
+ "analysis_status": "completed",
178
+ }
179
+
180
+ return output, metadata
181
+
182
+ except httpx.TimeoutException as e:
183
+ return self._create_error_response(
184
+ image_paths,
185
+ prompt,
186
+ f"Error: The request to the MedGemma API timed out after {timeout_config.read} seconds. The server might be overloaded or the model is taking too long to load. Try again later.",
187
+ "timeout_error",
188
+ str(e)
189
+ )
190
+ except httpx.ConnectError as e:
191
+ return self._create_error_response(
192
+ image_paths,
193
+ prompt,
194
+ f"Error: Could not connect to the MedGemma API. Check if the server address '{self.api_url}' is correct and running.",
195
+ "connection_error",
196
+ str(e)
197
+ )
198
+ except httpx.HTTPStatusError as e:
199
+ return self._create_error_response(
200
+ image_paths,
201
+ prompt,
202
+ f"Error: The MedGemma API returned an error (Status {e.response.status_code}): {e.response.text}",
203
+ "http_error",
204
+ f"Status {e.response.status_code}: {e.response.text}"
205
+ )
206
+ except Exception as e:
207
+ return self._create_error_response(
208
+ image_paths,
209
+ prompt,
210
+ f"An unexpected error occurred in the MedGemma client tool: {str(e)}",
211
+ "general_error",
212
+ str(e)
213
+ )
214
+ finally:
215
+ # Ensure all opened files are closed
216
+ if 'opened_files' in locals():
217
+ for f in opened_files:
218
+ f.close()
219
+
220
+ async def _arun(
221
+ self,
222
+ image_paths: List[str],
223
+ prompt: str,
224
+ system_prompt: str = "You are an expert radiologist.",
225
+ max_new_tokens: int = 300,
226
+ run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
227
+ ) -> Tuple[Dict[str, Any], Dict]:
228
+ """Execute the tool asynchronously."""
229
+ async with httpx.AsyncClient() as client:
230
+ try:
231
+ # Prepare the multipart form data
232
+ files_to_send, data, opened_files = self._prepare_request_data(
233
+ image_paths, prompt, system_prompt, max_new_tokens
234
+ )
235
+
236
+ response = await client.post(
237
+ f"{self.api_url}/analyze-images/",
238
+ data=data,
239
+ files=files_to_send,
240
+ timeout=120.0
241
+ )
242
+ response.raise_for_status()
243
+
244
+ response_data = response.json()
245
+ output = {"response": response_data["response"]}
246
+
247
+ metadata = {
248
+ "image_paths": image_paths,
249
+ "prompt": prompt,
250
+ "system_prompt": system_prompt,
251
+ "max_new_tokens": max_new_tokens,
252
+ "num_images": len(image_paths),
253
+ "analysis_status": "completed",
254
+ }
255
+
256
+ return output, metadata
257
+
258
+ except httpx.HTTPStatusError as e:
259
+ return self._create_error_response(
260
+ image_paths,
261
+ prompt,
262
+ f"Error calling MedGemma API: {e.response.status_code} - {e.response.text}",
263
+ "http_error",
264
+ f"Status {e.response.status_code}: {e.response.text}"
265
+ )
266
+ except Exception as e:
267
+ return self._create_error_response(
268
+ image_paths,
269
+ prompt,
270
+ f"An unexpected error occurred: {str(e)}",
271
+ "general_error",
272
+ str(e)
273
+ )
274
+ finally:
275
+ # Ensure all opened files are closed
276
+ if 'opened_files' in locals():
277
+ for f in opened_files:
278
+ f.close()
medrax/tools/vqa/medgemma/medgemma_requirements_standard.txt ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.9.0
2
+ annotated_types==0.7.0
3
+ anyio==4.9.0
4
+ bitsandbytes==0.46.0
5
+ certifi==2025.7.14
6
+ charset_normalizer==3.4.2
7
+ click==8.2.1
8
+ fastapi==0.116.1
9
+ filelock==3.18.0
10
+ fsspec==2025.7.0
11
+ h11==0.16.0
12
+ hf_xet==1.1.3
13
+ httpcore==1.0.9
14
+ httpx==0.28.1
15
+ huggingface-hub==0.34.3
16
+ idna==3.10
17
+ inquirerpy==0.3.4
18
+ jinja2==3.1.6
19
+ jsonpatch==1.33
20
+ jsonpointer==3.0.0
21
+ langchain-core==0.3.72
22
+ langsmith==0.4.8
23
+ MarkupSafe==2.1.5
24
+ mpmath==1.3.0
25
+ networkx==3.5
26
+ numpy==2.2.2
27
+ orjson==3.10.5
28
+ packaging==25.0
29
+ pfzy==0.3.4
30
+ pillow==11.1.0
31
+ prompt_toolkit==3.0.51
32
+ psutil==6.1.1
33
+ pydantic==2.11.7
34
+ pydantic_core==2.33.2
35
+ python_multipart==0.0.20
36
+ PyYAML==6.0.2
37
+ regex==2024.11.6
38
+ requests==2.32.4
39
+ requests_toolbelt==1.0.0
40
+ safetensors==0.5.3
41
+ sniffio==1.3.1
42
+ sshuttle==1.3.1
43
+ starlette==0.47.2
44
+ sympy==1.14.0
45
+ tenacity==9.1.2
46
+ tokenizers==0.21.1
47
+ torch==2.7.1
48
+ tqdm==4.67.1
49
+ transformers==4.54.1
50
+ typing_extensions==4.14.1
51
+ typing_inspection==0.4.1
52
+ urllib3==2.5.0
53
+ uvicorn==0.35.0
54
+ wcwidth==0.2.13
55
+ zstandard==0.23.0
medrax/tools/vqa/medgemma/medgemma_setup.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import subprocess
4
+ import venv
5
+
6
+ def setup_medgemma_env():
7
+ """Set up MedGemma virtual environment and launch the FastAPI service.
8
+
9
+ This function performs the following steps:
10
+ 1. Creates a virtual environment for MedGemma if it doesn't exist
11
+ 2. Installs MedGemma-specific dependencies from requirements.txt
12
+ 3. Launches the MedGemma FastAPI service in the isolated environment
13
+
14
+ Returns:
15
+ None: Launches MedGemma service as a background process
16
+
17
+ Raises:
18
+ subprocess.CalledProcessError: If pip installation fails
19
+ FileNotFoundError: If required files are missing
20
+ OSError: If virtual environment creation fails
21
+ """
22
+ # Get the directory containing this script
23
+ current_dir = Path(__file__).resolve().parent
24
+
25
+ # Define paths for MedGemma components
26
+ medgemma_path = current_dir / "medgemma.py"
27
+ requirements_path = current_dir / "medgemma_requirements_standard.txt"
28
+ env_dir = current_dir / "medgemma_env"
29
+
30
+ # Determine executable paths based on operating system
31
+ if os.name == "nt": # Windows
32
+ pip_executable = env_dir / "Scripts" / "pip"
33
+ python_executable = env_dir / "Scripts" / "python"
34
+ else: # Unix/Linux/macOS
35
+ pip_executable = env_dir / "bin" / "pip"
36
+ python_executable = env_dir / "bin" / "python"
37
+
38
+ # Create virtual environment if it doesn't exist
39
+ if not env_dir.exists():
40
+ print("Creating MedGemma virtual environment...")
41
+ venv.create(env_dir, with_pip=True)
42
+
43
+ # Install MedGemma dependencies
44
+ print("Installing MedGemma dependencies...")
45
+ subprocess.check_call([
46
+ str(pip_executable),
47
+ "install",
48
+ "-r",
49
+ str(requirements_path)
50
+ ])
51
+
52
+ # Ensure environment exists before accessing executables
53
+ if not env_dir.exists():
54
+ raise RuntimeError("Failed to create MedGemma virtual environment")
55
+
56
+ # Launch MedGemma FastAPI service
57
+ print("Launching MedGemma FastAPI service...")
58
+ subprocess.Popen([
59
+ str(python_executable),
60
+ str(medgemma_path)
61
+ ])
62
+ # Note: stdout and stderr redirection commented out for debugging
63
+ # stdout=subprocess.DEVNULL,
64
+ # stderr=subprocess.DEVNULL,
medrax/tools/{xray_vqa.py → vqa/xray_vqa.py} RENAMED
@@ -24,10 +24,10 @@ class XRayVQAToolInput(BaseModel):
24
  )
25
 
26
 
27
- class XRayVQATool(BaseTool):
28
  """Tool that leverages CheXagent for comprehensive chest X-ray analysis."""
29
 
30
- name: str = "chest_xray_expert"
31
  description: str = (
32
  "A versatile tool for analyzing chest X-rays. "
33
  "Can perform multiple tasks including: visual question answering, report generation, "
@@ -51,7 +51,7 @@ class XRayVQATool(BaseTool):
51
  cache_dir: Optional[str] = None,
52
  **kwargs: Any,
53
  ) -> None:
54
- """Initialize the XRayVQATool.
55
 
56
  Args:
57
  model_name: Name of the CheXagent model to use
 
24
  )
25
 
26
 
27
+ class CheXagentXRayVQATool(BaseTool):
28
  """Tool that leverages CheXagent for comprehensive chest X-ray analysis."""
29
 
30
+ name: str = "chexagent_xray_vqa"
31
  description: str = (
32
  "A versatile tool for analyzing chest X-rays. "
33
  "Can perform multiple tasks including: visual question answering, report generation, "
 
51
  cache_dir: Optional[str] = None,
52
  **kwargs: Any,
53
  ) -> None:
54
+ """Initialize the CheXagentXRayVQATool.
55
 
56
  Args:
57
  model_name: Name of the CheXagent model to use
medrax/tools/{generation.py → xray_generation.py} RENAMED
File without changes
pyproject.toml CHANGED
@@ -57,7 +57,6 @@ dependencies = [
57
  "torch>=2.2.0",
58
  "torchvision>=0.10.0",
59
  "scikit-image>=0.18.0",
60
- "gradio>=5.0.0",
61
  "opencv-python>=4.8.0",
62
  "matplotlib>=3.8.0",
63
  "diffusers>=0.20.0",
@@ -65,16 +64,15 @@ dependencies = [
65
  "pylibjpeg>=1.0.0",
66
  "jupyter>=1.0.0",
67
  "albumentations>=1.0.0",
68
- "pyarrow>=10.0.0",
69
  "chromadb>=0.0.10",
70
  "pinecone-client>=3.2.2",
71
  "langchain-pinecone>=0.0.1",
72
  "langchain-google-genai>=0.1.0",
73
  "ray>=2.9.0",
74
- "langchain-sandbox>=0.0.6",
75
  "seaborn>=0.12.0",
76
  "huggingface_hub>=0.17.0",
77
  "iopath>=0.1.10",
 
78
  ]
79
 
80
  [project.optional-dependencies]
 
57
  "torch>=2.2.0",
58
  "torchvision>=0.10.0",
59
  "scikit-image>=0.18.0",
 
60
  "opencv-python>=4.8.0",
61
  "matplotlib>=3.8.0",
62
  "diffusers>=0.20.0",
 
64
  "pylibjpeg>=1.0.0",
65
  "jupyter>=1.0.0",
66
  "albumentations>=1.0.0",
 
67
  "chromadb>=0.0.10",
68
  "pinecone-client>=3.2.2",
69
  "langchain-pinecone>=0.0.1",
70
  "langchain-google-genai>=0.1.0",
71
  "ray>=2.9.0",
 
72
  "seaborn>=0.12.0",
73
  "huggingface_hub>=0.17.0",
74
  "iopath>=0.1.10",
75
+ "duckduckgo-search>=4.0.0",
76
  ]
77
 
78
  [project.optional-dependencies]