File size: 2,521 Bytes
d8770e8
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8770e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39ec591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import sys
from pathlib import Path

# Add RETFound repo to path for imports
REPO_DIR = Path(__file__).parent / "RETFound_MAE"
sys.path.append(str(REPO_DIR))
from models_vit import RETFound_mae  # architecture builder
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_


def build_classifier(num_classes: int,
                     base_repo: str,
                     base_filename: str,
                     global_pool: bool = True,
                     drop_path_rate: float = 0.2,
                     device: str | torch.device = "cpu") -> nn.Module:
    """Load RETFound MAE backbone, attach a linear head for classification, and
    load pre-trained weights (excluding mismatched head).
    """
    device = torch.device(device)

    # 1) Download pretrained MAE weights from the Hub
    #ckpt_path = hf_hub_download(repo_id=base_repo, filename=base_filename)
    # Read token from env (if set)
    hf_token = os.getenv("HF_TOKEN")

    try:
        ckpt_path = hf_hub_download(
            repo_id=base_repo,
            filename=base_filename,
            token=hf_token,               # Works for private if token exists
            cache_dir="/tmp/hf_cache"      # Spaces-friendly cache
        )
    except Exception as e:
        raise RuntimeError(f"Failed to download model from {base_repo}: {e}")

    # Load model weights
    print(f"Loading RETFound MAE weights from {ckpt_path}...")

    # 2) Build backbone
    model = RETFound_mae(global_pool=global_pool, drop_path_rate=drop_path_rate)

    # 3) Prepare head
    in_features = model.head.in_features
    model.head = nn.Linear(in_features, num_classes)

    # 4) Load checkpoint w/ position interpolation & head removal if mismatched
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    ckpt = checkpoint.get("model", checkpoint)  # handle both formats

    # Remove head weights if shape mismatch
    state_dict = model.state_dict()
    for k in ["head.weight", "head.bias"]:
        if k in ckpt and k in state_dict and ckpt[k].shape != state_dict[k].shape:
            del ckpt[k]

    interpolate_pos_embed(model, ckpt)
    msg = model.load_state_dict(ckpt, strict=False)
    # Re-init head for classification
    trunc_normal_(model.head.weight, std=2e-5)
    if hasattr(model.head, 'bias') and model.head.bias is not None:
        nn.init.zeros_(model.head.bias)

    model.to(device)
    return model