MMRM / models /context_encoder.py
rexera's picture
0-shot pipeline test
87224ba
"""
Context Encoder using pre-trained GuwenBERT RoBERTa.
Implements the textual feature extraction module from the paper.
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class ContextEncoder(nn.Module):
"""
Context encoder using GuwenBERT RoBERTa large.
Extracts features from masked positions in the text.
"""
def __init__(self, config, pretrained_model_name: str = None):
"""
Initialize context encoder.
Args:
config: Configuration object
pretrained_model_name: HuggingFace model identifier
"""
super().__init__()
self.config = config
if pretrained_model_name is None:
pretrained_model_name = config.roberta_model
# Load pre-trained GuwenBERT RoBERTa
from transformers import logging as transformers_logging
# Suppress warnings about unexpected keys (lm_head) as we only want the encoder
transformers_logging.set_verbosity_error()
try:
self.encoder = AutoModel.from_pretrained(pretrained_model_name, tie_word_embeddings=False)
finally:
transformers_logging.set_verbosity_warning()
self.hidden_dim = self.encoder.config.hidden_size
# # Verify hidden dimension matches config
# assert self.hidden_dim == config.hidden_dim, \
# f"Model hidden dim {self.hidden_dim} != config hidden dim {config.hidden_dim}"
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Forward pass through RoBERTa.
Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
Returns:
Hidden states [batch_size, seq_len, hidden_dim]
"""
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# Return sequence of hidden states
return outputs.last_hidden_state
def extract_mask_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
mask_positions: torch.Tensor
) -> torch.Tensor:
"""
Extract features at masked positions.
Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
mask_positions: Positions of masks [batch_size, num_masks]
Returns:
Features at mask positions [batch_size, num_masks, hidden_dim]
"""
# Get all hidden states
hidden_states = self.forward(input_ids, attention_mask)
# Extract features at mask positions
batch_size, num_masks = mask_positions.shape
# Expand mask_positions for gathering
mask_positions_expanded = mask_positions.unsqueeze(-1).expand(
batch_size, num_masks, self.hidden_dim
)
# Gather features at mask positions
mask_features = torch.gather(hidden_states, 1, mask_positions_expanded)
return mask_features
def freeze(self):
"""Freeze all parameters (for Phase 2 training)."""
for param in self.parameters():
param.requires_grad = False
def unfreeze(self):
"""Unfreeze all parameters."""
for param in self.parameters():
param.requires_grad = True