Forrest Wargo commited on
Commit ·
7cc210d
1
Parent(s): 34b89db
Ensure single-device placement (cuda or cpu) to avoid index device mismatch
Browse files- handler.py +9 -3
handler.py
CHANGED
|
@@ -53,12 +53,18 @@ class EndpointHandler:
|
|
| 53 |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
|
| 55 |
# Load local repo (or remote if MODEL_ID points to hub id)
|
| 56 |
-
# Pass token when accessing gated repos
|
| 57 |
hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
load_kwargs = {
|
| 59 |
"trust_remote_code": True,
|
| 60 |
-
"torch_dtype":
|
| 61 |
-
"device_map":
|
| 62 |
}
|
| 63 |
if hub_token:
|
| 64 |
load_kwargs["token"] = hub_token
|
|
|
|
| 53 |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
|
| 55 |
# Load local repo (or remote if MODEL_ID points to hub id)
|
| 56 |
+
# Pass token when accessing gated repos and ensure consistent device placement
|
| 57 |
hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
device_map = {"": "cuda"}
|
| 60 |
+
dtype = torch.bfloat16
|
| 61 |
+
else:
|
| 62 |
+
device_map = {"": "cpu"}
|
| 63 |
+
dtype = torch.float32
|
| 64 |
load_kwargs = {
|
| 65 |
"trust_remote_code": True,
|
| 66 |
+
"torch_dtype": dtype,
|
| 67 |
+
"device_map": device_map,
|
| 68 |
}
|
| 69 |
if hub_token:
|
| 70 |
load_kwargs["token"] = hub_token
|