adibvafa commited on
Commit
ea89378
·
1 Parent(s): a1e6245

Add support for selecting tools.

Browse files
Files changed (1) hide show
  1. main.py +51 -14
main.py CHANGED
@@ -18,28 +18,49 @@ logging.set_verbosity_error()
18
  _ = load_dotenv()
19
 
20
 
21
- def initialize_agent(prompt_file, model_dir="/model-weights", temp_dir="temp", device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  prompts = load_prompts_from_file(prompt_file)
23
  prompt = prompts["MEDICAL_ASSISTANT"]
24
 
25
- tools_dict = {
26
- "ChestXRayClassifierTool": ChestXRayClassifierTool(device=device),
27
- "ChestXRaySegmentationTool": ChestXRaySegmentationTool(device=device),
28
- "LlavaMedTool": LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
29
- "XRayVQATool": XRayVQATool(cache_dir=model_dir, device=device),
30
- "ChestXRayReportGeneratorTool": ChestXRayReportGeneratorTool(
31
  cache_dir=model_dir, device=device
32
  ),
33
- "XRayPhraseGroundingTool": XRayPhraseGroundingTool(
34
  cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
35
  ),
36
- "ChestXRayGeneratorTool": ChestXRayGeneratorTool(
37
  model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
38
  ),
39
- "ImageVisualizerTool": ImageVisualizerTool(),
40
- "DicomProcessorTool": DicomProcessorTool(temp_dir=temp_dir),
41
  }
42
 
 
 
 
 
 
 
 
43
  checkpointer = MemorySaver()
44
  model = ChatOpenAI(model="gpt-4o", temperature=0.7, top_p=0.95)
45
  agent = Agent(
@@ -56,12 +77,28 @@ def initialize_agent(prompt_file, model_dir="/model-weights", temp_dir="temp", d
56
 
57
 
58
  if __name__ == "__main__":
 
 
 
 
59
  print("Starting server...")
60
 
61
- # Setup model_dir to where you want to download the weights
62
- # Some tools needs you to download the weights beforehand from Hugging Face
 
 
 
 
 
 
 
 
 
 
 
 
63
  agent, tools_dict = initialize_agent(
64
- "medrax/docs/system_prompts.txt", model_dir="/model-weights"
65
  )
66
  demo = create_demo(agent, tools_dict)
67
 
 
18
  _ = load_dotenv()
19
 
20
 
21
+ def initialize_agent(
22
+ prompt_file, tools_to_use=None, model_dir="/model-weights", temp_dir="temp", device="cuda"
23
+ ):
24
+ """Initialize the MedRAX agent with specified tools and configuration.
25
+
26
+ Args:
27
+ prompt_file (str): Path to file containing system prompts
28
+ tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
29
+ model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
30
+ temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
31
+ device (str, optional): Device to run models on. Defaults to "cuda".
32
+
33
+ Returns:
34
+ Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
35
+ """
36
  prompts = load_prompts_from_file(prompt_file)
37
  prompt = prompts["MEDICAL_ASSISTANT"]
38
 
39
+ all_tools = {
40
+ "ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
41
+ "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
42
+ "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
43
+ "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
44
+ "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
45
  cache_dir=model_dir, device=device
46
  ),
47
+ "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
48
  cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
49
  ),
50
+ "ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
51
  model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
52
  ),
53
+ "ImageVisualizerTool": lambda: ImageVisualizerTool(),
54
+ "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
55
  }
56
 
57
+ # Initialize only selected tools or all if none specified
58
+ tools_dict = {}
59
+ tools_to_use = tools_to_use or all_tools.keys()
60
+ for tool_name in tools_to_use:
61
+ if tool_name in all_tools:
62
+ tools_dict[tool_name] = all_tools[tool_name]()
63
+
64
  checkpointer = MemorySaver()
65
  model = ChatOpenAI(model="gpt-4o", temperature=0.7, top_p=0.95)
66
  agent = Agent(
 
77
 
78
 
79
  if __name__ == "__main__":
80
+ """
81
+ This is the main entry point for the MedRAX application.
82
+ It initializes the agent with the selected tools and creates the demo.
83
+ """
84
  print("Starting server...")
85
 
86
+ # Example: initialize with only specific tools
87
+ # Here three tools are commented out, you can uncomment them to use them
88
+ selected_tools = [
89
+ "ImageVisualizerTool",
90
+ "DicomProcessorTool",
91
+ "ChestXRayClassifierTool",
92
+ "ChestXRaySegmentationTool",
93
+ "ChestXRayReportGeneratorTool",
94
+ "XRayVQATool",
95
+ # "LlavaMedTool",
96
+ # "XRayPhraseGroundingTool",
97
+ # "ChestXRayGeneratorTool",
98
+ ]
99
+
100
  agent, tools_dict = initialize_agent(
101
+ "medrax/docs/system_prompts.txt", tools_to_use=selected_tools, model_dir="/model-weights"
102
  )
103
  demo = create_demo(agent, tools_dict)
104