sanjay7676 commited on
Commit
5a67c2d
·
1 Parent(s): 26e10b8

fix(custom_hf): prefer dtype over deprecated torch_dtype in from_pretrained

Browse files
Files changed (1) hide show
  1. forge/providers/hf_custom.py +8 -4
forge/providers/hf_custom.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import logging
6
  from typing import Any
7
 
@@ -48,10 +49,13 @@ class HFCustomProvider:
48
 
49
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
50
  device_map = "auto" if torch.cuda.is_available() else None
51
- model_kw: dict[str, Any] = {
52
- "trust_remote_code": True,
53
- "torch_dtype": dtype,
54
- }
 
 
 
55
  if self.hf_token:
56
  model_kw["token"] = self.hf_token
57
  if device_map:
 
2
 
3
  from __future__ import annotations
4
 
5
+ import inspect
6
  import logging
7
  from typing import Any
8
 
 
49
 
50
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
51
  device_map = "auto" if torch.cuda.is_available() else None
52
+ model_kw: dict[str, Any] = {"trust_remote_code": True}
53
+ # Newer transformers prefer `dtype`; older builds use `torch_dtype`.
54
+ _sig = inspect.signature(AutoModelForCausalLM.from_pretrained)
55
+ if "dtype" in _sig.parameters:
56
+ model_kw["dtype"] = dtype
57
+ else:
58
+ model_kw["torch_dtype"] = dtype
59
  if self.hf_token:
60
  model_kw["token"] = self.hf_token
61
  if device_map: