File size: 2,591 Bytes
87224ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
"""
Image Encoder using pre-trained ResNet50.
Implements the visual feature extraction module from the paper.
"""
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
class ImageEncoder(nn.Module):
"""
Image encoder using ResNet50 with custom final layer.
Critical: Final layer initialized with zeros as per paper.
"""
def __init__(self, config, pretrained_weights_path: str = None):
"""
Initialize image encoder.
Args:
config: Configuration object
pretrained_weights_path: Path to ResNet50 weights file
"""
super().__init__()
self.config = config
# Load ResNet50
self.resnet = resnet50(weights=None)
# Load pretrained weights if provided
if pretrained_weights_path:
state_dict = torch.load(pretrained_weights_path, weights_only = False)
self.resnet.load_state_dict(state_dict)
print(f"Loaded ResNet50 weights from {pretrained_weights_path}")
# Remove original FC layer
self.resnet.fc = nn.Identity()
# Add custom linear layer to map to hidden_dim
# CRITICAL: Initialize with zeros to prevent interference during early training
self.projection = nn.Linear(config.resnet_out_dim, config.hidden_dim)
nn.init.zeros_(self.projection.weight)
nn.init.zeros_(self.projection.bias)
print("Initialized image encoder final layer with zeros")
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Forward pass through ResNet50.
Args:
images: Input images [batch_size, num_masks, 1, H, W]
Returns:
Visual features [batch_size, num_masks, hidden_dim]
"""
batch_size, num_masks, C, H, W = images.shape
# Flatten batch and num_masks dimensions
images_flat = images.view(batch_size * num_masks, C, H, W)
# Convert grayscale to RGB by repeating channels
if C == 1:
images_flat = images_flat.repeat(1, 3, 1, 1)
# Extract features with ResNet50
features = self.resnet(images_flat) # [batch_size * num_masks, 2048]
# Project to hidden_dim
features = self.projection(features) # [batch_size * num_masks, hidden_dim]
# Reshape back
features = features.view(batch_size, num_masks, self.config.hidden_dim)
return features
|