Spaces:
Paused
Paused
fix: Load NV-Reason-CXR without dtype parameter to avoid JSON error
Browse filesThe 'dtype' parameter in from_pretrained() causes JSON serialization
errors in our environment. Instead, load the model with default settings,
then manually convert to bfloat16 after loading.
This approach achieves the same result as NVIDIA's demo (bfloat16 model)
but avoids the JSON serialization TypeError.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
medrax/tools/nv_reason_cxr.py
CHANGED
|
@@ -72,12 +72,13 @@ class NVReasonCXRTool(BaseTool):
|
|
| 72 |
print(f"Using device: {self.device}")
|
| 73 |
print("Following NVIDIA's exact loading pattern from official demo")
|
| 74 |
|
| 75 |
-
#
|
| 76 |
-
#
|
|
|
|
|
|
|
| 77 |
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
).eval().to(self.device)
|
| 81 |
|
| 82 |
self.processor = AutoProcessor.from_pretrained(
|
| 83 |
model_path,
|
|
|
|
| 72 |
print(f"Using device: {self.device}")
|
| 73 |
print("Following NVIDIA's exact loading pattern from official demo")
|
| 74 |
|
| 75 |
+
# NVIDIA's approach but adapted for our environment
|
| 76 |
+
# The 'dtype' parameter works in their Gradio Space environment
|
| 77 |
+
# but causes JSON serialization issues in our setup.
|
| 78 |
+
# Solution: Load with default dtype, then convert to bfloat16 manually
|
| 79 |
self.model = AutoModelForImageTextToText.from_pretrained(
|
| 80 |
+
model_path,
|
| 81 |
+
).to(dtype=torch.bfloat16).eval().to(self.device)
|
|
|
|
| 82 |
|
| 83 |
self.processor = AutoProcessor.from_pretrained(
|
| 84 |
model_path,
|