""" Text and Image Decoders for MMRM. """ import torch import torch.nn as nn class TextDecoder(nn.Module): """ Text decoder: MLP layer for character prediction. Initialized with RoBERTa's LM head parameters. """ def __init__(self, config, roberta_lm_head: nn.Module = None, shared_embedding: nn.Module = None): """ Initialize text decoder. Args: config: Configuration object roberta_lm_head: Optional RoBERTa LM head to copy parameters from shared_embedding: Optional embedding layer to tie weights with """ super().__init__() self.config = config # Single MLP layer mapping hidden_dim to vocabulary self.decoder = nn.Linear(config.hidden_dim, config.vocab_size) # Tie weights if shared_embedding is provided if shared_embedding is not None: self.decoder.weight = shared_embedding.weight print(" Tied TextDecoder weights to ContextEncoder input embeddings.") # Initialize with RoBERTa LM head if provided and NOT tying weights (or just bias if tying) elif roberta_lm_head is not None: with torch.no_grad(): self.decoder.weight.copy_(roberta_lm_head.weight) self.decoder.bias.copy_(roberta_lm_head.bias) print("Initialized text decoder with RoBERTa LM head parameters") # If tying weights, we might still want to initialize bias from LM head if available if shared_embedding is not None and roberta_lm_head is not None: with torch.no_grad(): self.decoder.bias.copy_(roberta_lm_head.bias) def forward(self, features: torch.Tensor) -> torch.Tensor: """ Predict character logits. Args: features: Fused features [batch_size, num_masks, hidden_dim] Returns: Logits over vocabulary [batch_size, num_masks, vocab_size] """ return self.decoder(features) class ImageDecoder(nn.Module): """ Image decoder: 5 transposed convolution layers to generate 64x64 images. Implements image restoration task. """ def __init__(self, config): """ Initialize image decoder. Args: config: Configuration object """ super().__init__() self.config = config # Map hidden_dim to spatial features # hidden_dim -> 512 * 4 * 4 self.fc = nn.Linear(config.hidden_dim, 512 * 4 * 4) # 5 transposed convolution layers to generate 64x64 image # 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64 -> 64x64 self.deconv_layers = nn.Sequential( # Layer 1: 4x4 -> 8x8 nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), # Layer 2: 8x8 -> 16x16 nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # Layer 3: 16x16 -> 32x32 nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), # Layer 4: 32x32 -> 64x64 nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), # Layer 5: 64x64 -> 64x64 (refinement) nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid() # Output in [0, 1] ) def forward(self, features: torch.Tensor) -> torch.Tensor: """ Generate restored images. Args: features: Fused features [batch_size, num_masks, hidden_dim] Returns: Reconstructed images [batch_size, num_masks, 1, 64, 64] """ batch_size, num_masks, hidden_dim = features.shape # Flatten batch and num_masks features_flat = features.view(batch_size * num_masks, hidden_dim) # Project to spatial features spatial_features = self.fc(features_flat) # [B*N, 512*4*4] spatial_features = spatial_features.view(batch_size * num_masks, 512, 4, 4) # Apply deconvolution layers images = self.deconv_layers(spatial_features) # [B*N, 1, 64, 64] # Reshape back images = images.view(batch_size, num_masks, 1, 64, 64) return images