|
|
""" |
|
|
Complete Multimodal Multitask Restoring Model (MMRM). |
|
|
Combines context encoder, image encoder, fusion, and decoders. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
from models.context_encoder import ContextEncoder |
|
|
from models.image_encoder import ImageEncoder |
|
|
from models.decoders import TextDecoder, ImageDecoder |
|
|
|
|
|
|
|
|
class MMRM(nn.Module): |
|
|
""" |
|
|
Multimodal Multitask Restoring Model. |
|
|
|
|
|
Architecture: |
|
|
1. Context Encoder (RoBERTa) extracts textual features |
|
|
2. Image Encoder (ResNet50) extracts visual features |
|
|
3. Additive Fusion combines features |
|
|
4. Text Decoder predicts missing characters |
|
|
5. Image Decoder generates restored images |
|
|
""" |
|
|
|
|
|
def __init__(self, config, pretrained_roberta_path: str = None): |
|
|
""" |
|
|
Initialize MMRM. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
pretrained_roberta_path: Path to fine-tuned RoBERTa checkpoint (Phase 1) |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.context_encoder = ContextEncoder(config) |
|
|
|
|
|
|
|
|
if pretrained_roberta_path: |
|
|
|
|
|
checkpoint = torch.load(pretrained_roberta_path, weights_only = False) |
|
|
self.context_encoder.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f"Loaded fine-tuned RoBERTa from {pretrained_roberta_path}") |
|
|
|
|
|
|
|
|
self.image_encoder = ImageEncoder(config, config.resnet_weights) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForMaskedLM, logging as transformers_logging |
|
|
|
|
|
|
|
|
transformers_logging.set_verbosity_error() |
|
|
try: |
|
|
roberta_mlm = AutoModelForMaskedLM.from_pretrained(config.roberta_model, tie_word_embeddings=False) |
|
|
finally: |
|
|
transformers_logging.set_verbosity_warning() |
|
|
|
|
|
|
|
|
lm_decoder = None |
|
|
if hasattr(roberta_mlm, "lm_head"): |
|
|
lm_decoder = roberta_mlm.lm_head.decoder |
|
|
|
|
|
if getattr(lm_decoder, "bias", None) is None: |
|
|
lm_decoder.bias = roberta_mlm.lm_head.bias |
|
|
elif hasattr(roberta_mlm, "cls"): |
|
|
lm_decoder = roberta_mlm.cls.predictions.decoder |
|
|
|
|
|
if getattr(lm_decoder, "bias", None) is None: |
|
|
lm_decoder.bias = roberta_mlm.cls.predictions.bias |
|
|
|
|
|
self.text_decoder = TextDecoder(config, lm_decoder) |
|
|
|
|
|
|
|
|
self.image_decoder = ImageDecoder(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
mask_positions: torch.Tensor, |
|
|
damaged_images: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Forward pass through MMRM. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch_size, seq_len] |
|
|
attention_mask: Attention mask [batch_size, seq_len] |
|
|
mask_positions: Positions of masks [batch_size, num_masks] |
|
|
damaged_images: Damaged images [batch_size, num_masks, 1, 64, 64] |
|
|
|
|
|
Returns: |
|
|
Tuple of (text_logits, restored_images) |
|
|
- text_logits: [batch_size, num_masks, vocab_size] |
|
|
- restored_images: [batch_size, num_masks, 1, 64, 64] |
|
|
""" |
|
|
|
|
|
|
|
|
text_features = self.context_encoder.extract_mask_features( |
|
|
input_ids, attention_mask, mask_positions |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
image_features = self.image_encoder(damaged_images) |
|
|
|
|
|
|
|
|
|
|
|
fused_features = text_features + image_features |
|
|
|
|
|
|
|
|
|
|
|
text_logits = self.text_decoder(fused_features) |
|
|
|
|
|
|
|
|
|
|
|
restored_images = self.image_decoder(fused_features) |
|
|
|
|
|
return text_logits, restored_images |
|
|
|
|
|
def freeze_context_encoder(self): |
|
|
"""Freeze context encoder parameters (for Phase 2).""" |
|
|
self.context_encoder.freeze() |
|
|
|
|
|
def unfreeze_context_encoder(self): |
|
|
"""Unfreeze context encoder parameters.""" |
|
|
self.context_encoder.unfreeze() |
|
|
|
|
|
|
|
|
class BaselineImageModel(nn.Module): |
|
|
""" |
|
|
Baseline model: Image-only (ResNet50) for character recognition. |
|
|
Used as 'Img' baseline in the paper. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
"""Initialize image-only baseline.""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.image_encoder = ImageEncoder(config, config.resnet_weights) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_dim, config.vocab_size) |
|
|
|
|
|
def forward(self, damaged_images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Predict characters from images only. |
|
|
|
|
|
Args: |
|
|
damaged_images: [batch_size, num_masks, 1, 64, 64] |
|
|
|
|
|
Returns: |
|
|
Logits [batch_size, num_masks, vocab_size] |
|
|
""" |
|
|
image_features = self.image_encoder(damaged_images) |
|
|
logits = self.classifier(image_features) |
|
|
return logits |
|
|
|
|
|
|
|
|
class BaselineLanguageModel(nn.Module): |
|
|
""" |
|
|
Baseline model: Text-only (RoBERTa) for masked language modeling. |
|
|
Used as 'LM' and 'LM ft' baselines in the paper. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, fine_tuned: bool = False): |
|
|
""" |
|
|
Initialize language model baseline. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
fine_tuned: If True, this is the fine-tuned version |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.fine_tuned = fine_tuned |
|
|
|
|
|
|
|
|
self.context_encoder = ContextEncoder(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForMaskedLM, logging as transformers_logging |
|
|
|
|
|
|
|
|
transformers_logging.set_verbosity_error() |
|
|
try: |
|
|
roberta_mlm = AutoModelForMaskedLM.from_pretrained(config.roberta_model, tie_word_embeddings=False) |
|
|
finally: |
|
|
transformers_logging.set_verbosity_warning() |
|
|
|
|
|
|
|
|
if self.fine_tuned: |
|
|
|
|
|
|
|
|
lm_decoder = None |
|
|
if hasattr(roberta_mlm, "lm_head"): |
|
|
lm_decoder = roberta_mlm.lm_head.decoder |
|
|
if getattr(lm_decoder, "bias", None) is None: |
|
|
lm_decoder.bias = roberta_mlm.lm_head.bias |
|
|
elif hasattr(roberta_mlm, "cls"): |
|
|
lm_decoder = roberta_mlm.cls.predictions.decoder |
|
|
if getattr(lm_decoder, "bias", None) is None: |
|
|
lm_decoder.bias = roberta_mlm.cls.predictions.bias |
|
|
|
|
|
self.classifier = TextDecoder(config, lm_decoder) |
|
|
else: |
|
|
|
|
|
|
|
|
if hasattr(roberta_mlm, "lm_head"): |
|
|
self.classifier = roberta_mlm.lm_head |
|
|
elif hasattr(roberta_mlm, "cls"): |
|
|
self.classifier = roberta_mlm.cls.predictions |
|
|
else: |
|
|
|
|
|
lm_decoder = roberta_mlm.lm_head.decoder if hasattr(roberta_mlm, "lm_head") else None |
|
|
self.classifier = TextDecoder(config, lm_decoder) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
mask_positions: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Predict characters from context only. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch_size, seq_len] |
|
|
attention_mask: Attention mask [batch_size, seq_len] |
|
|
mask_positions: Positions of masks [batch_size, num_masks] |
|
|
|
|
|
Returns: |
|
|
Logits [batch_size, num_masks, vocab_size] |
|
|
""" |
|
|
text_features = self.context_encoder.extract_mask_features( |
|
|
input_ids, attention_mask, mask_positions |
|
|
) |
|
|
logits = self.classifier(text_features) |
|
|
return logits |
|
|
|