Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| import sys | |
| from pathlib import Path | |
| # Add RETFound repo to path for imports | |
| REPO_DIR = Path(__file__).parent / "RETFound_MAE" | |
| sys.path.append(str(REPO_DIR)) | |
| from models_vit import RETFound_mae # architecture builder | |
| from util.pos_embed import interpolate_pos_embed | |
| from timm.models.layers import trunc_normal_ | |
| def build_classifier(num_classes: int, | |
| base_repo: str, | |
| base_filename: str, | |
| global_pool: bool = True, | |
| drop_path_rate: float = 0.2, | |
| device: str | torch.device = "cpu") -> nn.Module: | |
| """Load RETFound MAE backbone, attach a linear head for classification, and | |
| load pre-trained weights (excluding mismatched head). | |
| """ | |
| device = torch.device(device) | |
| # 1) Download pretrained MAE weights from the Hub | |
| #ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename) | |
| # Read token from env (if set) | |
| hf_token = os.getenv("HF_TOKEN") | |
| try: | |
| ckpt_path = hf_hub_download( | |
| repo_id=base_repo, | |
| filename=base_filename, | |
| token=hf_token, # Works for private if token exists | |
| cache_dir="/tmp/hf_cache" # Spaces-friendly cache | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to download model from {base_repo}: {e}") | |
| # Load model weights | |
| print(f"Loading RETFound MAE weights from {ckpt_path}...") | |
| # 2) Build backbone | |
| model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate) | |
| # 3) Prepare head | |
| in_features = model.head.in_features | |
| model.head = nn.Linear(in_features, num_classes) | |
| # 4) Load checkpoint w/ position interpolation & head removal if mismatched | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| ckpt = checkpoint.get("model", checkpoint) # handle both formats | |
| # Remove head weights if shape mismatch | |
| state_dict = model.state_dict() | |
| for k in ["head.weight", "head.bias"]: | |
| if k in ckpt and k in state_dict and ckpt[k].shape != state_dict[k].shape: | |
| del ckpt[k] | |
| interpolate_pos_embed(model, ckpt) | |
| msg = model.load_state_dict(ckpt, strict=False) | |
| # Re-init head for classification | |
| trunc_normal_(model.head.weight, std=2e-5) | |
| if hasattr(model.head, 'bias') and model.head.bias is not None: | |
| nn.init.zeros_(model.head.bias) | |
| model.to(device) | |
| return model |