File size: 2,591 Bytes
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Image Encoder using pre-trained ResNet50.
Implements the visual feature extraction module from the paper.
"""

import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights


class ImageEncoder(nn.Module):
    """
    Image encoder using ResNet50 with custom final layer.
    Critical: Final layer initialized with zeros as per paper.
    """
    
    def __init__(self, config, pretrained_weights_path: str = None):
        """
        Initialize image encoder.
        
        Args:
            config: Configuration object
            pretrained_weights_path: Path to ResNet50 weights file
        """
        super().__init__()
        self.config = config
        
        # Load ResNet50
        self.resnet = resnet50(weights=None)
        
        # Load pretrained weights if provided
        if pretrained_weights_path:
            state_dict = torch.load(pretrained_weights_path,  weights_only = False)
            self.resnet.load_state_dict(state_dict)
            print(f"Loaded ResNet50 weights from {pretrained_weights_path}")
        
        # Remove original FC layer
        self.resnet.fc = nn.Identity()
        
        # Add custom linear layer to map to hidden_dim
        # CRITICAL: Initialize with zeros to prevent interference during early training
        self.projection = nn.Linear(config.resnet_out_dim, config.hidden_dim)
        nn.init.zeros_(self.projection.weight)
        nn.init.zeros_(self.projection.bias)
        
        print("Initialized image encoder final layer with zeros")
    
    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through ResNet50.
        
        Args:
            images: Input images [batch_size, num_masks, 1, H, W]
            
        Returns:
            Visual features [batch_size, num_masks, hidden_dim]
        """
        batch_size, num_masks, C, H, W = images.shape
        
        # Flatten batch and num_masks dimensions
        images_flat = images.view(batch_size * num_masks, C, H, W)
        
        # Convert grayscale to RGB by repeating channels
        if C == 1:
            images_flat = images_flat.repeat(1, 3, 1, 1)
        
        # Extract features with ResNet50
        features = self.resnet(images_flat)  # [batch_size * num_masks, 2048]
        
        # Project to hidden_dim
        features = self.projection(features)  # [batch_size * num_masks, hidden_dim]
        
        # Reshape back
        features = features.view(batch_size, num_masks, self.config.hidden_dim)
        
        return features