|
|
""" |
|
|
Context Encoder using pre-trained GuwenBERT RoBERTa. |
|
|
Implements the textual feature extraction module from the paper. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
|
|
|
class ContextEncoder(nn.Module): |
|
|
""" |
|
|
Context encoder using GuwenBERT RoBERTa large. |
|
|
Extracts features from masked positions in the text. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, pretrained_model_name: str = None): |
|
|
""" |
|
|
Initialize context encoder. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
pretrained_model_name: HuggingFace model identifier |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
if pretrained_model_name is None: |
|
|
pretrained_model_name = config.roberta_model |
|
|
|
|
|
|
|
|
from transformers import logging as transformers_logging |
|
|
|
|
|
|
|
|
transformers_logging.set_verbosity_error() |
|
|
try: |
|
|
self.encoder = AutoModel.from_pretrained(pretrained_model_name, tie_word_embeddings=False) |
|
|
finally: |
|
|
transformers_logging.set_verbosity_warning() |
|
|
|
|
|
self.hidden_dim = self.encoder.config.hidden_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through RoBERTa. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch_size, seq_len] |
|
|
attention_mask: Attention mask [batch_size, seq_len] |
|
|
|
|
|
Returns: |
|
|
Hidden states [batch_size, seq_len, hidden_dim] |
|
|
""" |
|
|
outputs = self.encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
return outputs.last_hidden_state |
|
|
|
|
|
def extract_mask_features( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
mask_positions: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Extract features at masked positions. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch_size, seq_len] |
|
|
attention_mask: Attention mask [batch_size, seq_len] |
|
|
mask_positions: Positions of masks [batch_size, num_masks] |
|
|
|
|
|
Returns: |
|
|
Features at mask positions [batch_size, num_masks, hidden_dim] |
|
|
""" |
|
|
|
|
|
hidden_states = self.forward(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
batch_size, num_masks = mask_positions.shape |
|
|
|
|
|
|
|
|
mask_positions_expanded = mask_positions.unsqueeze(-1).expand( |
|
|
batch_size, num_masks, self.hidden_dim |
|
|
) |
|
|
|
|
|
|
|
|
mask_features = torch.gather(hidden_states, 1, mask_positions_expanded) |
|
|
|
|
|
return mask_features |
|
|
|
|
|
def freeze(self): |
|
|
"""Freeze all parameters (for Phase 2 training).""" |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def unfreeze(self): |
|
|
"""Unfreeze all parameters.""" |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = True |
|
|
|