samwell Claude commited on
Commit
e7a5afc
·
1 Parent(s): 99d91bc

fix: Load grounding model without dtype parameter to avoid errors

Browse files

Changed from using dtype parameter in from_pretrained to manually
converting dtype after loading. This avoids potential JSON serialization
and compatibility issues with the dtype parameter.

Pattern: load() → .to(dtype=bfloat16) → .eval() → .to(device)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. medrax/tools/grounding.py +3 -4
medrax/tools/grounding.py CHANGED
@@ -67,14 +67,13 @@ class XRayPhraseGroundingTool(BaseTool):
67
  super().__init__()
68
  self.device = torch.device(device) if device else "cuda"
69
 
70
- # Load model following transformers 4.56.0 API
71
- # Use 'dtype' instead of deprecated 'torch_dtype'
72
  self.model = AutoModelForCausalLM.from_pretrained(
73
  model_path,
74
  cache_dir=cache_dir,
75
  trust_remote_code=True,
76
- dtype=torch.bfloat16,
77
- ).eval().to(self.device)
78
 
79
  self.processor = AutoProcessor.from_pretrained(
80
  model_path,
 
67
  super().__init__()
68
  self.device = torch.device(device) if device else "cuda"
69
 
70
+ # Load model - convert to bfloat16 after loading to avoid dtype parameter issues
71
+ # Load with default dtype, then manually convert to bfloat16
72
  self.model = AutoModelForCausalLM.from_pretrained(
73
  model_path,
74
  cache_dir=cache_dir,
75
  trust_remote_code=True,
76
+ ).to(dtype=torch.bfloat16).eval().to(self.device)
 
77
 
78
  self.processor = AutoProcessor.from_pretrained(
79
  model_path,