File size: 4,685 Bytes
87224ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
Text and Image Decoders for MMRM.
"""
import torch
import torch.nn as nn
class TextDecoder(nn.Module):
"""
Text decoder: MLP layer for character prediction.
Initialized with RoBERTa's LM head parameters.
"""
def __init__(self, config, roberta_lm_head: nn.Module = None, shared_embedding: nn.Module = None):
"""
Initialize text decoder.
Args:
config: Configuration object
roberta_lm_head: Optional RoBERTa LM head to copy parameters from
shared_embedding: Optional embedding layer to tie weights with
"""
super().__init__()
self.config = config
# Single MLP layer mapping hidden_dim to vocabulary
self.decoder = nn.Linear(config.hidden_dim, config.vocab_size)
# Tie weights if shared_embedding is provided
if shared_embedding is not None:
self.decoder.weight = shared_embedding.weight
print(" Tied TextDecoder weights to ContextEncoder input embeddings.")
# Initialize with RoBERTa LM head if provided and NOT tying weights (or just bias if tying)
elif roberta_lm_head is not None:
with torch.no_grad():
self.decoder.weight.copy_(roberta_lm_head.weight)
self.decoder.bias.copy_(roberta_lm_head.bias)
print("Initialized text decoder with RoBERTa LM head parameters")
# If tying weights, we might still want to initialize bias from LM head if available
if shared_embedding is not None and roberta_lm_head is not None:
with torch.no_grad():
self.decoder.bias.copy_(roberta_lm_head.bias)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Predict character logits.
Args:
features: Fused features [batch_size, num_masks, hidden_dim]
Returns:
Logits over vocabulary [batch_size, num_masks, vocab_size]
"""
return self.decoder(features)
class ImageDecoder(nn.Module):
"""
Image decoder: 5 transposed convolution layers to generate 64x64 images.
Implements image restoration task.
"""
def __init__(self, config):
"""
Initialize image decoder.
Args:
config: Configuration object
"""
super().__init__()
self.config = config
# Map hidden_dim to spatial features
# hidden_dim -> 512 * 4 * 4
self.fc = nn.Linear(config.hidden_dim, 512 * 4 * 4)
# 5 transposed convolution layers to generate 64x64 image
# 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64 -> 64x64
self.deconv_layers = nn.Sequential(
# Layer 1: 4x4 -> 8x8
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# Layer 2: 8x8 -> 16x16
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# Layer 3: 16x16 -> 32x32
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# Layer 4: 32x32 -> 64x64
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
# Layer 5: 64x64 -> 64x64 (refinement)
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # Output in [0, 1]
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Generate restored images.
Args:
features: Fused features [batch_size, num_masks, hidden_dim]
Returns:
Reconstructed images [batch_size, num_masks, 1, 64, 64]
"""
batch_size, num_masks, hidden_dim = features.shape
# Flatten batch and num_masks
features_flat = features.view(batch_size * num_masks, hidden_dim)
# Project to spatial features
spatial_features = self.fc(features_flat) # [B*N, 512*4*4]
spatial_features = spatial_features.view(batch_size * num_masks, 512, 4, 4)
# Apply deconvolution layers
images = self.deconv_layers(spatial_features) # [B*N, 1, 64, 64]
# Reshape back
images = images.view(batch_size, num_masks, 1, 64, 64)
return images
|