Forrest Wargo commited on
Commit ·
34b89db
1
Parent(s): 5daee26
Support gated model: pass HF token to from_pretrained
Browse files- handler.py +10 -3
handler.py
CHANGED
|
@@ -53,11 +53,18 @@ class EndpointHandler:
|
|
| 53 |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
|
| 55 |
# Load local repo (or remote if MODEL_ID points to hub id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 57 |
model_id,
|
| 58 |
-
|
| 59 |
-
torch_dtype=torch.bfloat16,
|
| 60 |
-
device_map="auto",
|
| 61 |
)
|
| 62 |
|
| 63 |
# Optional compilation for speed if exposed by remote code
|
|
|
|
| 53 |
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 54 |
|
| 55 |
# Load local repo (or remote if MODEL_ID points to hub id)
|
| 56 |
+
# Pass token when accessing gated repos
|
| 57 |
+
hub_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HF_TOKEN")
|
| 58 |
+
load_kwargs = {
|
| 59 |
+
"trust_remote_code": True,
|
| 60 |
+
"torch_dtype": torch.bfloat16,
|
| 61 |
+
"device_map": "auto",
|
| 62 |
+
}
|
| 63 |
+
if hub_token:
|
| 64 |
+
load_kwargs["token"] = hub_token
|
| 65 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 66 |
model_id,
|
| 67 |
+
**load_kwargs,
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
# Optional compilation for speed if exposed by remote code
|