File size: 4,685 Bytes
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Text and Image Decoders for MMRM.
"""

import torch
import torch.nn as nn


class TextDecoder(nn.Module):
    """
    Text decoder: MLP layer for character prediction.
    Initialized with RoBERTa's LM head parameters.
    """
    
    def __init__(self, config, roberta_lm_head: nn.Module = None, shared_embedding: nn.Module = None):
        """
        Initialize text decoder.
        
        Args:
            config: Configuration object
            roberta_lm_head: Optional RoBERTa LM head to copy parameters from
            shared_embedding: Optional embedding layer to tie weights with
        """
        super().__init__()
        self.config = config
        
        # Single MLP layer mapping hidden_dim to vocabulary
        self.decoder = nn.Linear(config.hidden_dim, config.vocab_size)
        
        # Tie weights if shared_embedding is provided
        if shared_embedding is not None:
             self.decoder.weight = shared_embedding.weight
             print(" Tied TextDecoder weights to ContextEncoder input embeddings.")
        # Initialize with RoBERTa LM head if provided and NOT tying weights (or just bias if tying)
        elif roberta_lm_head is not None:
            with torch.no_grad():
                self.decoder.weight.copy_(roberta_lm_head.weight)
                self.decoder.bias.copy_(roberta_lm_head.bias)
            print("Initialized text decoder with RoBERTa LM head parameters")
            
        # If tying weights, we might still want to initialize bias from LM head if available
        if shared_embedding is not None and roberta_lm_head is not None:
             with torch.no_grad():
                 self.decoder.bias.copy_(roberta_lm_head.bias)
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Predict character logits.
        
        Args:
            features: Fused features [batch_size, num_masks, hidden_dim]
            
        Returns:
            Logits over vocabulary [batch_size, num_masks, vocab_size]
        """
        return self.decoder(features)


class ImageDecoder(nn.Module):
    """
    Image decoder: 5 transposed convolution layers to generate 64x64 images.
    Implements image restoration task.
    """
    
    def __init__(self, config):
        """
        Initialize image decoder.
        
        Args:
            config: Configuration object
        """
        super().__init__()
        self.config = config
        
        # Map hidden_dim to spatial features
        # hidden_dim -> 512 * 4 * 4
        self.fc = nn.Linear(config.hidden_dim, 512 * 4 * 4)
        
        # 5 transposed convolution layers to generate 64x64 image
        # 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64 -> 64x64
        self.deconv_layers = nn.Sequential(
            # Layer 1: 4x4 -> 8x8
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            # Layer 2: 8x8 -> 16x16
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # Layer 3: 16x16 -> 32x32
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Layer 4: 32x32 -> 64x64
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            # Layer 5: 64x64 -> 64x64 (refinement)
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # Output in [0, 1]
        )
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Generate restored images.
        
        Args:
            features: Fused features [batch_size, num_masks, hidden_dim]
            
        Returns:
            Reconstructed images [batch_size, num_masks, 1, 64, 64]
        """
        batch_size, num_masks, hidden_dim = features.shape
        
        # Flatten batch and num_masks
        features_flat = features.view(batch_size * num_masks, hidden_dim)
        
        # Project to spatial features
        spatial_features = self.fc(features_flat)  # [B*N, 512*4*4]
        spatial_features = spatial_features.view(batch_size * num_masks, 512, 4, 4)
        
        # Apply deconvolution layers
        images = self.deconv_layers(spatial_features)  # [B*N, 1, 64, 64]
        
        # Reshape back
        images = images.view(batch_size, num_masks, 1, 64, 64)
        
        return images