Forrest Wargo commited on
Commit
34b89db
·
1 Parent(s): 5daee26

Support gated model: pass HF token to from_pretrained

Browse files
Files changed (1) hide show
  1. handler.py +10 -3
handler.py CHANGED
@@ -53,11 +53,18 @@ class EndpointHandler:
53
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
54
 
55
  # Load local repo (or remote if MODEL_ID points to hub id)
 
 
 
 
 
 
 
 
 
56
  self.model = AutoModelForCausalLM.from_pretrained(
57
  model_id,
58
- trust_remote_code=True,
59
- torch_dtype=torch.bfloat16,
60
- device_map="auto",
61
  )
62
 
63
  # Optional compilation for speed if exposed by remote code
 
53
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
54
 
55
  # Load local repo (or remote if MODEL_ID points to hub id)
56
+ # Pass token when accessing gated repos
57
+ hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
58
+ load_kwargs = {
59
+ "trust_remote_code": True,
60
+ "torch_dtype": torch.bfloat16,
61
+ "device_map": "auto",
62
+ }
63
+ if hub_token:
64
+ load_kwargs["token"] = hub_token
65
  self.model = AutoModelForCausalLM.from_pretrained(
66
  model_id,
67
+ **load_kwargs,
 
 
68
  )
69
 
70
  # Optional compilation for speed if exposed by remote code