samwell Claude commited on
Commit
b8918e3
Β·
1 Parent(s): 8bd044e

Fix tool loading errors in MedRAX2

Browse files

- Fix NV-Reason-CXR quantization config handling to avoid dict.to_dict() error
- Correct tool import paths to use actual class names:
- CheXagentXRayVQATool instead of XRayVQATool
- TorchXRayVisionClassifierTool instead of XRayClassificationTool
- ChestXRayReportGeneratorTool instead of XRayReportGenerationTool
- DicomProcessorTool instead of DICOMTool
- WebBrowserTool instead of WebBrowsingTool
- Add better error handling for model loading

πŸ€– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +10 -10
  2. medrax/tools/nv_reason_cxr.py +42 -27
app.py CHANGED
@@ -58,8 +58,8 @@ if device == "cuda":
58
  print(f"βœ— Failed to load grounding tool: {e}")
59
 
60
  try:
61
- from medrax.tools import XRayVQATool
62
- vqa_tool = XRayVQATool(
63
  device=device,
64
  temp_dir="temp",
65
  load_in_4bit=True
@@ -70,8 +70,8 @@ if device == "cuda":
70
  print(f"βœ— Failed to load VQA tool: {e}")
71
 
72
  try:
73
- from medrax.tools import XRayClassificationTool
74
- classification_tool = XRayClassificationTool(
75
  device=device,
76
  temp_dir="temp",
77
  load_in_4bit=True
@@ -82,8 +82,8 @@ if device == "cuda":
82
  print(f"βœ— Failed to load classification tool: {e}")
83
 
84
  try:
85
- from medrax.tools import XRayReportGenerationTool
86
- report_tool = XRayReportGenerationTool(
87
  device=device,
88
  temp_dir="temp",
89
  load_in_4bit=True
@@ -95,16 +95,16 @@ if device == "cuda":
95
 
96
  # Load non-GPU tools
97
  try:
98
- from medrax.tools import DICOMTool
99
- dicom_tool = DICOMTool(temp_dir="temp")
100
  tools.append(dicom_tool)
101
  print("βœ“ Loaded DICOM tool")
102
  except Exception as e:
103
  print(f"βœ— Failed to load DICOM tool: {e}")
104
 
105
  try:
106
- from medrax.tools import WebBrowsingTool
107
- browsing_tool = WebBrowsingTool()
108
  tools.append(browsing_tool)
109
  print("βœ“ Loaded web browsing tool")
110
  except Exception as e:
 
58
  print(f"βœ— Failed to load grounding tool: {e}")
59
 
60
  try:
61
+ from medrax.tools.vqa import CheXagentXRayVQATool
62
+ vqa_tool = CheXagentXRayVQATool(
63
  device=device,
64
  temp_dir="temp",
65
  load_in_4bit=True
 
70
  print(f"βœ— Failed to load VQA tool: {e}")
71
 
72
  try:
73
+ from medrax.tools.classification import TorchXRayVisionClassifierTool
74
+ classification_tool = TorchXRayVisionClassifierTool(
75
  device=device,
76
  temp_dir="temp",
77
  load_in_4bit=True
 
82
  print(f"βœ— Failed to load classification tool: {e}")
83
 
84
  try:
85
+ from medrax.tools.report_generation import ChestXRayReportGeneratorTool
86
+ report_tool = ChestXRayReportGeneratorTool(
87
  device=device,
88
  temp_dir="temp",
89
  load_in_4bit=True
 
95
 
96
  # Load non-GPU tools
97
  try:
98
+ from medrax.tools.dicom import DicomProcessorTool
99
+ dicom_tool = DicomProcessorTool(temp_dir="temp")
100
  tools.append(dicom_tool)
101
  print("βœ“ Loaded DICOM tool")
102
  except Exception as e:
103
  print(f"βœ— Failed to load DICOM tool: {e}")
104
 
105
  try:
106
+ from medrax.tools.browsing import WebBrowserTool
107
+ browsing_tool = WebBrowserTool()
108
  tools.append(browsing_tool)
109
  print("βœ“ Loaded web browsing tool")
110
  except Exception as e:
medrax/tools/nv_reason_cxr.py CHANGED
@@ -67,35 +67,50 @@ class NVReasonCXRTool(BaseTool):
67
  self.device = device
68
 
69
  # Setup quantization config
70
- if load_in_4bit:
71
- quantization_config = BitsAndBytesConfig(
72
- load_in_4bit=True,
73
- bnb_4bit_compute_dtype=torch.bfloat16,
74
- bnb_4bit_use_double_quant=True,
75
- bnb_4bit_quant_type="nf4",
76
- )
77
- else:
78
- quantization_config = None
 
 
 
79
 
80
  # Load model
81
- print(f"Loading NV-Reason-CXR model from {model_path}...")
82
- self.model = AutoModelForImageTextToText.from_pretrained(
83
- model_path,
84
- device_map=self.device,
85
- cache_dir=cache_dir,
86
- torch_dtype=torch.bfloat16,
87
- quantization_config=quantization_config,
88
- trust_remote_code=True,
89
- ).eval()
90
-
91
- self.processor = AutoProcessor.from_pretrained(
92
- model_path,
93
- cache_dir=cache_dir,
94
- trust_remote_code=True,
95
- use_fast=True,
96
- )
97
-
98
- print(f"βœ“ NV-Reason-CXR model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def _run(
101
  self,
 
67
  self.device = device
68
 
69
  # Setup quantization config
70
+ quantization_config = None
71
+ if load_in_4bit and device == "cuda":
72
+ try:
73
+ quantization_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_compute_dtype=torch.bfloat16,
76
+ bnb_4bit_use_double_quant=True,
77
+ bnb_4bit_quant_type="nf4",
78
+ )
79
+ except Exception as e:
80
+ print(f"Warning: Could not setup 4-bit quantization: {e}")
81
+ quantization_config = None
82
 
83
  # Load model
84
+ try:
85
+ print(f"Loading NV-Reason-CXR model from {model_path}...")
86
+
87
+ # Load without quantization config if it's causing issues
88
+ model_kwargs = {
89
+ "device_map": self.device,
90
+ "cache_dir": cache_dir,
91
+ "torch_dtype": torch.bfloat16,
92
+ "trust_remote_code": True,
93
+ }
94
+
95
+ if quantization_config is not None:
96
+ model_kwargs["quantization_config"] = quantization_config
97
+
98
+ self.model = AutoModelForImageTextToText.from_pretrained(
99
+ model_path,
100
+ **model_kwargs
101
+ ).eval()
102
+
103
+ self.processor = AutoProcessor.from_pretrained(
104
+ model_path,
105
+ cache_dir=cache_dir,
106
+ trust_remote_code=True,
107
+ use_fast=True,
108
+ )
109
+
110
+ print(f"βœ“ NV-Reason-CXR model loaded successfully")
111
+ except Exception as e:
112
+ print(f"Error loading NV-Reason-CXR model: {e}")
113
+ raise
114
 
115
  def _run(
116
  self,