""" Feature extraction using pretrained ResNet backbone """ import torch import torch.nn as nn from torchvision.models import resnet18, ResNet18_Weights from typing import List, Dict import config class FeatureExtractor(nn.Module): """Extract multi-scale features from ResNet backbone""" def __init__(self, backbone: str = "resnet18", layers: List[str] = None): """ Args: backbone: Backbone architecture name layers: List of layer names to extract features from """ super().__init__() if layers is None: layers = config.FEATURE_LAYERS self.layers = layers # Load pretrained ResNet if backbone == "resnet18": model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) else: raise ValueError(f"Unsupported backbone: {backbone}") # Register hooks to extract intermediate features self.feature_maps = {} self.hooks = [] for name, module in model.named_children(): if name in self.layers: hook = module.register_forward_hook(self._get_hook(name)) self.hooks.append(hook) # Keep only the feature extraction part (remove classifier) self.model = model self.model.eval() # Freeze all parameters for param in self.model.parameters(): param.requires_grad = False def _get_hook(self, name: str): """Create forward hook to capture intermediate features""" def hook(module, input, output): self.feature_maps[name] = output return hook def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Extract features from multiple layers Args: x: Input tensor [B, 3, H, W] Returns: Dictionary mapping layer names to feature tensors """ self.feature_maps.clear() with torch.no_grad(): _ = self.model(x) return self.feature_maps def get_feature_dimensions(self) -> Dict[str, int]: """Get output dimensions for each layer""" dummy_input = torch.randn(1, 3, *config.IMAGE_SIZE) features = self.forward(dummy_input) dimensions = {} for name, feat in features.items(): dimensions[name] = { 'channels': feat.shape[1], 'height': feat.shape[2], 'width': feat.shape[3] } return dimensions def __del__(self): """Remove hooks when object is destroyed""" for hook in self.hooks: hook.remove() def extract_embeddings(extractor: FeatureExtractor, x: torch.Tensor) -> torch.Tensor: """ Extract and concatenate multi-scale embeddings Args: extractor: Feature extractor model x: Input tensor [B, 3, H, W] Returns: Concatenated embeddings [B, D, H', W'] where D is total channels """ feature_maps = extractor(x) # Get target spatial dimensions from the largest feature map target_size = None for name in extractor.layers: feat = feature_maps[name] if target_size is None or (feat.shape[2] > target_size[0]): target_size = (feat.shape[2], feat.shape[3]) # Resize all features to same spatial dimensions and concatenate aligned_features = [] for name in extractor.layers: feat = feature_maps[name] if feat.shape[2:] != target_size: feat = torch.nn.functional.interpolate( feat, size=target_size, mode='bilinear', align_corners=False ) aligned_features.append(feat) # Concatenate along channel dimension embeddings = torch.cat(aligned_features, dim=1) return embeddings