bamboo-1 / src /models /transformer_parser.py
rain1024's picture
Add model checkpoints and src/models to repo
7186695
"""
Transformer-based Dependency Parser using PhoBERT.
This module implements a Biaffine dependency parser with PhoBERT as the encoder,
following the Trankit approach but using Vietnamese-specific PhoBERT.
Architecture:
Input → PhoBERT → Word-level pooling → MLP projections → Biaffine attention → MST decoding
Reference:
- Dozat & Manning (2017): Deep Biaffine Attention for Neural Dependency Parsing
- Nguyen & Nguyen (2020): PhoBERT: Pre-trained language models for Vietnamese
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict, Any
import numpy as np
from bamboo1.models.mst import mst_decode, batch_mst_decode
class MLP(nn.Module):
"""Multi-layer perceptron for biaffine scoring."""
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33):
super().__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
self.activation = nn.LeakyReLU(0.1)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.activation(self.linear(x)))
class Biaffine(nn.Module):
"""Biaffine attention layer for dependency scoring."""
def __init__(
self,
input_dim: int,
output_dim: int = 1,
bias_x: bool = True,
bias_y: bool = True
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias_x = bias_x
self.bias_y = bias_y
self.weight = nn.Parameter(
torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)
)
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, input_dim) - dependent representations
y: (batch, seq_len, input_dim) - head representations
Returns:
scores: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1
"""
if self.bias_x:
x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)
if self.bias_y:
y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1)
# (batch, seq_len, output_dim, input_dim+1)
x = torch.einsum('bxi,oij->bxoj', x, self.weight)
# (batch, seq_len, seq_len, output_dim)
scores = torch.einsum('bxoj,byj->bxyo', x, y)
if self.output_dim == 1:
scores = scores.squeeze(-1)
return scores
class PhoBERTDependencyParser(nn.Module):
"""
PhoBERT-based Biaffine Dependency Parser.
Uses PhoBERT as encoder with first-subword pooling for word alignment,
followed by biaffine attention for arc and relation prediction.
"""
def __init__(
self,
encoder_name: str = "vinai/phobert-base",
n_rels: int = 50,
arc_hidden: int = 500,
rel_hidden: int = 100,
dropout: float = 0.33,
use_mst: bool = True,
):
"""
Args:
encoder_name: HuggingFace model name for PhoBERT
n_rels: Number of dependency relations
arc_hidden: Hidden dimension for arc MLPs
rel_hidden: Hidden dimension for relation MLPs
dropout: Dropout rate
use_mst: Use MST decoding (True) or greedy decoding (False)
"""
super().__init__()
from transformers import AutoModel, AutoTokenizer
self.encoder_name = encoder_name
self.n_rels = n_rels
self.use_mst = use_mst
# Load PhoBERT encoder
self.encoder = AutoModel.from_pretrained(encoder_name)
self.tokenizer = AutoTokenizer.from_pretrained(encoder_name)
self.hidden_size = self.encoder.config.hidden_size # 768 for phobert-base
# Dropout
self.dropout = nn.Dropout(dropout)
# MLP projections
self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout)
self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout)
self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout)
self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout)
# Biaffine attention
self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
def _get_word_embeddings(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
word_starts: torch.Tensor,
word_mask: torch.Tensor,
) -> torch.Tensor:
"""
Get word-level embeddings from subword encoder output.
Uses first-subword pooling strategy: each word is represented by
the embedding of its first subword token.
Args:
input_ids: (batch, subword_seq_len) - Subword token IDs
attention_mask: (batch, subword_seq_len) - Attention mask for subwords
word_starts: (batch, word_seq_len) - Indices of first subword for each word
word_mask: (batch, word_seq_len) - Mask for actual words
Returns:
word_embeddings: (batch, word_seq_len, hidden_size)
"""
# Get encoder output
encoder_output = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
hidden_states = encoder_output.last_hidden_state # (batch, subword_seq_len, hidden)
# Apply dropout
hidden_states = self.dropout(hidden_states)
# Extract word embeddings using first-subword indices
batch_size, word_seq_len = word_starts.shape
# Gather word embeddings
# word_starts: (batch, word_seq_len) -> (batch, word_seq_len, hidden)
word_embeddings = torch.gather(
hidden_states,
dim=1,
index=word_starts.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
)
return word_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
word_starts: torch.Tensor,
word_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass computing arc and relation scores.
Args:
input_ids: (batch, subword_seq_len) - Subword token IDs
attention_mask: (batch, subword_seq_len) - Attention mask for subwords
word_starts: (batch, word_seq_len) - Indices of first subword for each word
word_mask: (batch, word_seq_len) - Mask for actual words
Returns:
arc_scores: (batch, word_seq_len, word_seq_len) - Arc scores
rel_scores: (batch, word_seq_len, word_seq_len, n_rels) - Relation scores
"""
# Get word-level embeddings
word_embeddings = self._get_word_embeddings(
input_ids, attention_mask, word_starts, word_mask
)
# MLP projections
arc_dep = self.mlp_arc_dep(word_embeddings)
arc_head = self.mlp_arc_head(word_embeddings)
rel_dep = self.mlp_rel_dep(word_embeddings)
rel_head = self.mlp_rel_head(word_embeddings)
# Biaffine attention
arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq, seq)
rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq, seq, n_rels)
return arc_scores, rel_scores
def loss(
self,
arc_scores: torch.Tensor,
rel_scores: torch.Tensor,
heads: torch.Tensor,
rels: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""
Compute cross-entropy loss for arcs and relations.
Args:
arc_scores: (batch, seq_len, seq_len) - Arc scores
rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores
heads: (batch, seq_len) - Gold head indices
rels: (batch, seq_len) - Gold relation indices
mask: (batch, seq_len) - Token mask (1 for real tokens, 0 for padding)
Returns:
Total loss (arc_loss + rel_loss)
"""
batch_size, seq_len = mask.shape
# Mask invalid positions
arc_scores_masked = arc_scores.clone()
arc_scores_masked = arc_scores_masked.masked_fill(~mask.unsqueeze(2), float('-inf'))
# Arc loss: cross-entropy over possible heads
arc_loss = F.cross_entropy(
arc_scores_masked[mask].view(-1, seq_len),
heads[mask],
reduction='mean'
)
# Relation loss: cross-entropy conditioned on gold heads
batch_indices = torch.arange(batch_size, device=rel_scores.device).unsqueeze(1)
seq_indices = torch.arange(seq_len, device=rel_scores.device)
rel_scores_gold = rel_scores[batch_indices, seq_indices, heads]
rel_loss = F.cross_entropy(
rel_scores_gold[mask],
rels[mask],
reduction='mean'
)
return arc_loss + rel_loss
def decode(
self,
arc_scores: torch.Tensor,
rel_scores: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decode predictions using MST or greedy decoding.
Args:
arc_scores: (batch, seq_len, seq_len) - Arc scores
rel_scores: (batch, seq_len, seq_len, n_rels) - Relation scores
mask: (batch, seq_len) - Token mask
Returns:
arc_preds: (batch, seq_len) - Predicted head indices
rel_preds: (batch, seq_len) - Predicted relation indices
"""
batch_size, seq_len = mask.shape
device = arc_scores.device
if self.use_mst:
# MST decoding for valid tree structure
lengths = mask.sum(dim=1).cpu().numpy()
arc_scores_np = arc_scores.cpu().numpy()
arc_preds_np = batch_mst_decode(arc_scores_np, lengths)
arc_preds = torch.from_numpy(arc_preds_np).to(device)
else:
# Greedy decoding
arc_preds = arc_scores.argmax(dim=-1)
# Get relation predictions for predicted heads
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)
seq_indices = torch.arange(seq_len, device=device)
rel_scores_pred = rel_scores[batch_indices, seq_indices, arc_preds]
rel_preds = rel_scores_pred.argmax(dim=-1)
return arc_preds, rel_preds
def predict(
self,
words: List[str],
return_probs: bool = False,
) -> List[Tuple[str, int, str]]:
"""
Predict dependencies for a single sentence.
Args:
words: List of words (pre-tokenized)
return_probs: Whether to return probability scores
Returns:
List of (word, head, deprel) tuples
"""
self.eval()
device = next(self.parameters()).device
# Tokenize with word boundary tracking
encoded = self.tokenize_with_alignment([words])
# Move to device
input_ids = encoded['input_ids'].to(device)
attention_mask = encoded['attention_mask'].to(device)
word_starts = encoded['word_starts'].to(device)
word_mask = encoded['word_mask'].to(device)
with torch.no_grad():
arc_scores, rel_scores = self.forward(
input_ids, attention_mask, word_starts, word_mask
)
arc_preds, rel_preds = self.decode(arc_scores, rel_scores, word_mask)
# Convert to list of tuples
arc_preds = arc_preds[0].cpu().tolist()
rel_preds = rel_preds[0].cpu().tolist()
results = []
for i, word in enumerate(words):
head = arc_preds[i]
rel_idx = rel_preds[i]
rel = self.idx2rel.get(rel_idx, "dep")
results.append((word, head, rel))
return results
def tokenize_with_alignment(
self,
sentences: List[List[str]],
max_length: int = 256,
) -> Dict[str, torch.Tensor]:
"""
Tokenize sentences and track word-subword alignment.
Args:
sentences: List of sentences, where each sentence is a list of words
max_length: Maximum subword sequence length
Returns:
Dictionary with input_ids, attention_mask, word_starts, word_mask
"""
batch_input_ids = []
batch_attention_mask = []
batch_word_starts = []
batch_word_mask = []
for words in sentences:
# Tokenize each word separately to track boundaries
word_starts = []
subword_ids = [self.tokenizer.cls_token_id]
for word in words:
word_starts.append(len(subword_ids))
word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
subword_ids.extend(word_tokens)
subword_ids.append(self.tokenizer.sep_token_id)
# Truncate if needed
if len(subword_ids) > max_length:
subword_ids = subword_ids[:max_length-1] + [self.tokenizer.sep_token_id]
# Truncate word_starts that go beyond
word_starts = [ws for ws in word_starts if ws < max_length - 1]
attention_mask = [1] * len(subword_ids)
batch_input_ids.append(subword_ids)
batch_attention_mask.append(attention_mask)
batch_word_starts.append(word_starts)
batch_word_mask.append([1] * len(word_starts))
# Pad sequences
max_subword_len = max(len(ids) for ids in batch_input_ids)
max_word_len = max(len(ws) for ws in batch_word_starts)
padded_input_ids = []
padded_attention_mask = []
padded_word_starts = []
padded_word_mask = []
for i in range(len(sentences)):
# Pad subwords
pad_len = max_subword_len - len(batch_input_ids[i])
padded_input_ids.append(
batch_input_ids[i] + [self.tokenizer.pad_token_id] * pad_len
)
padded_attention_mask.append(
batch_attention_mask[i] + [0] * pad_len
)
# Pad words
word_pad_len = max_word_len - len(batch_word_starts[i])
# Use 0 for padding word_starts (points to CLS token, but masked)
padded_word_starts.append(
batch_word_starts[i] + [0] * word_pad_len
)
padded_word_mask.append(
batch_word_mask[i] + [0] * word_pad_len
)
return {
'input_ids': torch.tensor(padded_input_ids, dtype=torch.long),
'attention_mask': torch.tensor(padded_attention_mask, dtype=torch.long),
'word_starts': torch.tensor(padded_word_starts, dtype=torch.long),
'word_mask': torch.tensor(padded_word_mask, dtype=torch.bool),
}
def save(self, path: str, vocab: Optional[Dict] = None):
"""
Save model checkpoint.
Args:
path: Directory path to save the model
vocab: Vocabulary dict with rel2idx and idx2rel mappings
"""
import os
os.makedirs(path, exist_ok=True)
# Save model state
checkpoint = {
'model_state_dict': self.state_dict(),
'config': {
'encoder_name': self.encoder_name,
'n_rels': self.n_rels,
'arc_hidden': self.mlp_arc_dep.linear.out_features,
'rel_hidden': self.mlp_rel_dep.linear.out_features,
'dropout': self.dropout.p,
'use_mst': self.use_mst,
},
}
if vocab is not None:
checkpoint['vocab'] = vocab
torch.save(checkpoint, os.path.join(path, 'model.pt'))
# Save tokenizer
self.tokenizer.save_pretrained(path)
@classmethod
def load(cls, path: str, device: str = 'cpu') -> 'PhoBERTDependencyParser':
"""
Load model from checkpoint.
Args:
path: Directory path containing the saved model
device: Device to load the model to
Returns:
Loaded PhoBERTDependencyParser model
"""
import os
checkpoint = torch.load(
os.path.join(path, 'model.pt'),
map_location=device,
weights_only=False
)
config = checkpoint['config']
# Create model
model = cls(
encoder_name=config['encoder_name'],
n_rels=config['n_rels'],
arc_hidden=config['arc_hidden'],
rel_hidden=config['rel_hidden'],
dropout=config['dropout'],
use_mst=config.get('use_mst', True),
)
# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])
# Load vocabulary
if 'vocab' in checkpoint:
model.rel2idx = checkpoint['vocab'].get('rel2idx', {})
model.idx2rel = checkpoint['vocab'].get('idx2rel', {})
else:
model.rel2idx = {}
model.idx2rel = {}
model.to(device)
return model