Habeeb Okunade commited on
Commit
d8770e8
·
1 Parent(s): b66d70d

Updating model

Browse files
Files changed (1) hide show
  1. model_loader.py +17 -1
model_loader.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from huggingface_hub import hf_hub_download
@@ -24,7 +25,22 @@ def build_classifier(num_classes: int,
24
  device = torch.device(device)
25
 
26
  # 1) Download pretrained MAE weights from the Hub
27
- ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # 2) Build backbone
30
  model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate)
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
  from huggingface_hub import hf_hub_download
 
25
  device = torch.device(device)
26
 
27
  # 1) Download pretrained MAE weights from the Hub
28
+ #ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename)
29
+ # Read token from env (if set)
30
+ hf_token = os.getenv("HF_TOKEN")
31
+
32
+ try:
33
+ ckpt_path = hf_hub_download(
34
+ repo_id=base_repo,
35
+ filename=base_filename,
36
+ token=hf_token, # Works for private if token exists
37
+ cache_dir="/tmp/hf_cache" # Spaces-friendly cache
38
+ )
39
+ except Exception as e:
40
+ raise RuntimeError(f"Failed to download model from {base_repo}: {e}")
41
+
42
+ # Load model weights
43
+ print(f"Loading RETFound MAE weights from {ckpt_path}...")
44
 
45
  # 2) Build backbone
46
  model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate)