import os import torch import torch.nn as nn from huggingface_hub import hf_hub_download import sys from pathlib import Path # Add RETFound repo to path for imports REPO_DIR = Path(__file__).parent / "RETFound_MAE" sys.path.append(str(REPO_DIR)) from models_vit import RETFound_mae # architecture builder from util.pos_embed import interpolate_pos_embed from timm.models.layers import trunc_normal_ def build_classifier(num_classes: int, base_repo: str, base_filename: str, global_pool: bool = True, drop_path_rate: float = 0.2, device: str | torch.device = "cpu") -> nn.Module: """Load RETFound MAE backbone, attach a linear head for classification, and load pre-trained weights (excluding mismatched head). """ device = torch.device(device) # 1) Download pretrained MAE weights from the Hub #ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename) # Read token from env (if set) hf_token = os.getenv("HF_TOKEN") try: ckpt_path = hf_hub_download( repo_id=base_repo, filename=base_filename, token=hf_token, # Works for private if token exists cache_dir="/tmp/hf_cache" # Spaces-friendly cache ) except Exception as e: raise RuntimeError(f"Failed to download model from {base_repo}: {e}") # Load model weights print(f"Loading RETFound MAE weights from {ckpt_path}...") # 2) Build backbone model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate) # 3) Prepare head in_features = model.head.in_features model.head = nn.Linear(in_features, num_classes) # 4) Load checkpoint w/ position interpolation & head removal if mismatched checkpoint = torch.load(ckpt_path, map_location="cpu") ckpt = checkpoint.get("model", checkpoint) # handle both formats # Remove head weights if shape mismatch state_dict = model.state_dict() for k in ["head.weight", "head.bias"]: if k in ckpt and k in state_dict and ckpt[k].shape != state_dict[k].shape: del ckpt[k] interpolate_pos_embed(model, ckpt) msg = model.load_state_dict(ckpt, strict=False) # Re-init head for classification trunc_normal_(model.head.weight, std=2e-5) if hasattr(model.head, 'bias') and model.head.bias is not None: nn.init.zeros_(model.head.bias) model.to(device) return model