|
|
import json |
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
from loguru import logger |
|
|
from pathlib import Path |
|
|
|
|
|
from torchvision.transforms import ToTensor |
|
|
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize |
|
|
|
|
|
|
|
|
import vits |
|
|
|
|
|
def _clean_moco_state_dict(state_dict: dict[str, Tensor], linear_keyword: str) -> dict[str, Tensor]: |
|
|
""" |
|
|
Filters and renames keys from a MoCo state_dict. |
|
|
|
|
|
It selects keys from the 'base_encoder', removes the given linear layer keyword, |
|
|
and strips the 'module.base_encoder.' prefix. |
|
|
""" |
|
|
for key in list(state_dict.keys()): |
|
|
|
|
|
if key.startswith('module.base_encoder') and not key.startswith(f'module.base_encoder.{linear_keyword}'): |
|
|
|
|
|
new_key = key[len("module.base_encoder."):] |
|
|
state_dict[new_key] = state_dict[key] |
|
|
|
|
|
|
|
|
del state_dict[key] |
|
|
|
|
|
return state_dict |
|
|
|
|
|
def load_moco_encoder( |
|
|
model: nn.Module, |
|
|
weight_path: Path, |
|
|
linear_keyword: str, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Loads pre-trained MoCo weights into a given model instance (ResNet, ViT, etc.). |
|
|
|
|
|
This function handles loading the checkpoint, cleaning the state dictionary keys, |
|
|
and loading the weights into the model's backbone. It finishes by replacing |
|
|
the model's linear head with an Identity layer to turn it into a feature extractor. |
|
|
|
|
|
Args: |
|
|
model: An instantiated PyTorch model (e.g., from timm or a custom module). |
|
|
weight_path: Path to the .pth or .pt MoCo checkpoint file. |
|
|
linear_keyword: The name of the final linear layer to exclude (e.g., 'fc' or 'head'). |
|
|
|
|
|
Returns: |
|
|
The same model, with pre-trained backbone weights and the head replaced |
|
|
by nn.Identity(), ready for feature extraction. |
|
|
""" |
|
|
assert weight_path.exists(), f"Checkpoint not found at '{weight_path}'" |
|
|
logger.info(f"=> Loading MoCo checkpoint from '{weight_path}'") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(weight_path, map_location="cpu", weights_only=True) |
|
|
|
|
|
|
|
|
state_dict = checkpoint["state_dict"] |
|
|
|
|
|
|
|
|
cleaned_state_dict = _clean_moco_state_dict(state_dict, linear_keyword) |
|
|
|
|
|
|
|
|
msg = model.load_state_dict(cleaned_state_dict, strict=False) |
|
|
logger.info(msg) |
|
|
logger.info("=> Successfully loaded pre-trained model backbone.") |
|
|
|
|
|
|
|
|
if hasattr(model, linear_keyword): |
|
|
setattr(model, linear_keyword, nn.Identity()) |
|
|
logger.info(f"=> Model's '{linear_keyword}' layer replaced with nn.Identity for feature extraction.") |
|
|
|
|
|
return model |
|
|
|
|
|
def get_vit_feature_extractor(weight_path: Path, model_name: str = "vits8", img_size: int = 40) -> nn.Module: |
|
|
"""Creates a ViT feature extractor using the unified loader.""" |
|
|
|
|
|
vit_model = vits.__dict__[model_name](img_size=img_size, num_classes=0) |
|
|
|
|
|
|
|
|
feature_extractor = load_moco_encoder( |
|
|
model=vit_model, |
|
|
weight_path=weight_path, |
|
|
linear_keyword='head' |
|
|
) |
|
|
return feature_extractor |
|
|
|
|
|
|
|
|
def prepare_transform( |
|
|
stats_path, |
|
|
size: int = 40, |
|
|
) -> Compose: |
|
|
|
|
|
with open(stats_path, "r") as f: |
|
|
norm_dict = json.load(f) |
|
|
mean = norm_dict["mean"] |
|
|
std = norm_dict["std"] |
|
|
|
|
|
|
|
|
list_transform = [ |
|
|
ToTensor(), |
|
|
Normalize(mean=mean, std=std), |
|
|
CenterCrop(size=size), |
|
|
] |
|
|
transform = Compose(list_transform) |
|
|
return transform |
|
|
|