File size: 3,964 Bytes
6a06696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import json
import torch
from torch import nn, Tensor
from loguru import logger
from pathlib import Path
from torchvision.transforms import ToTensor
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize
import vits
def _clean_moco_state_dict(state_dict: dict[str, Tensor], linear_keyword: str) -> dict[str, Tensor]:
"""
Filters and renames keys from a MoCo state_dict.
It selects keys from the 'base_encoder', removes the given linear layer keyword,
and strips the 'module.base_encoder.' prefix.
"""
for key in list(state_dict.keys()):
# Check if the key belongs to the base encoder's backbone
if key.startswith('module.base_encoder') and not key.startswith(f'module.base_encoder.{linear_keyword}'):
# Create a new key by stripping the prefix
new_key = key[len("module.base_encoder."):]
state_dict[new_key] = state_dict[key]
# Delete the old key (either renamed or unused)
del state_dict[key]
return state_dict
def load_moco_encoder(
model: nn.Module,
weight_path: Path,
linear_keyword: str,
) -> nn.Module:
"""
Loads pre-trained MoCo weights into a given model instance (ResNet, ViT, etc.).
This function handles loading the checkpoint, cleaning the state dictionary keys,
and loading the weights into the model's backbone. It finishes by replacing
the model's linear head with an Identity layer to turn it into a feature extractor.
Args:
model: An instantiated PyTorch model (e.g., from timm or a custom module).
weight_path: Path to the .pth or .pt MoCo checkpoint file.
linear_keyword: The name of the final linear layer to exclude (e.g., 'fc' or 'head').
Returns:
The same model, with pre-trained backbone weights and the head replaced
by nn.Identity(), ready for feature extraction.
"""
assert weight_path.exists(), f"Checkpoint not found at '{weight_path}'"
logger.info(f"=> Loading MoCo checkpoint from '{weight_path}'")
# Use weights_only=True for added security if the checkpoint doesn't contain pickled code
checkpoint = torch.load(weight_path, map_location="cpu", weights_only=True)
# Extract the state dictionary containing the model weights
state_dict = checkpoint["state_dict"]
# Clean the state_dict to match the model's architecture
cleaned_state_dict = _clean_moco_state_dict(state_dict, linear_keyword)
# Load the cleaned weights into the model
msg = model.load_state_dict(cleaned_state_dict, strict=False)
logger.info(msg)
logger.info("=> Successfully loaded pre-trained model backbone.")
# Replace the model's head to turn it into a feature extractor
if hasattr(model, linear_keyword):
setattr(model, linear_keyword, nn.Identity())
logger.info(f"=> Model's '{linear_keyword}' layer replaced with nn.Identity for feature extraction.")
return model
def get_vit_feature_extractor(weight_path: Path, model_name: str = "vits8", img_size: int = 40) -> nn.Module:
"""Creates a ViT feature extractor using the unified loader."""
# 1. Create the model architecture shell
vit_model = vits.__dict__[model_name](img_size=img_size, num_classes=0)
# 2. Use the unified function to load weights and prepare for feature extraction
feature_extractor = load_moco_encoder(
model=vit_model,
weight_path=weight_path,
linear_keyword='head'
)
return feature_extractor
def prepare_transform(
stats_path,
size: int = 40,
) -> Compose:
# Get normalisation stats
with open(stats_path, "r") as f:
norm_dict = json.load(f)
mean = norm_dict["mean"]
std = norm_dict["std"]
# Prepare transform
list_transform = [
ToTensor(),
Normalize(mean=mean, std=std),
CenterCrop(size=size),
]
transform = Compose(list_transform)
return transform
|