Retina_Training / model_loader.py
Habeeb Okunade
Updating model
d8770e8
import os
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import sys
from pathlib import Path
# Add RETFound repo to path for imports
REPO_DIR = Path(__file__).parent / "RETFound_MAE"
sys.path.append(str(REPO_DIR))
from models_vit import RETFound_mae # architecture builder
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_
def build_classifier(num_classes: int,
base_repo: str,
base_filename: str,
global_pool: bool = True,
drop_path_rate: float = 0.2,
device: str | torch.device = "cpu") -> nn.Module:
"""Load RETFound MAE backbone, attach a linear head for classification, and
load pre-trained weights (excluding mismatched head).
"""
device = torch.device(device)
# 1) Download pretrained MAE weights from the Hub
#ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename)
# Read token from env (if set)
hf_token = os.getenv("HF_TOKEN")
try:
ckpt_path = hf_hub_download(
repo_id=base_repo,
filename=base_filename,
token=hf_token, # Works for private if token exists
cache_dir="/tmp/hf_cache" # Spaces-friendly cache
)
except Exception as e:
raise RuntimeError(f"Failed to download model from {base_repo}: {e}")
# Load model weights
print(f"Loading RETFound MAE weights from {ckpt_path}...")
# 2) Build backbone
model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate)
# 3) Prepare head
in_features = model.head.in_features
model.head = nn.Linear(in_features, num_classes)
# 4) Load checkpoint w/ position interpolation & head removal if mismatched
checkpoint = torch.load(ckpt_path, map_location="cpu")
ckpt = checkpoint.get("model", checkpoint) # handle both formats
# Remove head weights if shape mismatch
state_dict = model.state_dict()
for k in ["head.weight", "head.bias"]:
if k in ckpt and k in state_dict and ckpt[k].shape != state_dict[k].shape:
del ckpt[k]
interpolate_pos_embed(model, ckpt)
msg = model.load_state_dict(ckpt, strict=False)
# Re-init head for classification
trunc_normal_(model.head.weight, std=2e-5)
if hasattr(model.head, 'bias') and model.head.bias is not None:
nn.init.zeros_(model.head.bias)
model.to(device)
return model