File size: 3,964 Bytes
6a06696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import torch
from torch import nn, Tensor
from loguru import logger
from pathlib import Path

from torchvision.transforms import ToTensor
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize


import vits

def _clean_moco_state_dict(state_dict: dict[str, Tensor], linear_keyword: str) -> dict[str, Tensor]:
    """
    Filters and renames keys from a MoCo state_dict.

    It selects keys from the 'base_encoder', removes the given linear layer keyword,
    and strips the 'module.base_encoder.' prefix.
    """
    for key in list(state_dict.keys()):
        # Check if the key belongs to the base encoder's backbone
        if key.startswith('module.base_encoder') and not key.startswith(f'module.base_encoder.{linear_keyword}'):
            # Create a new key by stripping the prefix
            new_key = key[len("module.base_encoder."):]
            state_dict[new_key] = state_dict[key]

        # Delete the old key (either renamed or unused)
        del state_dict[key]

    return state_dict

def load_moco_encoder(
        model: nn.Module,
        weight_path: Path,
        linear_keyword: str,
) -> nn.Module:
    """
    Loads pre-trained MoCo weights into a given model instance (ResNet, ViT, etc.).

    This function handles loading the checkpoint, cleaning the state dictionary keys,
    and loading the weights into the model's backbone. It finishes by replacing
    the model's linear head with an Identity layer to turn it into a feature extractor.

    Args:
        model: An instantiated PyTorch model (e.g., from timm or a custom module).
        weight_path: Path to the .pth or .pt MoCo checkpoint file.
        linear_keyword: The name of the final linear layer to exclude (e.g., 'fc' or 'head').

    Returns:
        The same model, with pre-trained backbone weights and the head replaced
        by nn.Identity(), ready for feature extraction.
    """
    assert weight_path.exists(), f"Checkpoint not found at '{weight_path}'"
    logger.info(f"=> Loading MoCo checkpoint from '{weight_path}'")

    # Use weights_only=True for added security if the checkpoint doesn't contain pickled code
    checkpoint = torch.load(weight_path, map_location="cpu", weights_only=True)

    # Extract the state dictionary containing the model weights
    state_dict = checkpoint["state_dict"]

    # Clean the state_dict to match the model's architecture
    cleaned_state_dict = _clean_moco_state_dict(state_dict, linear_keyword)

    # Load the cleaned weights into the model
    msg = model.load_state_dict(cleaned_state_dict, strict=False)
    logger.info(msg)
    logger.info("=> Successfully loaded pre-trained model backbone.")

    # Replace the model's head to turn it into a feature extractor
    if hasattr(model, linear_keyword):
        setattr(model, linear_keyword, nn.Identity())
        logger.info(f"=> Model's '{linear_keyword}' layer replaced with nn.Identity for feature extraction.")

    return model

def get_vit_feature_extractor(weight_path: Path, model_name: str = "vits8", img_size: int = 40) -> nn.Module:
    """Creates a ViT feature extractor using the unified loader."""
    # 1. Create the model architecture shell
    vit_model = vits.__dict__[model_name](img_size=img_size, num_classes=0)

    # 2. Use the unified function to load weights and prepare for feature extraction
    feature_extractor = load_moco_encoder(
        model=vit_model,
        weight_path=weight_path,
        linear_keyword='head'
    )
    return feature_extractor


def prepare_transform(
        stats_path,
        size: int = 40,
) -> Compose:
    # Get normalisation stats
    with open(stats_path, "r") as f:
        norm_dict = json.load(f)
    mean = norm_dict["mean"]
    std = norm_dict["std"]

    # Prepare transform
    list_transform = [
        ToTensor(),
        Normalize(mean=mean, std=std),
        CenterCrop(size=size),
    ]
    transform = Compose(list_transform)
    return transform