Spaces:
Sleeping
Sleeping
Commit
·
6d6c01e
1
Parent(s):
22dca62
Fix device handling: check GPU availability before device_map
Browse files
models.py
CHANGED
|
@@ -127,13 +127,20 @@ def load_model(
|
|
| 127 |
model_kwargs = {
|
| 128 |
"token": config.hf_token,
|
| 129 |
"trust_remote_code": True,
|
| 130 |
-
"device_map": "auto",
|
| 131 |
}
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
else:
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
|
| 138 |
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
|
| 139 |
|
|
|
|
| 127 |
model_kwargs = {
|
| 128 |
"token": config.hf_token,
|
| 129 |
"trust_remote_code": True,
|
|
|
|
| 130 |
}
|
| 131 |
|
| 132 |
+
# On ZeroGPU, use device_map only when GPU is available
|
| 133 |
+
# Otherwise load to CPU for local testing
|
| 134 |
+
if torch.cuda.is_available():
|
| 135 |
+
model_kwargs["device_map"] = "auto"
|
| 136 |
+
if quant_config is not None:
|
| 137 |
+
model_kwargs["quantization_config"] = quant_config
|
| 138 |
+
else:
|
| 139 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 140 |
else:
|
| 141 |
+
# CPU mode - no quantization, float32
|
| 142 |
+
model_kwargs["device_map"] = "cpu"
|
| 143 |
+
model_kwargs["torch_dtype"] = torch.float32
|
| 144 |
|
| 145 |
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
|
| 146 |
|