samwell Claude commited on
Commit
37bdbfa
Β·
1 Parent(s): e7a5afc

fix: Downgrade transformers to 4.51.3 to fix grounding tool

Browse files

Problem: MAIRA-2 grounding tool requires transformers <4.52, but we
upgraded to 4.56.0 for NV-Reason-CXR, breaking grounding.

Solution: Prioritize grounding tool (more useful for visualization)
- Downgrade transformers from 4.56.0 β†’ 4.51.3
- Revert tokenizers from >=0.22 β†’ >=0.21,<0.22
- Revert grounding tool to use torch_dtype with transformers 4.51.3
- Disable NV-Reason-CXR tool (incompatible with transformers 4.51.3)

Now 7 tools will work:
βœ“ Grounding (MAIRA-2) - Now working
βœ“ Segmentation
βœ“ Classification
βœ“ VQA (CheXagent)
βœ“ Report Generation
βœ“ DICOM Processor
βœ“ Web Browser
βŠ— NV-Reason-CXR - Disabled (transformers conflict)

πŸ€– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +14 -12
  2. medrax/tools/grounding.py +5 -3
  3. requirements.txt +2 -2
app.py CHANGED
@@ -38,18 +38,20 @@ tools = []
38
  if device == "cuda":
39
  # Load GPU-based tools
40
 
41
- # NV-Reason-CXR - Re-enabled for L40S (48GB VRAM)
42
- # Quantization disabled due to dict.to_dict() error - runs in bfloat16 (~7GB)
43
- try:
44
- from medrax.tools import NVReasonCXRTool
45
- nv_reason_tool = NVReasonCXRTool(
46
- device=device,
47
- load_in_4bit=False # Disabled - causes BitsAndBytesConfig error
48
- )
49
- tools.append(nv_reason_tool)
50
- print("βœ“ Loaded NV-Reason-CXR tool")
51
- except Exception as e:
52
- print(f"βœ— Failed to load NV-Reason-CXR tool: {e}")
 
 
53
 
54
  # MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM)
55
  try:
 
38
  if device == "cuda":
39
  # Load GPU-based tools
40
 
41
+ # NV-Reason-CXR - Disabled due to transformers version conflict
42
+ # Requires transformers 4.56.0, but MAIRA-2 (grounding) requires <4.52
43
+ # Prioritizing grounding tool for visualization
44
+ # try:
45
+ # from medrax.tools import NVReasonCXRTool
46
+ # nv_reason_tool = NVReasonCXRTool(
47
+ # device=device,
48
+ # load_in_4bit=False
49
+ # )
50
+ # tools.append(nv_reason_tool)
51
+ # print("βœ“ Loaded NV-Reason-CXR tool")
52
+ # except Exception as e:
53
+ # print(f"βœ— Failed to load NV-Reason-CXR tool: {e}")
54
+ print("βŠ— NV-Reason-CXR tool disabled (transformers version conflict)")
55
 
56
  # MAIRA-2 Grounding - Re-enabled for L40S (48GB VRAM)
57
  try:
medrax/tools/grounding.py CHANGED
@@ -67,13 +67,15 @@ class XRayPhraseGroundingTool(BaseTool):
67
  super().__init__()
68
  self.device = torch.device(device) if device else "cuda"
69
 
70
- # Load model - convert to bfloat16 after loading to avoid dtype parameter issues
71
- # Load with default dtype, then manually convert to bfloat16
72
  self.model = AutoModelForCausalLM.from_pretrained(
73
  model_path,
 
74
  cache_dir=cache_dir,
75
  trust_remote_code=True,
76
- ).to(dtype=torch.bfloat16).eval().to(self.device)
 
 
77
 
78
  self.processor = AutoProcessor.from_pretrained(
79
  model_path,
 
67
  super().__init__()
68
  self.device = torch.device(device) if device else "cuda"
69
 
70
+ # Load model without quantization - works with transformers 4.51.3
 
71
  self.model = AutoModelForCausalLM.from_pretrained(
72
  model_path,
73
+ device_map=self.device,
74
  cache_dir=cache_dir,
75
  trust_remote_code=True,
76
+ torch_dtype=torch.bfloat16,
77
+ )
78
+ self.model = self.model.eval()
79
 
80
  self.processor = AutoProcessor.from_pretrained(
81
  model_path,
requirements.txt CHANGED
@@ -21,9 +21,9 @@ Pillow>=8.0.0
21
  PyPDF2>=3.0.0
22
  pdfplumber>=0.10.0
23
  torchxrayvision>=0.0.37
24
- transformers==4.56.0
25
  datasets>=2.15.0
26
- tokenizers>=0.22,<0.24
27
  sentencepiece>=0.1.95
28
  shortuuid>=1.0.0
29
  tqdm>=4.64.0
 
21
  PyPDF2>=3.0.0
22
  pdfplumber>=0.10.0
23
  torchxrayvision>=0.0.37
24
+ transformers==4.51.3
25
  datasets>=2.15.0
26
+ tokenizers>=0.21,<0.22
27
  sentencepiece>=0.1.95
28
  shortuuid>=1.0.0
29
  tqdm>=4.64.0