| """ |
| Transformer-based Dependency Parser using PhoBERT. |
| |
| This module implements a Biaffine dependency parser with PhoBERT as the encoder, |
| following the Trankit approach but using Vietnamese-specific PhoBERT. |
| |
| Architecture: |
| Input → PhoBERT → Word-level pooling → MLP projections → Biaffine attention → MST decoding |
| |
| Reference: |
| - Dozat & Manning (2017): Deep Biaffine Attention for Neural Dependency Parsing |
| - Nguyen & Nguyen (2020): PhoBERT: Pre-trained language models for Vietnamese |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Tuple, Optional, Dict, Any |
| import numpy as np |
|
|
| from bamboo1.models.mst import mst_decode, batch_mst_decode |
|
|
|
|
| class MLP(nn.Module): |
| """Multi-layer perceptron for biaffine scoring.""" |
|
|
| def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, hidden_dim) |
| self.activation = nn.LeakyReLU(0.1) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(self.activation(self.linear(x))) |
|
|
|
|
| class Biaffine(nn.Module): |
| """Biaffine attention layer for dependency scoring.""" |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int = 1, |
| bias_x: bool = True, |
| bias_y: bool = True |
| ): |
| super().__init__() |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.bias_x = bias_x |
| self.bias_y = bias_y |
|
|
| self.weight = nn.Parameter( |
| torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y) |
| ) |
| nn.init.xavier_uniform_(self.weight) |
|
|
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch, seq_len, input_dim) - dependent representations |
| y: (batch, seq_len, input_dim) - head representations |
| |
| Returns: |
| scores: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1 |
| """ |
| if self.bias_x: |
| x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) |
| if self.bias_y: |
| y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) |
|
|
| |
| x = torch.einsum('bxi,oij->bxoj', x, self.weight) |
| |
| scores = torch.einsum('bxoj,byj->bxyo', x, y) |
|
|
| if self.output_dim == 1: |
| scores = scores.squeeze(-1) |
|
|
| return scores |
|
|
|
|
| class PhoBERTDependencyParser(nn.Module): |
| """ |
| PhoBERT-based Biaffine Dependency Parser. |
| |
| Uses PhoBERT as encoder with first-subword pooling for word alignment, |
| followed by biaffine attention for arc and relation prediction. |
| """ |
|
|
| def __init__( |
| self, |
| encoder_name: str = "vinai/phobert-base", |
| n_rels: int = 50, |
| arc_hidden: int = 500, |
| rel_hidden: int = 100, |
| dropout: float = 0.33, |
| use_mst: bool = True, |
| ): |
| """ |
| Args: |
| encoder_name: HuggingFace model name for PhoBERT |
| n_rels: Number of dependency relations |
| arc_hidden: Hidden dimension for arc MLPs |
| rel_hidden: Hidden dimension for relation MLPs |
| dropout: Dropout rate |
| use_mst: Use MST decoding (True) or greedy decoding (False) |
| """ |
| super().__init__() |
|
|
| from transformers import AutoModel, AutoTokenizer |
|
|
| self.encoder_name = encoder_name |
| self.n_rels = n_rels |
| self.use_mst = use_mst |
|
|
| |
| self.encoder = AutoModel.from_pretrained(encoder_name) |
| self.tokenizer = AutoTokenizer.from_pretrained(encoder_name) |
| self.hidden_size = self.encoder.config.hidden_size |
|
|
| |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout) |
| self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout) |
| self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout) |
| self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout) |
|
|
| |
| self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
|
|
| def _get_word_embeddings( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| word_starts: torch.Tensor, |
| word_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Get word-level embeddings from subword encoder output. |
| |
| Uses first-subword pooling strategy: each word is represented by |
| the embedding of its first subword token. |
| |
| Args: |
| input_ids: (batch, subword_seq_len) - Subword token IDs |
| attention_mask: (batch, subword_seq_len) - Attention mask for subwords |
| word_starts: (batch, word_seq_len) - Indices of first subword for each word |
| word_mask: (batch, word_seq_len) - Mask for actual words |
| |
| Returns: |
| word_embeddings: (batch, word_seq_len, hidden_size) |
| """ |
| |
| encoder_output = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True |
| ) |
| hidden_states = encoder_output.last_hidden_state |
|
|
| |
| hidden_states = self.dropout(hidden_states) |
|
|
| |
| batch_size, word_seq_len = word_starts.shape |
|
|
| |
| |
| word_embeddings = torch.gather( |
| hidden_states, |
| dim=1, |
| index=word_starts.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)) |
| ) |
|
|
| return word_embeddings |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| word_starts: torch.Tensor, |
| word_mask: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass computing arc and relation scores. |
| |
| Args: |
| input_ids: (batch, subword_seq_len) - Subword token IDs |
| attention_mask: (batch, subword_seq_len) - Attention mask for subwords |
| word_starts: (batch, word_seq_len) - Indices of first subword for each word |
| word_mask: (batch, word_seq_len) - Mask for actual words |
| |
| Returns: |
| arc_scores: (batch, word_seq_len, word_seq_len) - Arc scores |
| rel_scores: (batch, word_seq_len, word_seq_len, n_rels) - Relation scores |
| """ |
| |
| word_embeddings = self._get_word_embeddings( |
| input_ids, attention_mask, word_starts, word_mask |
| ) |
|
|
| |
| arc_dep = self.mlp_arc_dep(word_embeddings) |
| arc_head = self.mlp_arc_head(word_embeddings) |
| rel_dep = self.mlp_rel_dep(word_embeddings) |
| rel_head = self.mlp_rel_head(word_embeddings) |
|
|
| |
| arc_scores = self.arc_attn(arc_dep, arc_head) |
| rel_scores = self.rel_attn(rel_dep, rel_head) |
|
|
| return arc_scores, rel_scores |
|
|
| def loss( |
| self, |
| arc_scores: torch.Tensor, |
| rel_scores: torch.Tensor, |
| heads: torch.Tensor, |
| rels: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Compute cross-entropy loss for arcs and relations. |
| |
| Args: |
| arc_scores: (batch, seq_len, seq_len) - Arc scores |
| rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores |
| heads: (batch, seq_len) - Gold head indices |
| rels: (batch, seq_len) - Gold relation indices |
| mask: (batch, seq_len) - Token mask (1 for real tokens, 0 for padding) |
| |
| Returns: |
| Total loss (arc_loss + rel_loss) |
| """ |
| batch_size, seq_len = mask.shape |
|
|
| |
| arc_scores_masked = arc_scores.clone() |
| arc_scores_masked = arc_scores_masked.masked_fill(~mask.unsqueeze(2), float('-inf')) |
|
|
| |
| arc_loss = F.cross_entropy( |
| arc_scores_masked[mask].view(-1, seq_len), |
| heads[mask], |
| reduction='mean' |
| ) |
|
|
| |
| batch_indices = torch.arange(batch_size, device=rel_scores.device).unsqueeze(1) |
| seq_indices = torch.arange(seq_len, device=rel_scores.device) |
| rel_scores_gold = rel_scores[batch_indices, seq_indices, heads] |
|
|
| rel_loss = F.cross_entropy( |
| rel_scores_gold[mask], |
| rels[mask], |
| reduction='mean' |
| ) |
|
|
| return arc_loss + rel_loss |
|
|
| def decode( |
| self, |
| arc_scores: torch.Tensor, |
| rel_scores: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Decode predictions using MST or greedy decoding. |
| |
| Args: |
| arc_scores: (batch, seq_len, seq_len) - Arc scores |
| rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores |
| mask: (batch, seq_len) - Token mask |
| |
| Returns: |
| arc_preds: (batch, seq_len) - Predicted head indices |
| rel_preds: (batch, seq_len) - Predicted relation indices |
| """ |
| batch_size, seq_len = mask.shape |
| device = arc_scores.device |
|
|
| if self.use_mst: |
| |
| lengths = mask.sum(dim=1).cpu().numpy() |
| arc_scores_np = arc_scores.cpu().numpy() |
| arc_preds_np = batch_mst_decode(arc_scores_np, lengths) |
| arc_preds = torch.from_numpy(arc_preds_np).to(device) |
| else: |
| |
| arc_preds = arc_scores.argmax(dim=-1) |
|
|
| |
| batch_indices = torch.arange(batch_size, device=device).unsqueeze(1) |
| seq_indices = torch.arange(seq_len, device=device) |
| rel_scores_pred = rel_scores[batch_indices, seq_indices, arc_preds] |
| rel_preds = rel_scores_pred.argmax(dim=-1) |
|
|
| return arc_preds, rel_preds |
|
|
| def predict( |
| self, |
| words: List[str], |
| return_probs: bool = False, |
| ) -> List[Tuple[str, int, str]]: |
| """ |
| Predict dependencies for a single sentence. |
| |
| Args: |
| words: List of words (pre-tokenized) |
| return_probs: Whether to return probability scores |
| |
| Returns: |
| List of (word, head, deprel) tuples |
| """ |
| self.eval() |
| device = next(self.parameters()).device |
|
|
| |
| encoded = self.tokenize_with_alignment([words]) |
|
|
| |
| input_ids = encoded['input_ids'].to(device) |
| attention_mask = encoded['attention_mask'].to(device) |
| word_starts = encoded['word_starts'].to(device) |
| word_mask = encoded['word_mask'].to(device) |
|
|
| with torch.no_grad(): |
| arc_scores, rel_scores = self.forward( |
| input_ids, attention_mask, word_starts, word_mask |
| ) |
| arc_preds, rel_preds = self.decode(arc_scores, rel_scores, word_mask) |
|
|
| |
| arc_preds = arc_preds[0].cpu().tolist() |
| rel_preds = rel_preds[0].cpu().tolist() |
|
|
| results = [] |
| for i, word in enumerate(words): |
| head = arc_preds[i] |
| rel_idx = rel_preds[i] |
| rel = self.idx2rel.get(rel_idx, "dep") |
| results.append((word, head, rel)) |
|
|
| return results |
|
|
| def tokenize_with_alignment( |
| self, |
| sentences: List[List[str]], |
| max_length: int = 256, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Tokenize sentences and track word-subword alignment. |
| |
| Args: |
| sentences: List of sentences, where each sentence is a list of words |
| max_length: Maximum subword sequence length |
| |
| Returns: |
| Dictionary with input_ids, attention_mask, word_starts, word_mask |
| """ |
| batch_input_ids = [] |
| batch_attention_mask = [] |
| batch_word_starts = [] |
| batch_word_mask = [] |
|
|
| for words in sentences: |
| |
| word_starts = [] |
| subword_ids = [self.tokenizer.cls_token_id] |
|
|
| for word in words: |
| word_starts.append(len(subword_ids)) |
| word_tokens = self.tokenizer.encode(word, add_special_tokens=False) |
| subword_ids.extend(word_tokens) |
|
|
| subword_ids.append(self.tokenizer.sep_token_id) |
|
|
| |
| if len(subword_ids) > max_length: |
| subword_ids = subword_ids[:max_length-1] + [self.tokenizer.sep_token_id] |
| |
| word_starts = [ws for ws in word_starts if ws < max_length - 1] |
|
|
| attention_mask = [1] * len(subword_ids) |
|
|
| batch_input_ids.append(subword_ids) |
| batch_attention_mask.append(attention_mask) |
| batch_word_starts.append(word_starts) |
| batch_word_mask.append([1] * len(word_starts)) |
|
|
| |
| max_subword_len = max(len(ids) for ids in batch_input_ids) |
| max_word_len = max(len(ws) for ws in batch_word_starts) |
|
|
| padded_input_ids = [] |
| padded_attention_mask = [] |
| padded_word_starts = [] |
| padded_word_mask = [] |
|
|
| for i in range(len(sentences)): |
| |
| pad_len = max_subword_len - len(batch_input_ids[i]) |
| padded_input_ids.append( |
| batch_input_ids[i] + [self.tokenizer.pad_token_id] * pad_len |
| ) |
| padded_attention_mask.append( |
| batch_attention_mask[i] + [0] * pad_len |
| ) |
|
|
| |
| word_pad_len = max_word_len - len(batch_word_starts[i]) |
| |
| padded_word_starts.append( |
| batch_word_starts[i] + [0] * word_pad_len |
| ) |
| padded_word_mask.append( |
| batch_word_mask[i] + [0] * word_pad_len |
| ) |
|
|
| return { |
| 'input_ids': torch.tensor(padded_input_ids, dtype=torch.long), |
| 'attention_mask': torch.tensor(padded_attention_mask, dtype=torch.long), |
| 'word_starts': torch.tensor(padded_word_starts, dtype=torch.long), |
| 'word_mask': torch.tensor(padded_word_mask, dtype=torch.bool), |
| } |
|
|
| def save(self, path: str, vocab: Optional[Dict] = None): |
| """ |
| Save model checkpoint. |
| |
| Args: |
| path: Directory path to save the model |
| vocab: Vocabulary dict with rel2idx and idx2rel mappings |
| """ |
| import os |
| os.makedirs(path, exist_ok=True) |
|
|
| |
| checkpoint = { |
| 'model_state_dict': self.state_dict(), |
| 'config': { |
| 'encoder_name': self.encoder_name, |
| 'n_rels': self.n_rels, |
| 'arc_hidden': self.mlp_arc_dep.linear.out_features, |
| 'rel_hidden': self.mlp_rel_dep.linear.out_features, |
| 'dropout': self.dropout.p, |
| 'use_mst': self.use_mst, |
| }, |
| } |
|
|
| if vocab is not None: |
| checkpoint['vocab'] = vocab |
|
|
| torch.save(checkpoint, os.path.join(path, 'model.pt')) |
|
|
| |
| self.tokenizer.save_pretrained(path) |
|
|
| @classmethod |
| def load(cls, path: str, device: str = 'cpu') -> 'PhoBERTDependencyParser': |
| """ |
| Load model from checkpoint. |
| |
| Args: |
| path: Directory path containing the saved model |
| device: Device to load the model to |
| |
| Returns: |
| Loaded PhoBERTDependencyParser model |
| """ |
| import os |
|
|
| checkpoint = torch.load( |
| os.path.join(path, 'model.pt'), |
| map_location=device, |
| weights_only=False |
| ) |
|
|
| config = checkpoint['config'] |
|
|
| |
| model = cls( |
| encoder_name=config['encoder_name'], |
| n_rels=config['n_rels'], |
| arc_hidden=config['arc_hidden'], |
| rel_hidden=config['rel_hidden'], |
| dropout=config['dropout'], |
| use_mst=config.get('use_mst', True), |
| ) |
|
|
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
|
|
| |
| if 'vocab' in checkpoint: |
| model.rel2idx = checkpoint['vocab'].get('rel2idx', {}) |
| model.idx2rel = checkpoint['vocab'].get('idx2rel', {}) |
| else: |
| model.rel2idx = {} |
| model.idx2rel = {} |
|
|
| model.to(device) |
| return model |
|
|