|
|
""" |
|
|
Text and Image Decoders for MMRM. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class TextDecoder(nn.Module): |
|
|
""" |
|
|
Text decoder: MLP layer for character prediction. |
|
|
Initialized with RoBERTa's LM head parameters. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, roberta_lm_head: nn.Module = None, shared_embedding: nn.Module = None): |
|
|
""" |
|
|
Initialize text decoder. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
roberta_lm_head: Optional RoBERTa LM head to copy parameters from |
|
|
shared_embedding: Optional embedding layer to tie weights with |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.decoder = nn.Linear(config.hidden_dim, config.vocab_size) |
|
|
|
|
|
|
|
|
if shared_embedding is not None: |
|
|
self.decoder.weight = shared_embedding.weight |
|
|
print(" Tied TextDecoder weights to ContextEncoder input embeddings.") |
|
|
|
|
|
elif roberta_lm_head is not None: |
|
|
with torch.no_grad(): |
|
|
self.decoder.weight.copy_(roberta_lm_head.weight) |
|
|
self.decoder.bias.copy_(roberta_lm_head.bias) |
|
|
print("Initialized text decoder with RoBERTa LM head parameters") |
|
|
|
|
|
|
|
|
if shared_embedding is not None and roberta_lm_head is not None: |
|
|
with torch.no_grad(): |
|
|
self.decoder.bias.copy_(roberta_lm_head.bias) |
|
|
|
|
|
def forward(self, features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Predict character logits. |
|
|
|
|
|
Args: |
|
|
features: Fused features [batch_size, num_masks, hidden_dim] |
|
|
|
|
|
Returns: |
|
|
Logits over vocabulary [batch_size, num_masks, vocab_size] |
|
|
""" |
|
|
return self.decoder(features) |
|
|
|
|
|
|
|
|
class ImageDecoder(nn.Module): |
|
|
""" |
|
|
Image decoder: 5 transposed convolution layers to generate 64x64 images. |
|
|
Implements image restoration task. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Initialize image decoder. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
|
|
|
self.fc = nn.Linear(config.hidden_dim, 512 * 4 * 4) |
|
|
|
|
|
|
|
|
|
|
|
self.deconv_layers = nn.Sequential( |
|
|
|
|
|
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
|
|
|
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Generate restored images. |
|
|
|
|
|
Args: |
|
|
features: Fused features [batch_size, num_masks, hidden_dim] |
|
|
|
|
|
Returns: |
|
|
Reconstructed images [batch_size, num_masks, 1, 64, 64] |
|
|
""" |
|
|
batch_size, num_masks, hidden_dim = features.shape |
|
|
|
|
|
|
|
|
features_flat = features.view(batch_size * num_masks, hidden_dim) |
|
|
|
|
|
|
|
|
spatial_features = self.fc(features_flat) |
|
|
spatial_features = spatial_features.view(batch_size * num_masks, 512, 4, 4) |
|
|
|
|
|
|
|
|
images = self.deconv_layers(spatial_features) |
|
|
|
|
|
|
|
|
images = images.view(batch_size, num_masks, 1, 64, 64) |
|
|
|
|
|
return images |
|
|
|