| | import json |
| | import torch |
| | from torch import nn, Tensor |
| | from loguru import logger |
| | from pathlib import Path |
| |
|
| | from torchvision.transforms import ToTensor |
| | from torchvision.transforms.v2 import CenterCrop, Compose, Normalize |
| |
|
| |
|
| | import vits |
| |
|
| | def _clean_moco_state_dict(state_dict: dict[str, Tensor], linear_keyword: str) -> dict[str, Tensor]: |
| | """ |
| | Filters and renames keys from a MoCo state_dict. |
| | |
| | It selects keys from the 'base_encoder', removes the given linear layer keyword, |
| | and strips the 'module.base_encoder.' prefix. |
| | """ |
| | for key in list(state_dict.keys()): |
| | |
| | if key.startswith('module.base_encoder') and not key.startswith(f'module.base_encoder.{linear_keyword}'): |
| | |
| | new_key = key[len("module.base_encoder."):] |
| | state_dict[new_key] = state_dict[key] |
| |
|
| | |
| | del state_dict[key] |
| |
|
| | return state_dict |
| |
|
| | def load_moco_encoder( |
| | model: nn.Module, |
| | weight_path: Path, |
| | linear_keyword: str, |
| | ) -> nn.Module: |
| | """ |
| | Loads pre-trained MoCo weights into a given model instance (ResNet, ViT, etc.). |
| | |
| | This function handles loading the checkpoint, cleaning the state dictionary keys, |
| | and loading the weights into the model's backbone. It finishes by replacing |
| | the model's linear head with an Identity layer to turn it into a feature extractor. |
| | |
| | Args: |
| | model: An instantiated PyTorch model (e.g., from timm or a custom module). |
| | weight_path: Path to the .pth or .pt MoCo checkpoint file. |
| | linear_keyword: The name of the final linear layer to exclude (e.g., 'fc' or 'head'). |
| | |
| | Returns: |
| | The same model, with pre-trained backbone weights and the head replaced |
| | by nn.Identity(), ready for feature extraction. |
| | """ |
| | assert weight_path.exists(), f"Checkpoint not found at '{weight_path}'" |
| | logger.info(f"=> Loading MoCo checkpoint from '{weight_path}'") |
| |
|
| | |
| | checkpoint = torch.load(weight_path, map_location="cpu", weights_only=True) |
| |
|
| | |
| | state_dict = checkpoint["state_dict"] |
| |
|
| | |
| | cleaned_state_dict = _clean_moco_state_dict(state_dict, linear_keyword) |
| |
|
| | |
| | msg = model.load_state_dict(cleaned_state_dict, strict=False) |
| | logger.info(msg) |
| | logger.info("=> Successfully loaded pre-trained model backbone.") |
| |
|
| | |
| | if hasattr(model, linear_keyword): |
| | setattr(model, linear_keyword, nn.Identity()) |
| | logger.info(f"=> Model's '{linear_keyword}' layer replaced with nn.Identity for feature extraction.") |
| |
|
| | return model |
| |
|
| | def get_vit_feature_extractor(weight_path: Path, model_name: str = "vits8", img_size: int = 40) -> nn.Module: |
| | """Creates a ViT feature extractor using the unified loader.""" |
| | |
| | vit_model = vits.__dict__[model_name](img_size=img_size, num_classes=0) |
| |
|
| | |
| | feature_extractor = load_moco_encoder( |
| | model=vit_model, |
| | weight_path=weight_path, |
| | linear_keyword='head' |
| | ) |
| | return feature_extractor |
| |
|
| |
|
| | def prepare_transform( |
| | stats_path, |
| | size: int = 40, |
| | ) -> Compose: |
| | |
| | with open(stats_path, "r") as f: |
| | norm_dict = json.load(f) |
| | mean = norm_dict["mean"] |
| | std = norm_dict["std"] |
| |
|
| | |
| | list_transform = [ |
| | ToTensor(), |
| | Normalize(mean=mean, std=std), |
| | CenterCrop(size=size), |
| | ] |
| | transform = Compose(list_transform) |
| | return transform |
| |
|