Spaces:
Sleeping
Sleeping
Commit ·
5a67c2d
1
Parent(s): 26e10b8
fix(custom_hf): prefer dtype over deprecated torch_dtype in from_pretrained
Browse files
forge/providers/hf_custom.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 5 |
import logging
|
| 6 |
from typing import Any
|
| 7 |
|
|
@@ -48,10 +49,13 @@ class HFCustomProvider:
|
|
| 48 |
|
| 49 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 50 |
device_map = "auto" if torch.cuda.is_available() else None
|
| 51 |
-
model_kw: dict[str, Any] = {
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
if self.hf_token:
|
| 56 |
model_kw["token"] = self.hf_token
|
| 57 |
if device_map:
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import inspect
|
| 6 |
import logging
|
| 7 |
from typing import Any
|
| 8 |
|
|
|
|
| 49 |
|
| 50 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 51 |
device_map = "auto" if torch.cuda.is_available() else None
|
| 52 |
+
model_kw: dict[str, Any] = {"trust_remote_code": True}
|
| 53 |
+
# Newer transformers prefer `dtype`; older builds use `torch_dtype`.
|
| 54 |
+
_sig = inspect.signature(AutoModelForCausalLM.from_pretrained)
|
| 55 |
+
if "dtype" in _sig.parameters:
|
| 56 |
+
model_kw["dtype"] = dtype
|
| 57 |
+
else:
|
| 58 |
+
model_kw["torch_dtype"] = dtype
|
| 59 |
if self.hf_token:
|
| 60 |
model_kw["token"] = self.hf_token
|
| 61 |
if device_map:
|