samwell commited on
Commit
1f83b1b
Β·
1 Parent(s): 3b51954

Enable all medical imaging tools (VQA, classification, report generation, DICOM, web browsing)

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py CHANGED
@@ -33,6 +33,7 @@ print(f"Using device: {device}")
33
  tools = []
34
 
35
  if device == "cuda":
 
36
  try:
37
  from medrax.tools import XRayPhraseGroundingTool
38
  grounding_tool = XRayPhraseGroundingTool(
@@ -45,6 +46,59 @@ if device == "cuda":
45
  except Exception as e:
46
  print(f"βœ— Failed to load grounding tool: {e}")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  checkpointer = MemorySaver()
49
 
50
  llm = ModelFactory.create_model(
 
33
  tools = []
34
 
35
  if device == "cuda":
36
+ # Load GPU-based tools
37
  try:
38
  from medrax.tools import XRayPhraseGroundingTool
39
  grounding_tool = XRayPhraseGroundingTool(
 
46
  except Exception as e:
47
  print(f"βœ— Failed to load grounding tool: {e}")
48
 
49
+ try:
50
+ from medrax.tools import XRayVQATool
51
+ vqa_tool = XRayVQATool(
52
+ device=device,
53
+ temp_dir="temp",
54
+ load_in_4bit=True
55
+ )
56
+ tools.append(vqa_tool)
57
+ print("βœ“ Loaded VQA tool")
58
+ except Exception as e:
59
+ print(f"βœ— Failed to load VQA tool: {e}")
60
+
61
+ try:
62
+ from medrax.tools import XRayClassificationTool
63
+ classification_tool = XRayClassificationTool(
64
+ device=device,
65
+ temp_dir="temp",
66
+ load_in_4bit=True
67
+ )
68
+ tools.append(classification_tool)
69
+ print("βœ“ Loaded classification tool")
70
+ except Exception as e:
71
+ print(f"βœ— Failed to load classification tool: {e}")
72
+
73
+ try:
74
+ from medrax.tools import XRayReportGenerationTool
75
+ report_tool = XRayReportGenerationTool(
76
+ device=device,
77
+ temp_dir="temp",
78
+ load_in_4bit=True
79
+ )
80
+ tools.append(report_tool)
81
+ print("βœ“ Loaded report generation tool")
82
+ except Exception as e:
83
+ print(f"βœ— Failed to load report generation tool: {e}")
84
+
85
+ # Load non-GPU tools
86
+ try:
87
+ from medrax.tools import DICOMTool
88
+ dicom_tool = DICOMTool(temp_dir="temp")
89
+ tools.append(dicom_tool)
90
+ print("βœ“ Loaded DICOM tool")
91
+ except Exception as e:
92
+ print(f"βœ— Failed to load DICOM tool: {e}")
93
+
94
+ try:
95
+ from medrax.tools import WebBrowsingTool
96
+ browsing_tool = WebBrowsingTool()
97
+ tools.append(browsing_tool)
98
+ print("βœ“ Loaded web browsing tool")
99
+ except Exception as e:
100
+ print(f"βœ— Failed to load web browsing tool: {e}")
101
+
102
  checkpointer = MemorySaver()
103
 
104
  llm = ModelFactory.create_model(