samwell Claude commited on
Commit
7a6a9a6
·
1 Parent(s): f49ba8b

fix: Load NV-Reason-CXR without dtype parameter to avoid JSON error

Browse files

The 'dtype' parameter in from_pretrained() causes JSON serialization
errors in our environment. Instead, load the model with default settings,
then manually convert to bfloat16 after loading.

This approach achieves the same result as NVIDIA's demo (bfloat16 model)
but avoids the JSON serialization TypeError.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. medrax/tools/nv_reason_cxr.py +6 -5
medrax/tools/nv_reason_cxr.py CHANGED
@@ -72,12 +72,13 @@ class NVReasonCXRTool(BaseTool):
72
  print(f"Using device: {self.device}")
73
  print("Following NVIDIA's exact loading pattern from official demo")
74
 
75
- # Follow NVIDIA's exact approach from their official Gradio demo
76
- # Key: Use 'dtype' parameter (NOT 'torch_dtype' which is deprecated)
 
 
77
  self.model = AutoModelForImageTextToText.from_pretrained(
78
- pretrained_model_name_or_path=model_path,
79
- dtype=torch.bfloat16,
80
- ).eval().to(self.device)
81
 
82
  self.processor = AutoProcessor.from_pretrained(
83
  model_path,
 
72
  print(f"Using device: {self.device}")
73
  print("Following NVIDIA's exact loading pattern from official demo")
74
 
75
+ # NVIDIA's approach but adapted for our environment
76
+ # The 'dtype' parameter works in their Gradio Space environment
77
+ # but causes JSON serialization issues in our setup.
78
+ # Solution: Load with default dtype, then convert to bfloat16 manually
79
  self.model = AutoModelForImageTextToText.from_pretrained(
80
+ model_path,
81
+ ).to(dtype=torch.bfloat16).eval().to(self.device)
 
82
 
83
  self.processor = AutoProcessor.from_pretrained(
84
  model_path,