serenichron commited on
Commit
6d6c01e
·
1 Parent(s): 22dca62

Fix device handling: check GPU availability before device_map

Browse files
Files changed (1) hide show
  1. models.py +11 -4
models.py CHANGED
@@ -127,13 +127,20 @@ def load_model(
127
  model_kwargs = {
128
  "token": config.hf_token,
129
  "trust_remote_code": True,
130
- "device_map": "auto",
131
  }
132
 
133
- if quant_config is not None:
134
- model_kwargs["quantization_config"] = quant_config
 
 
 
 
 
 
135
  else:
136
- model_kwargs["torch_dtype"] = torch.bfloat16
 
 
137
 
138
  model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
139
 
 
127
  model_kwargs = {
128
  "token": config.hf_token,
129
  "trust_remote_code": True,
 
130
  }
131
 
132
+ # On ZeroGPU, use device_map only when GPU is available
133
+ # Otherwise load to CPU for local testing
134
+ if torch.cuda.is_available():
135
+ model_kwargs["device_map"] = "auto"
136
+ if quant_config is not None:
137
+ model_kwargs["quantization_config"] = quant_config
138
+ else:
139
+ model_kwargs["torch_dtype"] = torch.bfloat16
140
  else:
141
+ # CPU mode - no quantization, float32
142
+ model_kwargs["device_map"] = "cpu"
143
+ model_kwargs["torch_dtype"] = torch.float32
144
 
145
  model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
146