Spaces:
Sleeping
Sleeping
Habeeb Okunade commited on
Commit ·
d8770e8
1
Parent(s): b66d70d
Updating model
Browse files- model_loader.py +17 -1
model_loader.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from huggingface_hub import hf_hub_download
|
|
@@ -24,7 +25,22 @@ def build_classifier(num_classes: int,
|
|
| 24 |
device = torch.device(device)
|
| 25 |
|
| 26 |
# 1) Download pretrained MAE weights from the Hub
|
| 27 |
-
ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# 2) Build backbone
|
| 30 |
model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate)
|
|
|
|
| 1 |
+
import os
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 25 |
device = torch.device(device)
|
| 26 |
|
| 27 |
# 1) Download pretrained MAE weights from the Hub
|
| 28 |
+
#ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename)
|
| 29 |
+
# Read token from env (if set)
|
| 30 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
ckpt_path = hf_hub_download(
|
| 34 |
+
repo_id=base_repo,
|
| 35 |
+
filename=base_filename,
|
| 36 |
+
token=hf_token, # Works for private if token exists
|
| 37 |
+
cache_dir="/tmp/hf_cache" # Spaces-friendly cache
|
| 38 |
+
)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise RuntimeError(f"Failed to download model from {base_repo}: {e}")
|
| 41 |
+
|
| 42 |
+
# Load model weights
|
| 43 |
+
print(f"Loading RETFound MAE weights from {ckpt_path}...")
|
| 44 |
|
| 45 |
# 2) Build backbone
|
| 46 |
model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate)
|