Adibvafa commited on
Commit
b9f142a
·
1 Parent(s): f1b994a

Setup argparse

Browse files
Files changed (2) hide show
  1. main.py +239 -60
  2. pyproject.toml +2 -1
main.py CHANGED
@@ -11,6 +11,9 @@ with different model weights, tools, and parameters.
11
 
12
  import warnings
13
  import os
 
 
 
14
  from typing import Dict, List, Optional, Any
15
  from dotenv import load_dotenv
16
  from transformers import logging
@@ -19,6 +22,7 @@ from langgraph.checkpoint.memory import MemorySaver
19
  from medrax.models import ModelFactory
20
 
21
  from interface import create_demo
 
22
  from medrax.agent import *
23
  from medrax.tools import *
24
  from medrax.utils import *
@@ -37,7 +41,7 @@ def initialize_agent(
37
  model_dir: str = "/model-weights",
38
  temp_dir: str = "temp",
39
  device: str = "cuda",
40
- model: str = "gemini-2.5-pro",
41
  temperature: float = 1.0,
42
  rag_config: Optional[RAGConfig] = None,
43
  model_kwargs: Dict[str, Any] = {},
@@ -137,56 +141,216 @@ def initialize_agent(
137
  return agent, tools_dict
138
 
139
 
140
- if __name__ == "__main__":
141
  """
142
- This is the main entry point for the MedRAX application.
143
- It initializes the agent with the selected tools and creates the demo.
 
 
 
 
 
144
  """
145
- print("Starting server...")
146
-
147
- # Example: initialize with only specific tools
148
- # Here three tools are commented out, you can uncomment them to use them
149
- selected_tools = [
150
- # Image Processing Tools
151
- "ImageVisualizerTool", # For displaying images in the UI
152
- # "DicomProcessorTool", # For processing DICOM medical image files
153
-
154
- # Segmentation Tools
155
- "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
156
- "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
157
-
158
- # Generation Tools
159
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
160
-
161
- # Classification Tools
162
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
163
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
164
-
165
- # Report Generation Tools
166
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
167
 
168
- # Grounding Tools
169
- "XRayPhraseGroundingTool", # For locating described features in X-rays
170
 
171
- # VQA Tools
172
- "MedGemmaVQATool", # Google MedGemma VQA tool
173
- "XRayVQATool", # For visual question answering on X-rays
174
- # "LlavaMedTool", # For multimodal medical image understanding
 
 
 
 
 
 
 
 
 
175
 
176
- # RAG Tools
177
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
178
 
179
- # Search Tools
180
- "WebBrowserTool", # For web browsing and search capabilities
181
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Development Tools
184
- # "PythonSandboxTool", # Add the Python sandbox tool
185
- ]
186
 
187
- # Share a single cache directory and device across tools
188
- model_dir = os.getenv("MODEL_WEIGHTS_DIR", "/model-weights")
189
- device = os.getenv("MEDRAX_DEVICE", "cuda:0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
192
  if "MedGemmaVQATool" in selected_tools:
@@ -195,15 +359,15 @@ if __name__ == "__main__":
195
  # Configure the Retrieval Augmented Generation (RAG) system
196
  # This allows the agent to access and use medical knowledge documents
197
  rag_config = RAGConfig(
198
- model="command-a-03-2025", # Chat model for generating responses
199
- embedding_model="embed-v4.0", # Embedding model for the RAG system
200
- rerank_model="rerank-v3.5", # Reranking model for the RAG system
201
- temperature=0.3,
202
- pinecone_index_name="medrax2", # Name for the Pinecone index
203
- chunk_size=1500,
204
- chunk_overlap=300,
205
- retriever_k=3,
206
- local_docs_dir="rag_docs", # Change this to the path of the documents for RAG
207
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
208
  dataset_split="train", # Which split of the datasets to use
209
  )
@@ -212,18 +376,33 @@ if __name__ == "__main__":
212
  model_kwargs = {}
213
 
214
  agent, tools_dict = initialize_agent(
215
- prompt_file="medrax/docs/system_prompts.txt",
216
  tools_to_use=selected_tools,
217
  model_dir=model_dir,
218
- temp_dir="temp2", # Change this to the path of the temporary directory
219
  device=device,
220
- model="gpt-4.1", # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5
221
- temperature=1.0,
222
  model_kwargs=model_kwargs,
223
  rag_config=rag_config,
224
- system_prompt="MEDICAL_ASSISTANT",
225
  )
226
 
227
- # Create and launch the web interface
228
- demo = create_demo(agent, tools_dict)
229
- demo.launch(server_name="0.0.0.0", server_port=8686, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  import warnings
13
  import os
14
+ import argparse
15
+ import threading
16
+ import uvicorn
17
  from typing import Dict, List, Optional, Any
18
  from dotenv import load_dotenv
19
  from transformers import logging
 
22
  from medrax.models import ModelFactory
23
 
24
  from interface import create_demo
25
+ from api import create_api
26
  from medrax.agent import *
27
  from medrax.tools import *
28
  from medrax.utils import *
 
41
  model_dir: str = "/model-weights",
42
  temp_dir: str = "temp",
43
  device: str = "cuda",
44
+ model: str = "gpt-4.1",
45
  temperature: float = 1.0,
46
  rag_config: Optional[RAGConfig] = None,
47
  model_kwargs: Dict[str, Any] = {},
 
141
  return agent, tools_dict
142
 
143
 
144
+ def run_gradio_interface(agent, tools_dict, host="0.0.0.0", port=8686):
145
  """
146
+ Run the Gradio web interface.
147
+
148
+ Args:
149
+ agent: The initialized MedRAX agent
150
+ tools_dict: Dictionary of available tools
151
+ host (str): Host to bind the server to
152
+ port (int): Port to run the server on
153
  """
154
+ print(f"Starting Gradio interface on {host}:{port}")
155
+ demo = create_demo(agent, tools_dict)
156
+ demo.launch(server_name=host, server_port=port, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
 
 
158
 
159
+ def run_api_server(agent, tools_dict, host="0.0.0.0", port=8000):
160
+ """
161
+ Run the FastAPI server.
162
+
163
+ Args:
164
+ agent: The initialized MedRAX agent
165
+ tools_dict: Dictionary of available tools
166
+ host (str): Host to bind the server to
167
+ port (int): Port to run the server on
168
+ """
169
+ print(f"Starting API server on {host}:{port}")
170
+ app = create_api(agent, tools_dict)
171
+ uvicorn.run(app, host=host, port=port)
172
 
 
 
173
 
174
+ def parse_arguments():
175
+ """Parse command line arguments."""
176
+ parser = argparse.ArgumentParser(description="MedRAX - Medical Reasoning Agent for Chest X-ray")
177
+
178
+ # Server configuration
179
+ parser.add_argument(
180
+ "--mode",
181
+ choices=["gradio", "api", "both"],
182
+ default="gradio",
183
+ help="Run mode: 'gradio' for web interface, 'api' for REST API, 'both' for both services"
184
+ )
185
+ parser.add_argument("--gradio-host", default="0.0.0.0", help="Gradio host address")
186
+ parser.add_argument("--gradio-port", type=int, default=8686, help="Gradio port")
187
+ parser.add_argument("--api-host", default="0.0.0.0", help="API host address")
188
+ parser.add_argument("--api-port", type=int, default=8000, help="API port")
189
+
190
+ # Model and system configuration
191
+ parser.add_argument(
192
+ "--model-dir",
193
+ default="/model-weights",
194
+ help="Directory containing model weights (default: uses MODEL_WEIGHTS_DIR env var or '/model-weights')"
195
+ )
196
+ parser.add_argument(
197
+ "--device",
198
+ default="cuda",
199
+ help="Device to run models on (default: uses MEDRAX_DEVICE env var or 'cuda:1')"
200
+ )
201
+ parser.add_argument(
202
+ "--model",
203
+ default="gpt-4.1",
204
+ help="Model to use (default: gpt-4.1). Examples: gpt-4.1-2025-04-14, gemini-2.5-pro, gpt-5"
205
+ )
206
+ parser.add_argument(
207
+ "--temperature",
208
+ type=float,
209
+ default=1.0,
210
+ help="Temperature for the model (default: 1.0)"
211
+ )
212
+ parser.add_argument(
213
+ "--temp-dir",
214
+ default="temp2",
215
+ help="Directory for temporary files (default: temp2)"
216
+ )
217
+ parser.add_argument(
218
+ "--prompt-file",
219
+ default="medrax/docs/system_prompts.txt",
220
+ help="Path to file containing system prompts (default: medrax/docs/system_prompts.txt)"
221
+ )
222
+ parser.add_argument(
223
+ "--system-prompt",
224
+ default="MEDICAL_ASSISTANT",
225
+ help="System prompt to use (default: MEDICAL_ASSISTANT)"
226
+ )
227
+
228
+ # RAG configuration
229
+ parser.add_argument(
230
+ "--rag-model",
231
+ default="command-a-03-2025",
232
+ help="Chat model for RAG responses (default: command-a-03-2025)"
233
+ )
234
+ parser.add_argument(
235
+ "--rag-embedding-model",
236
+ default="embed-v4.0",
237
+ help="Embedding model for RAG system (default: embed-v4.0)"
238
+ )
239
+ parser.add_argument(
240
+ "--rag-rerank-model",
241
+ default="rerank-v3.5",
242
+ help="Reranking model for RAG system (default: rerank-v3.5)"
243
+ )
244
+ parser.add_argument(
245
+ "--rag-temperature",
246
+ type=float,
247
+ default=0.3,
248
+ help="Temperature for RAG model (default: 0.3)"
249
+ )
250
+ parser.add_argument(
251
+ "--pinecone-index",
252
+ default="medrax2",
253
+ help="Pinecone index name (default: medrax2)"
254
+ )
255
+ parser.add_argument(
256
+ "--chunk-size",
257
+ type=int,
258
+ default=1500,
259
+ help="RAG chunk size (default: 1500)"
260
+ )
261
+ parser.add_argument(
262
+ "--chunk-overlap",
263
+ type=int,
264
+ default=300,
265
+ help="RAG chunk overlap (default: 300)"
266
+ )
267
+ parser.add_argument(
268
+ "--retriever-k",
269
+ type=int,
270
+ default=3,
271
+ help="Number of documents to retrieve (default: 3)"
272
+ )
273
+ parser.add_argument(
274
+ "--rag-docs-dir",
275
+ default="rag_docs",
276
+ help="Directory for RAG documents (default: rag_docs)"
277
+ )
278
+
279
+ # Tools configuration
280
+ parser.add_argument(
281
+ "--tools",
282
+ nargs="*",
283
+ help="Specific tools to enable (if not provided, uses default set). Available tools: " +
284
+ "ImageVisualizerTool, DicomProcessorTool, MedSAM2Tool, ChestXRaySegmentationTool, " +
285
+ "ChestXRayGeneratorTool, TorchXRayVisionClassifierTool, ArcPlusClassifierTool, " +
286
+ "ChestXRayReportGeneratorTool, XRayPhraseGroundingTool, MedGemmaVQATool, " +
287
+ "XRayVQATool, LlavaMedTool, MedicalRAGTool, WebBrowserTool, DuckDuckGoSearchTool, " +
288
+ "PythonSandboxTool"
289
+ )
290
+
291
+ return parser.parse_args()
292
 
 
 
 
293
 
294
+ if __name__ == "__main__":
295
+ """
296
+ This is the main entry point for the MedRAX application.
297
+ It initializes the agent with the selected tools and creates the demo/API.
298
+ """
299
+ args = parse_arguments()
300
+ print(f"Starting MedRAX in {args.mode} mode...")
301
+
302
+ # Configure tools based on arguments
303
+ if args.tools is not None:
304
+ # Use tools specified via command line
305
+ selected_tools = args.tools
306
+ else:
307
+ # Use default tools selection
308
+ selected_tools = [
309
+ # Image Processing Tools
310
+ "ImageVisualizerTool", # For displaying images in the UI
311
+ # "DicomProcessorTool", # For processing DICOM medical image files
312
+
313
+ # Segmentation Tools
314
+ "MedSAM2Tool", # For advanced medical image segmentation using MedSAM2
315
+ "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
316
+
317
+ # Generation Tools
318
+ # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
319
+
320
+ # Classification Tools
321
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
322
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
323
+
324
+ # Report Generation Tools
325
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
326
+
327
+ # Grounding Tools
328
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
329
+
330
+ # VQA Tools
331
+ "MedGemmaVQATool", # Google MedGemma VQA tool
332
+ "XRayVQATool", # For visual question answering on X-rays
333
+ # "LlavaMedTool", # For multimodal medical image understanding
334
+
335
+ # RAG Tools
336
+ "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
337
+
338
+ # Search Tools
339
+ "WebBrowserTool", # For web browsing and search capabilities
340
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
341
+
342
+ # Development Tools
343
+ # "PythonSandboxTool", # Add the Python sandbox tool
344
+ ]
345
+
346
+ # Configure model directory and device
347
+ model_dir = args.model_dir or os.getenv("MODEL_WEIGHTS_DIR", "/model-weights")
348
+ device = args.device or os.getenv("MEDRAX_DEVICE", "cuda:0")
349
+
350
+ print(f"Using model directory: {model_dir}")
351
+ print(f"Using device: {device}")
352
+ print(f"Using model: {args.model}")
353
+ print(f"Selected tools: {selected_tools}")
354
 
355
  # Setup the MedGemma environment if the MedGemmaVQATool is selected
356
  if "MedGemmaVQATool" in selected_tools:
 
359
  # Configure the Retrieval Augmented Generation (RAG) system
360
  # This allows the agent to access and use medical knowledge documents
361
  rag_config = RAGConfig(
362
+ model=args.rag_model,
363
+ embedding_model=args.rag_embedding_model,
364
+ rerank_model=args.rag_rerank_model,
365
+ temperature=args.rag_temperature,
366
+ pinecone_index_name=args.pinecone_index,
367
+ chunk_size=args.chunk_size,
368
+ chunk_overlap=args.chunk_overlap,
369
+ retriever_k=args.retriever_k,
370
+ local_docs_dir=args.rag_docs_dir,
371
  huggingface_datasets=["VictorLJZ/medrax2"], # List of HuggingFace datasets to load
372
  dataset_split="train", # Which split of the datasets to use
373
  )
 
376
  model_kwargs = {}
377
 
378
  agent, tools_dict = initialize_agent(
379
+ prompt_file=args.prompt_file,
380
  tools_to_use=selected_tools,
381
  model_dir=model_dir,
382
+ temp_dir=args.temp_dir,
383
  device=device,
384
+ model=args.model,
385
+ temperature=args.temperature,
386
  model_kwargs=model_kwargs,
387
  rag_config=rag_config,
388
+ system_prompt=args.system_prompt,
389
  )
390
 
391
+ # Launch based on selected mode
392
+ if args.mode == "gradio":
393
+ run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
394
+
395
+ elif args.mode == "api":
396
+ run_api_server(agent, tools_dict, args.api_host, args.api_port)
397
+
398
+ elif args.mode == "both":
399
+ # Run both services in separate threads
400
+ api_thread = threading.Thread(
401
+ target=run_api_server,
402
+ args=(agent, tools_dict, args.api_host, args.api_port)
403
+ )
404
+ api_thread.daemon = True
405
+ api_thread.start()
406
+
407
+ # Run Gradio in main thread
408
+ run_gradio_interface(agent, tools_dict, args.gradio_host, args.gradio_port)
pyproject.toml CHANGED
@@ -46,8 +46,9 @@ dependencies = [
46
  "gradio>=3.0.0",
47
  "gradio_client>=0.2.0",
48
  "httpx>=0.23.0",
49
- "uvicorn>=0.15.0",
50
  "fastapi>=0.68.0",
 
51
  "einops>=0.3.0",
52
  "einops-exts>=0.0.4",
53
  "timm==0.5.4",
 
46
  "gradio>=3.0.0",
47
  "gradio_client>=0.2.0",
48
  "httpx>=0.23.0",
49
+ "uvicorn[standard]>=0.15.0",
50
  "fastapi>=0.68.0",
51
+ "python-multipart>=0.0.6",
52
  "einops>=0.3.0",
53
  "einops-exts>=0.0.4",
54
  "timm==0.5.4",