MMRM / models /decoders.py
rexera's picture
0-shot pipeline test
87224ba
"""
Text and Image Decoders for MMRM.
"""
import torch
import torch.nn as nn
class TextDecoder(nn.Module):
"""
Text decoder: MLP layer for character prediction.
Initialized with RoBERTa's LM head parameters.
"""
def __init__(self, config, roberta_lm_head: nn.Module = None, shared_embedding: nn.Module = None):
"""
Initialize text decoder.
Args:
config: Configuration object
roberta_lm_head: Optional RoBERTa LM head to copy parameters from
shared_embedding: Optional embedding layer to tie weights with
"""
super().__init__()
self.config = config
# Single MLP layer mapping hidden_dim to vocabulary
self.decoder = nn.Linear(config.hidden_dim, config.vocab_size)
# Tie weights if shared_embedding is provided
if shared_embedding is not None:
self.decoder.weight = shared_embedding.weight
print(" Tied TextDecoder weights to ContextEncoder input embeddings.")
# Initialize with RoBERTa LM head if provided and NOT tying weights (or just bias if tying)
elif roberta_lm_head is not None:
with torch.no_grad():
self.decoder.weight.copy_(roberta_lm_head.weight)
self.decoder.bias.copy_(roberta_lm_head.bias)
print("Initialized text decoder with RoBERTa LM head parameters")
# If tying weights, we might still want to initialize bias from LM head if available
if shared_embedding is not None and roberta_lm_head is not None:
with torch.no_grad():
self.decoder.bias.copy_(roberta_lm_head.bias)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Predict character logits.
Args:
features: Fused features [batch_size, num_masks, hidden_dim]
Returns:
Logits over vocabulary [batch_size, num_masks, vocab_size]
"""
return self.decoder(features)
class ImageDecoder(nn.Module):
"""
Image decoder: 5 transposed convolution layers to generate 64x64 images.
Implements image restoration task.
"""
def __init__(self, config):
"""
Initialize image decoder.
Args:
config: Configuration object
"""
super().__init__()
self.config = config
# Map hidden_dim to spatial features
# hidden_dim -> 512 * 4 * 4
self.fc = nn.Linear(config.hidden_dim, 512 * 4 * 4)
# 5 transposed convolution layers to generate 64x64 image
# 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64 -> 64x64
self.deconv_layers = nn.Sequential(
# Layer 1: 4x4 -> 8x8
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# Layer 2: 8x8 -> 16x16
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# Layer 3: 16x16 -> 32x32
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# Layer 4: 32x32 -> 64x64
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# Layer 5: 64x64 -> 64x64 (refinement)
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # Output in [0, 1]
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Generate restored images.
Args:
features: Fused features [batch_size, num_masks, hidden_dim]
Returns:
Reconstructed images [batch_size, num_masks, 1, 64, 64]
"""
batch_size, num_masks, hidden_dim = features.shape
# Flatten batch and num_masks
features_flat = features.view(batch_size * num_masks, hidden_dim)
# Project to spatial features
spatial_features = self.fc(features_flat) # [B*N, 512*4*4]
spatial_features = spatial_features.view(batch_size * num_masks, 512, 4, 4)
# Apply deconvolution layers
images = self.deconv_layers(spatial_features) # [B*N, 1, 64, 64]
# Reshape back
images = images.view(batch_size, num_masks, 1, 64, 64)
return images