""" Image Encoder using pre-trained ResNet50. Implements the visual feature extraction module from the paper. """ import torch import torch.nn as nn from torchvision.models import resnet50, ResNet50_Weights class ImageEncoder(nn.Module): """ Image encoder using ResNet50 with custom final layer. Critical: Final layer initialized with zeros as per paper. """ def __init__(self, config, pretrained_weights_path: str = None): """ Initialize image encoder. Args: config: Configuration object pretrained_weights_path: Path to ResNet50 weights file """ super().__init__() self.config = config # Load ResNet50 self.resnet = resnet50(weights=None) # Load pretrained weights if provided if pretrained_weights_path: state_dict = torch.load(pretrained_weights_path, weights_only = False) self.resnet.load_state_dict(state_dict) print(f"Loaded ResNet50 weights from {pretrained_weights_path}") # Remove original FC layer self.resnet.fc = nn.Identity() # Add custom linear layer to map to hidden_dim # CRITICAL: Initialize with zeros to prevent interference during early training self.projection = nn.Linear(config.resnet_out_dim, config.hidden_dim) nn.init.zeros_(self.projection.weight) nn.init.zeros_(self.projection.bias) print("Initialized image encoder final layer with zeros") def forward(self, images: torch.Tensor) -> torch.Tensor: """ Forward pass through ResNet50. Args: images: Input images [batch_size, num_masks, 1, H, W] Returns: Visual features [batch_size, num_masks, hidden_dim] """ batch_size, num_masks, C, H, W = images.shape # Flatten batch and num_masks dimensions images_flat = images.view(batch_size * num_masks, C, H, W) # Convert grayscale to RGB by repeating channels if C == 1: images_flat = images_flat.repeat(1, 3, 1, 1) # Extract features with ResNet50 features = self.resnet(images_flat) # [batch_size * num_masks, 2048] # Project to hidden_dim features = self.projection(features) # [batch_size * num_masks, hidden_dim] # Reshape back features = features.view(batch_size, num_masks, self.config.hidden_dim) return features