""" Context Encoder using pre-trained GuwenBERT RoBERTa. Implements the textual feature extraction module from the paper. """ import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer class ContextEncoder(nn.Module): """ Context encoder using GuwenBERT RoBERTa large. Extracts features from masked positions in the text. """ def __init__(self, config, pretrained_model_name: str = None): """ Initialize context encoder. Args: config: Configuration object pretrained_model_name: HuggingFace model identifier """ super().__init__() self.config = config if pretrained_model_name is None: pretrained_model_name = config.roberta_model # Load pre-trained GuwenBERT RoBERTa from transformers import logging as transformers_logging # Suppress warnings about unexpected keys (lm_head) as we only want the encoder transformers_logging.set_verbosity_error() try: self.encoder = AutoModel.from_pretrained(pretrained_model_name, tie_word_embeddings=False) finally: transformers_logging.set_verbosity_warning() self.hidden_dim = self.encoder.config.hidden_size # # Verify hidden dimension matches config # assert self.hidden_dim == config.hidden_dim, \ # f"Model hidden dim {self.hidden_dim} != config hidden dim {config.hidden_dim}" def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """ Forward pass through RoBERTa. Args: input_ids: Token IDs [batch_size, seq_len] attention_mask: Attention mask [batch_size, seq_len] Returns: Hidden states [batch_size, seq_len, hidden_dim] """ outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask ) # Return sequence of hidden states return outputs.last_hidden_state def extract_mask_features( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, mask_positions: torch.Tensor ) -> torch.Tensor: """ Extract features at masked positions. Args: input_ids: Token IDs [batch_size, seq_len] attention_mask: Attention mask [batch_size, seq_len] mask_positions: Positions of masks [batch_size, num_masks] Returns: Features at mask positions [batch_size, num_masks, hidden_dim] """ # Get all hidden states hidden_states = self.forward(input_ids, attention_mask) # Extract features at mask positions batch_size, num_masks = mask_positions.shape # Expand mask_positions for gathering mask_positions_expanded = mask_positions.unsqueeze(-1).expand( batch_size, num_masks, self.hidden_dim ) # Gather features at mask positions mask_features = torch.gather(hidden_states, 1, mask_positions_expanded) return mask_features def freeze(self): """Freeze all parameters (for Phase 2 training).""" for param in self.parameters(): param.requires_grad = False def unfreeze(self): """Unfreeze all parameters.""" for param in self.parameters(): param.requires_grad = True