""" Transformer-based Dependency Parser using PhoBERT. This module implements a Biaffine dependency parser with PhoBERT as the encoder, following the Trankit approach but using Vietnamese-specific PhoBERT. Architecture: Input → PhoBERT → Word-level pooling → MLP projections → Biaffine attention → MST decoding Reference: - Dozat & Manning (2017): Deep Biaffine Attention for Neural Dependency Parsing - Nguyen & Nguyen (2020): PhoBERT: Pre-trained language models for Vietnamese """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Optional, Dict, Any import numpy as np from bamboo1.models.mst import mst_decode, batch_mst_decode class MLP(nn.Module): """Multi-layer perceptron for biaffine scoring.""" def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): super().__init__() self.linear = nn.Linear(input_dim, hidden_dim) self.activation = nn.LeakyReLU(0.1) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.activation(self.linear(x))) class Biaffine(nn.Module): """Biaffine attention layer for dependency scoring.""" def __init__( self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True ): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.bias_x = bias_x self.bias_y = bias_y self.weight = nn.Parameter( torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y) ) nn.init.xavier_uniform_(self.weight) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, seq_len, input_dim) - dependent representations y: (batch, seq_len, input_dim) - head representations Returns: scores: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1 """ if self.bias_x: x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) if self.bias_y: y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) # (batch, seq_len, output_dim, input_dim+1) x = torch.einsum('bxi,oij->bxoj', x, self.weight) # (batch, seq_len, seq_len, output_dim) scores = torch.einsum('bxoj,byj->bxyo', x, y) if self.output_dim == 1: scores = scores.squeeze(-1) return scores class PhoBERTDependencyParser(nn.Module): """ PhoBERT-based Biaffine Dependency Parser. Uses PhoBERT as encoder with first-subword pooling for word alignment, followed by biaffine attention for arc and relation prediction. """ def __init__( self, encoder_name: str = "vinai/phobert-base", n_rels: int = 50, arc_hidden: int = 500, rel_hidden: int = 100, dropout: float = 0.33, use_mst: bool = True, ): """ Args: encoder_name: HuggingFace model name for PhoBERT n_rels: Number of dependency relations arc_hidden: Hidden dimension for arc MLPs rel_hidden: Hidden dimension for relation MLPs dropout: Dropout rate use_mst: Use MST decoding (True) or greedy decoding (False) """ super().__init__() from transformers import AutoModel, AutoTokenizer self.encoder_name = encoder_name self.n_rels = n_rels self.use_mst = use_mst # Load PhoBERT encoder self.encoder = AutoModel.from_pretrained(encoder_name) self.tokenizer = AutoTokenizer.from_pretrained(encoder_name) self.hidden_size = self.encoder.config.hidden_size # 768 for phobert-base # Dropout self.dropout = nn.Dropout(dropout) # MLP projections self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout) self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout) self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout) self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout) # Biaffine attention self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) def _get_word_embeddings( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_starts: torch.Tensor, word_mask: torch.Tensor, ) -> torch.Tensor: """ Get word-level embeddings from subword encoder output. Uses first-subword pooling strategy: each word is represented by the embedding of its first subword token. Args: input_ids: (batch, subword_seq_len) - Subword token IDs attention_mask: (batch, subword_seq_len) - Attention mask for subwords word_starts: (batch, word_seq_len) - Indices of first subword for each word word_mask: (batch, word_seq_len) - Mask for actual words Returns: word_embeddings: (batch, word_seq_len, hidden_size) """ # Get encoder output encoder_output = self.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) hidden_states = encoder_output.last_hidden_state # (batch, subword_seq_len, hidden) # Apply dropout hidden_states = self.dropout(hidden_states) # Extract word embeddings using first-subword indices batch_size, word_seq_len = word_starts.shape # Gather word embeddings # word_starts: (batch, word_seq_len) -> (batch, word_seq_len, hidden) word_embeddings = torch.gather( hidden_states, dim=1, index=word_starts.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)) ) return word_embeddings def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_starts: torch.Tensor, word_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass computing arc and relation scores. Args: input_ids: (batch, subword_seq_len) - Subword token IDs attention_mask: (batch, subword_seq_len) - Attention mask for subwords word_starts: (batch, word_seq_len) - Indices of first subword for each word word_mask: (batch, word_seq_len) - Mask for actual words Returns: arc_scores: (batch, word_seq_len, word_seq_len) - Arc scores rel_scores: (batch, word_seq_len, word_seq_len, n_rels) - Relation scores """ # Get word-level embeddings word_embeddings = self._get_word_embeddings( input_ids, attention_mask, word_starts, word_mask ) # MLP projections arc_dep = self.mlp_arc_dep(word_embeddings) arc_head = self.mlp_arc_head(word_embeddings) rel_dep = self.mlp_rel_dep(word_embeddings) rel_head = self.mlp_rel_head(word_embeddings) # Biaffine attention arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq, seq) rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq, seq, n_rels) return arc_scores, rel_scores def loss( self, arc_scores: torch.Tensor, rel_scores: torch.Tensor, heads: torch.Tensor, rels: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """ Compute cross-entropy loss for arcs and relations. Args: arc_scores: (batch, seq_len, seq_len) - Arc scores rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores heads: (batch, seq_len) - Gold head indices rels: (batch, seq_len) - Gold relation indices mask: (batch, seq_len) - Token mask (1 for real tokens, 0 for padding) Returns: Total loss (arc_loss + rel_loss) """ batch_size, seq_len = mask.shape # Mask invalid positions arc_scores_masked = arc_scores.clone() arc_scores_masked = arc_scores_masked.masked_fill(~mask.unsqueeze(2), float('-inf')) # Arc loss: cross-entropy over possible heads arc_loss = F.cross_entropy( arc_scores_masked[mask].view(-1, seq_len), heads[mask], reduction='mean' ) # Relation loss: cross-entropy conditioned on gold heads batch_indices = torch.arange(batch_size, device=rel_scores.device).unsqueeze(1) seq_indices = torch.arange(seq_len, device=rel_scores.device) rel_scores_gold = rel_scores[batch_indices, seq_indices, heads] rel_loss = F.cross_entropy( rel_scores_gold[mask], rels[mask], reduction='mean' ) return arc_loss + rel_loss def decode( self, arc_scores: torch.Tensor, rel_scores: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode predictions using MST or greedy decoding. Args: arc_scores: (batch, seq_len, seq_len) - Arc scores rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores mask: (batch, seq_len) - Token mask Returns: arc_preds: (batch, seq_len) - Predicted head indices rel_preds: (batch, seq_len) - Predicted relation indices """ batch_size, seq_len = mask.shape device = arc_scores.device if self.use_mst: # MST decoding for valid tree structure lengths = mask.sum(dim=1).cpu().numpy() arc_scores_np = arc_scores.cpu().numpy() arc_preds_np = batch_mst_decode(arc_scores_np, lengths) arc_preds = torch.from_numpy(arc_preds_np).to(device) else: # Greedy decoding arc_preds = arc_scores.argmax(dim=-1) # Get relation predictions for predicted heads batch_indices = torch.arange(batch_size, device=device).unsqueeze(1) seq_indices = torch.arange(seq_len, device=device) rel_scores_pred = rel_scores[batch_indices, seq_indices, arc_preds] rel_preds = rel_scores_pred.argmax(dim=-1) return arc_preds, rel_preds def predict( self, words: List[str], return_probs: bool = False, ) -> List[Tuple[str, int, str]]: """ Predict dependencies for a single sentence. Args: words: List of words (pre-tokenized) return_probs: Whether to return probability scores Returns: List of (word, head, deprel) tuples """ self.eval() device = next(self.parameters()).device # Tokenize with word boundary tracking encoded = self.tokenize_with_alignment([words]) # Move to device input_ids = encoded['input_ids'].to(device) attention_mask = encoded['attention_mask'].to(device) word_starts = encoded['word_starts'].to(device) word_mask = encoded['word_mask'].to(device) with torch.no_grad(): arc_scores, rel_scores = self.forward( input_ids, attention_mask, word_starts, word_mask ) arc_preds, rel_preds = self.decode(arc_scores, rel_scores, word_mask) # Convert to list of tuples arc_preds = arc_preds[0].cpu().tolist() rel_preds = rel_preds[0].cpu().tolist() results = [] for i, word in enumerate(words): head = arc_preds[i] rel_idx = rel_preds[i] rel = self.idx2rel.get(rel_idx, "dep") results.append((word, head, rel)) return results def tokenize_with_alignment( self, sentences: List[List[str]], max_length: int = 256, ) -> Dict[str, torch.Tensor]: """ Tokenize sentences and track word-subword alignment. Args: sentences: List of sentences, where each sentence is a list of words max_length: Maximum subword sequence length Returns: Dictionary with input_ids, attention_mask, word_starts, word_mask """ batch_input_ids = [] batch_attention_mask = [] batch_word_starts = [] batch_word_mask = [] for words in sentences: # Tokenize each word separately to track boundaries word_starts = [] subword_ids = [self.tokenizer.cls_token_id] for word in words: word_starts.append(len(subword_ids)) word_tokens = self.tokenizer.encode(word, add_special_tokens=False) subword_ids.extend(word_tokens) subword_ids.append(self.tokenizer.sep_token_id) # Truncate if needed if len(subword_ids) > max_length: subword_ids = subword_ids[:max_length-1] + [self.tokenizer.sep_token_id] # Truncate word_starts that go beyond word_starts = [ws for ws in word_starts if ws < max_length - 1] attention_mask = [1] * len(subword_ids) batch_input_ids.append(subword_ids) batch_attention_mask.append(attention_mask) batch_word_starts.append(word_starts) batch_word_mask.append([1] * len(word_starts)) # Pad sequences max_subword_len = max(len(ids) for ids in batch_input_ids) max_word_len = max(len(ws) for ws in batch_word_starts) padded_input_ids = [] padded_attention_mask = [] padded_word_starts = [] padded_word_mask = [] for i in range(len(sentences)): # Pad subwords pad_len = max_subword_len - len(batch_input_ids[i]) padded_input_ids.append( batch_input_ids[i] + [self.tokenizer.pad_token_id] * pad_len ) padded_attention_mask.append( batch_attention_mask[i] + [0] * pad_len ) # Pad words word_pad_len = max_word_len - len(batch_word_starts[i]) # Use 0 for padding word_starts (points to CLS token, but masked) padded_word_starts.append( batch_word_starts[i] + [0] * word_pad_len ) padded_word_mask.append( batch_word_mask[i] + [0] * word_pad_len ) return { 'input_ids': torch.tensor(padded_input_ids, dtype=torch.long), 'attention_mask': torch.tensor(padded_attention_mask, dtype=torch.long), 'word_starts': torch.tensor(padded_word_starts, dtype=torch.long), 'word_mask': torch.tensor(padded_word_mask, dtype=torch.bool), } def save(self, path: str, vocab: Optional[Dict] = None): """ Save model checkpoint. Args: path: Directory path to save the model vocab: Vocabulary dict with rel2idx and idx2rel mappings """ import os os.makedirs(path, exist_ok=True) # Save model state checkpoint = { 'model_state_dict': self.state_dict(), 'config': { 'encoder_name': self.encoder_name, 'n_rels': self.n_rels, 'arc_hidden': self.mlp_arc_dep.linear.out_features, 'rel_hidden': self.mlp_rel_dep.linear.out_features, 'dropout': self.dropout.p, 'use_mst': self.use_mst, }, } if vocab is not None: checkpoint['vocab'] = vocab torch.save(checkpoint, os.path.join(path, 'model.pt')) # Save tokenizer self.tokenizer.save_pretrained(path) @classmethod def load(cls, path: str, device: str = 'cpu') -> 'PhoBERTDependencyParser': """ Load model from checkpoint. Args: path: Directory path containing the saved model device: Device to load the model to Returns: Loaded PhoBERTDependencyParser model """ import os checkpoint = torch.load( os.path.join(path, 'model.pt'), map_location=device, weights_only=False ) config = checkpoint['config'] # Create model model = cls( encoder_name=config['encoder_name'], n_rels=config['n_rels'], arc_hidden=config['arc_hidden'], rel_hidden=config['rel_hidden'], dropout=config['dropout'], use_mst=config.get('use_mst', True), ) # Load state dict model.load_state_dict(checkpoint['model_state_dict']) # Load vocabulary if 'vocab' in checkpoint: model.rel2idx = checkpoint['vocab'].get('rel2idx', {}) model.idx2rel = checkpoint['vocab'].get('idx2rel', {}) else: model.rel2idx = {} model.idx2rel = {} model.to(device) return model