Forrest Wargo commited on
Commit
7cc210d
·
1 Parent(s): 34b89db

Ensure single-device placement (cuda or cpu) to avoid index device mismatch

Browse files
Files changed (1) hide show
  1. handler.py +9 -3
handler.py CHANGED
@@ -53,12 +53,18 @@ class EndpointHandler:
53
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
54
 
55
  # Load local repo (or remote if MODEL_ID points to hub id)
56
- # Pass token when accessing gated repos
57
  hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
 
 
 
 
 
 
58
  load_kwargs = {
59
  "trust_remote_code": True,
60
- "torch_dtype": torch.bfloat16,
61
- "device_map": "auto",
62
  }
63
  if hub_token:
64
  load_kwargs["token"] = hub_token
 
53
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
54
 
55
  # Load local repo (or remote if MODEL_ID points to hub id)
56
+ # Pass token when accessing gated repos and ensure consistent device placement
57
  hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
58
+ if torch.cuda.is_available():
59
+ device_map = {"": "cuda"}
60
+ dtype = torch.bfloat16
61
+ else:
62
+ device_map = {"": "cpu"}
63
+ dtype = torch.float32
64
  load_kwargs = {
65
  "trust_remote_code": True,
66
+ "torch_dtype": dtype,
67
+ "device_map": device_map,
68
  }
69
  if hub_token:
70
  load_kwargs["token"] = hub_token