samwell Claude commited on
Commit
6a7c30f
·
1 Parent(s): 7a6a9a6

fix: Upgrade transformers to 4.56.0 to fix NV-Reason-CXR dtype error

Browse files

Root cause: We were using transformers 4.51.3, but NVIDIA's official
demo uses 4.56.0. The older version has a bug where the 'dtype' parameter
causes JSON serialization errors.

Solution:
- Upgrade transformers from 4.51.3 to 4.56.0 (matches NVIDIA demo)
- Use NVIDIA's EXACT loading code with dtype=torch.bfloat16
- This is the same version and code that works in NVIDIA's official Space

This should finally resolve the persistent NV-Reason-CXR loading error.

Reference: https://huggingface.co/spaces/nvidia/nv-reason-cxr

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

medrax/tools/nv_reason_cxr.py CHANGED
@@ -66,19 +66,19 @@ class NVReasonCXRTool(BaseTool):
66
  super().__init__()
67
  self.device = device
68
 
69
- # Load model following NVIDIA's official demo code
 
70
  try:
71
  print(f"Loading NV-Reason-CXR model from {model_path}...")
72
  print(f"Using device: {self.device}")
73
- print("Following NVIDIA's exact loading pattern from official demo")
74
 
75
- # NVIDIA's approach but adapted for our environment
76
- # The 'dtype' parameter works in their Gradio Space environment
77
- # but causes JSON serialization issues in our setup.
78
- # Solution: Load with default dtype, then convert to bfloat16 manually
79
  self.model = AutoModelForImageTextToText.from_pretrained(
80
- model_path,
81
- ).to(dtype=torch.bfloat16).eval().to(self.device)
 
82
 
83
  self.processor = AutoProcessor.from_pretrained(
84
  model_path,
 
66
  super().__init__()
67
  self.device = device
68
 
69
+ # Load model following NVIDIA's official demo code EXACTLY
70
+ # Requires transformers==4.56.0 (same as NVIDIA's demo)
71
  try:
72
  print(f"Loading NV-Reason-CXR model from {model_path}...")
73
  print(f"Using device: {self.device}")
74
+ print("Using NVIDIA's exact loading pattern with transformers 4.56.0")
75
 
76
+ # Match NVIDIA's demo exactly - requires transformers 4.56.0
77
+ # The dtype parameter works correctly in newer transformers versions
 
 
78
  self.model = AutoModelForImageTextToText.from_pretrained(
79
+ pretrained_model_name_or_path=model_path,
80
+ dtype=torch.bfloat16,
81
+ ).eval().to(self.device)
82
 
83
  self.processor = AutoProcessor.from_pretrained(
84
  model_path,
requirements.txt CHANGED
@@ -21,7 +21,7 @@ Pillow>=8.0.0
21
  PyPDF2>=3.0.0
22
  pdfplumber>=0.10.0
23
  torchxrayvision>=0.0.37
24
- transformers==4.51.3
25
  datasets>=2.15.0
26
  tokenizers>=0.21,<0.22
27
  sentencepiece>=0.1.95
 
21
  PyPDF2>=3.0.0
22
  pdfplumber>=0.10.0
23
  torchxrayvision>=0.0.37
24
+ transformers==4.56.0
25
  datasets>=2.15.0
26
  tokenizers>=0.21,<0.22
27
  sentencepiece>=0.1.95