# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0.0", # "transformers>=4.30.0", # "datasets>=2.14.0", # "click>=8.0.0", # "tqdm>=4.60.0", # "wandb>=0.15.0", # "python-dotenv>=1.0.0", # ] # /// """ Training script for Bamboo-1 Vietnamese Dependency Parser. Supports multiple methods: - baseline: BiLSTM + Biaffine (Dozat & Manning, 2017) - trankit: XLM-RoBERTa + Biaffine (Nguyen et al., 2021) Usage: uv run scripts/train.py # Default baseline uv run scripts/train.py --method trankit # Reproduce Trankit uv run scripts/train.py --method trankit --dataset ud-vtb # Trankit on VTB """ import sys from pathlib import Path from collections import Counter from dataclasses import dataclass from typing import List, Tuple, Optional # Load environment variables from dotenv import load_dotenv load_dotenv() import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence from torch.utils.data import Dataset, DataLoader from torch.optim import Adam, AdamW from torch.optim.lr_scheduler import ExponentialLR from tqdm import tqdm import click sys.path.insert(0, str(Path(__file__).parent.parent)) from src.corpus import UDD1Corpus from src.ud_corpus import UDVietnameseVTB from src.vndt_corpus import VnDTCorpus from src.cost_estimate import CostTracker, detect_hardware # ============================================================================ # Data Processing # ============================================================================ @dataclass class Sentence: """A dependency-parsed sentence.""" words: List[str] heads: List[int] rels: List[str] def read_conllu(path: str) -> List[Sentence]: """Read CoNLL-U file and return list of sentences.""" sentences = [] words, heads, rels = [], [], [] with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: if words: sentences.append(Sentence(words, heads, rels)) words, heads, rels = [], [], [] elif line.startswith('#'): continue else: parts = line.split('\t') if '-' in parts[0] or '.' in parts[0]: # Skip multi-word tokens continue words.append(parts[1]) # FORM heads.append(int(parts[6])) # HEAD rels.append(parts[7]) # DEPREL if words: sentences.append(Sentence(words, heads, rels)) return sentences class Vocabulary: """Vocabulary for words, characters, and relations.""" PAD = '' UNK = '' def __init__(self, min_freq: int = 2): self.min_freq = min_freq self.word2idx = {self.PAD: 0, self.UNK: 1} self.char2idx = {self.PAD: 0, self.UNK: 1} self.rel2idx = {} self.idx2rel = {} def build(self, sentences: List[Sentence]): """Build vocabulary from sentences.""" word_counts = Counter() char_counts = Counter() rel_counts = Counter() for sent in sentences: for word in sent.words: word_counts[word.lower()] += 1 for char in word: char_counts[char] += 1 for rel in sent.rels: rel_counts[rel] += 1 # Words for word, count in word_counts.items(): if count >= self.min_freq and word not in self.word2idx: self.word2idx[word] = len(self.word2idx) # Characters for char, count in char_counts.items(): if char not in self.char2idx: self.char2idx[char] = len(self.char2idx) # Relations for rel in rel_counts: if rel not in self.rel2idx: idx = len(self.rel2idx) self.rel2idx[rel] = idx self.idx2rel[idx] = rel def encode_word(self, word: str) -> int: return self.word2idx.get(word.lower(), self.word2idx[self.UNK]) def encode_char(self, char: str) -> int: return self.char2idx.get(char, self.char2idx[self.UNK]) def encode_rel(self, rel: str) -> int: return self.rel2idx.get(rel, 0) @property def n_words(self) -> int: return len(self.word2idx) @property def n_chars(self) -> int: return len(self.char2idx) @property def n_rels(self) -> int: return len(self.rel2idx) class DependencyDataset(Dataset): """Dataset for dependency parsing.""" def __init__(self, sentences: List[Sentence], vocab: Vocabulary): self.sentences = sentences self.vocab = vocab def __len__(self): return len(self.sentences) def __getitem__(self, idx): sent = self.sentences[idx] # Encode words word_ids = [self.vocab.encode_word(w) for w in sent.words] # Encode characters char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words] # Heads and relations heads = sent.heads rels = [self.vocab.encode_rel(r) for r in sent.rels] return word_ids, char_ids, heads, rels def collate_fn(batch): """Collate function for DataLoader.""" word_ids, char_ids, heads, rels = zip(*batch) # Get lengths lengths = [len(w) for w in word_ids] max_len = max(lengths) # Pad words word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long) for i, wids in enumerate(word_ids): word_ids_padded[i, :len(wids)] = torch.tensor(wids) # Pad characters max_word_len = max(max(len(c) for c in chars) for chars in char_ids) char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long) for i, chars in enumerate(char_ids): for j, c in enumerate(chars): char_ids_padded[i, j, :len(c)] = torch.tensor(c) # Pad heads heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long) for i, h in enumerate(heads): heads_padded[i, :len(h)] = torch.tensor(h) # Pad rels rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long) for i, r in enumerate(rels): rels_padded[i, :len(r)] = torch.tensor(r) # Mask mask = torch.zeros(len(batch), max_len, dtype=torch.bool) for i, l in enumerate(lengths): mask[i, :l] = True lengths = torch.tensor(lengths) return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths # ============================================================================ # Model # ============================================================================ class CharLSTM(nn.Module): """Character-level LSTM embeddings.""" def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100): super().__init__() self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0) self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True) self.hidden_dim = hidden_dim def forward(self, chars): """ Args: chars: (batch, seq_len, max_word_len) Returns: (batch, seq_len, hidden_dim) """ batch, seq_len, max_word_len = chars.shape # Flatten chars_flat = chars.view(-1, max_word_len) # (batch * seq_len, max_word_len) # Get word lengths word_lens = (chars_flat != 0).sum(dim=1) word_lens = word_lens.clamp(min=1) # Embed char_embeds = self.embed(chars_flat) # (batch * seq_len, max_word_len, char_dim) # Pack and run LSTM packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False) _, (hidden, _) = self.lstm(packed) # Concatenate forward and backward hidden states hidden = torch.cat([hidden[0], hidden[1]], dim=-1) # (batch * seq_len, hidden_dim) return hidden.view(batch, seq_len, self.hidden_dim) class MLP(nn.Module): """Multi-layer perceptron.""" def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): super().__init__() self.linear = nn.Linear(input_dim, hidden_dim) self.activation = nn.LeakyReLU(0.1) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.activation(self.linear(x))) class Biaffine(nn.Module): """Biaffine attention layer.""" def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.bias_x = bias_x self.bias_y = bias_y self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)) nn.init.xavier_uniform_(self.weight) def forward(self, x, y): """ Args: x: (batch, seq_len, input_dim) - dependent y: (batch, seq_len, input_dim) - head Returns: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1 """ if self.bias_x: x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) if self.bias_y: y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) # (batch, seq_len, output_dim, input_dim+1) x = torch.einsum('bxi,oij->bxoj', x, self.weight) # (batch, seq_len, seq_len, output_dim) scores = torch.einsum('bxoj,byj->bxyo', x, y) if self.output_dim == 1: scores = scores.squeeze(-1) return scores class BiaffineDependencyParser(nn.Module): """Biaffine Dependency Parser (Dozat & Manning, 2017).""" def __init__( self, n_words: int, n_chars: int, n_rels: int, word_dim: int = 100, char_dim: int = 50, char_hidden: int = 100, lstm_hidden: int = 400, lstm_layers: int = 3, arc_hidden: int = 500, rel_hidden: int = 100, dropout: float = 0.33, ): super().__init__() self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0) self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden) input_dim = word_dim + char_hidden self.lstm = nn.LSTM( input_dim, lstm_hidden // 2, num_layers=lstm_layers, batch_first=True, bidirectional=True, dropout=dropout if lstm_layers > 1 else 0 ) self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout) self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout) self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout) self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout) self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) self.dropout = nn.Dropout(dropout) self.n_rels = n_rels def forward(self, words, chars, mask): """ Args: words: (batch, seq_len) chars: (batch, seq_len, max_word_len) mask: (batch, seq_len) Returns: arc_scores: (batch, seq_len, seq_len) rel_scores: (batch, seq_len, seq_len, n_rels) """ # Embeddings word_embeds = self.word_embed(words) char_embeds = self.char_lstm(chars) embeds = torch.cat([word_embeds, char_embeds], dim=-1) embeds = self.dropout(embeds) # BiLSTM lengths = mask.sum(dim=1).cpu() packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False) lstm_out, _ = self.lstm(packed) lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1)) lstm_out = self.dropout(lstm_out) # MLP arc_dep = self.mlp_arc_dep(lstm_out) arc_head = self.mlp_arc_head(lstm_out) rel_dep = self.mlp_rel_dep(lstm_out) rel_head = self.mlp_rel_head(lstm_out) # Biaffine arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq_len, seq_len) rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq_len, seq_len, n_rels) return arc_scores, rel_scores def loss(self, arc_scores, rel_scores, heads, rels, mask): """Compute loss.""" batch_size, seq_len = mask.shape # Arc loss arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf')) arc_loss = F.cross_entropy( arc_scores[mask].view(-1, seq_len), heads[mask], reduction='mean' ) # Rel loss - select scores for gold heads rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads] rel_loss = F.cross_entropy( rel_scores_gold[mask], rels[mask], reduction='mean' ) return arc_loss + rel_loss def decode(self, arc_scores, rel_scores, mask): """Decode predictions.""" # Greedy decoding arc_preds = arc_scores.argmax(dim=-1) batch_size, seq_len = mask.shape rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds] rel_preds = rel_scores_pred.argmax(dim=-1) return arc_preds, rel_preds # ============================================================================ # Trankit-style Transformer Parser (XLM-RoBERTa + Biaffine) # ============================================================================ class TransformerDependencyParser(nn.Module): """ Trankit-style dependency parser using XLM-RoBERTa. Architecture follows Nguyen et al. 2021 EACL: - XLM-RoBERTa encoder - Word-level pooling (first subword) - Biaffine attention for arc/rel prediction """ def __init__( self, n_rels: int, encoder: str = "xlm-roberta-base", arc_hidden: int = 500, rel_hidden: int = 100, dropout: float = 0.33, ): super().__init__() from transformers import AutoModel, AutoTokenizer self.encoder_name = encoder self.tokenizer = AutoTokenizer.from_pretrained(encoder) self.encoder = AutoModel.from_pretrained(encoder) self.hidden_size = self.encoder.config.hidden_size # Biaffine layers self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout) self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout) self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout) self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout) self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) self.dropout = nn.Dropout(dropout) self.n_rels = n_rels def encode_pretokenized(self, input_ids, attention_mask, word_starts, word_mask): """Encode pre-tokenized batch (fast path - no tokenization overhead).""" outputs = self.encoder(input_ids, attention_mask=attention_mask) hidden = outputs.last_hidden_state # Gather word-level representations using pre-computed positions word_starts_exp = word_starts.unsqueeze(-1).expand(-1, -1, self.hidden_size) word_starts_exp = word_starts_exp.clamp(0, hidden.size(1) - 1) word_hidden = torch.gather(hidden, 1, word_starts_exp) return word_hidden, word_mask def forward(self, word_hidden, word_mask): """Compute arc and relation scores from word representations.""" word_hidden = self.dropout(word_hidden) # Biaffine scoring arc_dep = self.mlp_arc_dep(word_hidden) arc_head = self.mlp_arc_head(word_hidden) rel_dep = self.mlp_rel_dep(word_hidden) rel_head = self.mlp_rel_head(word_hidden) arc_scores = self.arc_attn(arc_dep, arc_head) rel_scores = self.rel_attn(rel_dep, rel_head) return arc_scores, rel_scores def loss(self, arc_scores, rel_scores, heads, rels, mask): """Compute cross-entropy loss.""" batch_size, seq_len = mask.shape # Arc loss arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf')) arc_loss = F.cross_entropy( arc_scores[mask].view(-1, seq_len), heads[mask], reduction='mean' ) # Rel loss rel_scores_gold = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1), torch.arange(seq_len, device=mask.device), heads] rel_loss = F.cross_entropy( rel_scores_gold[mask], rels[mask], reduction='mean' ) return arc_loss + rel_loss def decode(self, arc_scores, rel_scores, mask): """Greedy decoding.""" arc_preds = arc_scores.argmax(dim=-1) batch_size, seq_len = mask.shape rel_scores_pred = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1), torch.arange(seq_len, device=mask.device), arc_preds] rel_preds = rel_scores_pred.argmax(dim=-1) return arc_preds, rel_preds class TransformerDataset(Dataset): """Pre-tokenized dataset - tokenizes once at creation for fast training.""" def __init__(self, sentences: List[Sentence], vocab, tokenizer): self.data = [] for sent in tqdm(sentences, desc="Pre-tokenizing", leave=False): input_ids = [tokenizer.cls_token_id] word_starts = [] for word in sent.words: word_starts.append(len(input_ids)) tokens = tokenizer.encode(word, add_special_tokens=False) input_ids.extend(tokens if tokens else [tokenizer.unk_token_id]) input_ids.append(tokenizer.sep_token_id) self.data.append(( torch.tensor(input_ids, dtype=torch.long), torch.tensor(word_starts, dtype=torch.long), torch.tensor(sent.heads, dtype=torch.long), torch.tensor([vocab.encode_rel(r) for r in sent.rels], dtype=torch.long), len(sent.words), )) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def transformer_collate_fn(batch): """Collate for pre-tokenized transformer data.""" input_ids, word_starts, heads, rels, n_words = zip(*batch) batch_size = len(batch) max_subwords = max(ids.size(0) for ids in input_ids) max_words = max(n_words) padded_ids = torch.zeros(batch_size, max_subwords, dtype=torch.long) attention_mask = torch.zeros(batch_size, max_subwords, dtype=torch.long) padded_starts = torch.zeros(batch_size, max_words, dtype=torch.long) padded_heads = torch.zeros(batch_size, max_words, dtype=torch.long) padded_rels = torch.zeros(batch_size, max_words, dtype=torch.long) word_mask = torch.zeros(batch_size, max_words, dtype=torch.bool) for i, (ids, starts, h, r, nw) in enumerate(zip(input_ids, word_starts, heads, rels, n_words)): padded_ids[i, :ids.size(0)] = ids attention_mask[i, :ids.size(0)] = 1 padded_starts[i, :starts.size(0)] = starts padded_heads[i, :nw] = h padded_rels[i, :nw] = r word_mask[i, :nw] = True return padded_ids, attention_mask, padded_starts, padded_heads, padded_rels, word_mask def evaluate_transformer(model, dataloader, device): """Evaluate transformer-based model.""" model.eval() total_arcs = 0 correct_arcs = 0 correct_rels = 0 with torch.no_grad(): for batch in dataloader: input_ids, attn_mask, word_starts, heads, rels, mask = [x.to(device) for x in batch] word_hidden, word_mask = model.encode_pretokenized(input_ids, attn_mask, word_starts, mask) arc_scores, rel_scores = model(word_hidden, word_mask) arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask) arc_correct = (arc_preds == heads) & mask rel_correct = (rel_preds == rels) & mask & arc_correct total_arcs += mask.sum().item() correct_arcs += arc_correct.sum().item() correct_rels += rel_correct.sum().item() uas = correct_arcs / total_arcs * 100 las = correct_rels / total_arcs * 100 return uas, las # ============================================================================ # Training # ============================================================================ def evaluate(model, dataloader, device): """Evaluate model and return UAS/LAS.""" model.eval() total_arcs = 0 correct_arcs = 0 correct_rels = 0 with torch.no_grad(): for batch in dataloader: words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] arc_scores, rel_scores = model(words, chars, mask) arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask) # Count correct arc_correct = (arc_preds == heads) & mask rel_correct = (rel_preds == rels) & mask & arc_correct total_arcs += mask.sum().item() correct_arcs += arc_correct.sum().item() correct_rels += rel_correct.sum().item() uas = correct_arcs / total_arcs * 100 las = correct_rels / total_arcs * 100 return uas, las @click.command() @click.option('--method', type=click.Choice(['baseline', 'trankit']), default='baseline', help='Parser method: baseline (BiLSTM) or trankit (XLM-RoBERTa)') @click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb', 'vndt']), default='udd1', help='Dataset: udd1 (UDD-1), ud-vtb (UD Vietnamese VTB), or vndt (VnDT v1.1)') @click.option('--encoder', default='xlm-roberta-base', help='Transformer encoder for trankit method') @click.option('--output', '-o', default='models/bamboo-1', help='Output directory') @click.option('--epochs', default=100, type=int, help='Number of epochs') @click.option('--batch-size', default=64, type=int, help='Batch size') @click.option('--lr', default=2e-3, type=float, help='Learning rate for baseline') @click.option('--bert-lr', default=2e-5, type=float, help='Encoder learning rate for trankit') @click.option('--head-lr', default=2e-4, type=float, help='Head learning rate for trankit') @click.option('--warmup-steps', default=200, type=int, help='Warmup steps for trankit') @click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size (baseline)') @click.option('--lstm-layers', default=3, type=int, help='LSTM layers (baseline)') @click.option('--patience', default=5, type=int, help='Early stopping patience') @click.option('--force-download', is_flag=True, help='Force re-download dataset') @click.option('--data-dir', default=None, help='Custom data directory') @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation') @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds') @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') @click.option('--wandb-project', default='bamboo-1', help='W&B project name') @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)') @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)') @click.option('--eval-every', default=2, type=int, help='Evaluate every N epochs') @click.option('--fp16', is_flag=True, default=True, help='Use mixed precision training') def train(method, dataset, encoder, output, epochs, batch_size, lr, bert_lr, head_lr, warmup_steps, lstm_hidden, lstm_layers, patience, force_download, data_dir, gpu_type, cost_interval, use_wandb, wandb_project, max_time, sample, eval_every, fp16): """Train Bamboo-1 Vietnamese Dependency Parser.""" # Detect hardware hardware = detect_hardware() detected_gpu_type = hardware.get_gpu_type() if gpu_type == "RTX_A4000": gpu_type = detected_gpu_type device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') click.echo(f"Using device: {device}") click.echo(f"Hardware: {hardware}") # CUDA optimizations if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Mixed precision use_amp = fp16 and torch.cuda.is_available() scaler = torch.amp.GradScaler('cuda') if use_amp else None if use_amp: click.echo("Mixed precision (FP16): enabled") # Initialize wandb if use_wandb: import wandb wandb.init( project=wandb_project, config={ "method": method, "dataset": dataset, "encoder": encoder if method == "trankit" else "bilstm", "epochs": epochs, "batch_size": batch_size, "lr": lr if method == "baseline" else bert_lr, "head_lr": head_lr if method == "trankit" else None, "lstm_hidden": lstm_hidden if method == "baseline" else None, "lstm_layers": lstm_layers if method == "baseline" else None, "patience": patience, "gpu_type": gpu_type, "hardware": hardware.to_dict(), } ) click.echo(f"W&B logging enabled: {wandb.run.url}") click.echo("=" * 60) click.echo(f"Bamboo-1: Vietnamese Dependency Parser ({method.upper()})") click.echo("=" * 60) # Load corpus click.echo(f"\nLoading {dataset.upper()} corpus...") if dataset == 'udd1': corpus = UDD1Corpus(data_dir=data_dir, force_download=force_download) elif dataset == 'ud-vtb': corpus = UDVietnameseVTB(data_dir=data_dir, force_download=force_download) else: # vndt corpus = VnDTCorpus(data_dir=data_dir, force_download=force_download) train_sents = read_conllu(corpus.train) dev_sents = read_conllu(corpus.dev) test_sents = read_conllu(corpus.test) # Sample subset if requested if sample > 0: train_sents = train_sents[:sample] dev_sents = dev_sents[:min(sample // 2, len(dev_sents))] test_sents = test_sents[:min(sample // 2, len(test_sents))] click.echo(f" Sampling {sample} sentences...") click.echo(f" Train: {len(train_sents)} sentences") click.echo(f" Dev: {len(dev_sents)} sentences") click.echo(f" Test: {len(test_sents)} sentences") # Build vocabulary click.echo("\nBuilding vocabulary...") vocab = Vocabulary(min_freq=2) vocab.build(train_sents) if method == "baseline": click.echo(f" Words: {vocab.n_words}") click.echo(f" Chars: {vocab.n_chars}") click.echo(f" Relations: {vocab.n_rels}") # Create datasets and model based on method if method == "trankit": # Trankit method: XLM-RoBERTa + Biaffine click.echo(f"\nInitializing model with {encoder}...") model = TransformerDependencyParser( n_rels=vocab.n_rels, encoder=encoder, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) encoder_params = sum(p.numel() for p in model.encoder.parameters()) head_params = n_params - encoder_params click.echo(f" Total parameters: {n_params:,}") click.echo(f" Encoder parameters: {encoder_params:,}") click.echo(f" Head parameters: {head_params:,}") # Pre-tokenize datasets (tokenize once, not every epoch) click.echo("\nPre-tokenizing datasets...") use_pin = torch.cuda.is_available() train_dataset = TransformerDataset(train_sents, vocab, model.tokenizer) dev_dataset = TransformerDataset(dev_sents, vocab, model.tokenizer) test_dataset = TransformerDataset(test_sents, vocab, model.tokenizer) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=transformer_collate_fn, pin_memory=use_pin) dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=transformer_collate_fn, pin_memory=use_pin) test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=transformer_collate_fn, pin_memory=use_pin) # Differential learning rates encoder_params_list = list(model.encoder.parameters()) head_params_list = [p for n, p in model.named_parameters() if 'encoder' not in n] optimizer = AdamW([ {'params': encoder_params_list, 'lr': bert_lr}, {'params': head_params_list, 'lr': head_lr}, ], weight_decay=0.01) # Learning rate scheduler with warmup total_steps = len(train_loader) * epochs def lr_lambda(step): if step < warmup_steps: return step / warmup_steps return max(0.0, (total_steps - step) / (total_steps - warmup_steps)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) eval_fn = evaluate_transformer else: # Baseline method: BiLSTM + Biaffine train_dataset = DependencyDataset(train_sents, vocab) dev_dataset = DependencyDataset(dev_sents, vocab) test_dataset = DependencyDataset(test_sents, vocab) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn) test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn) click.echo("\nInitializing BiLSTM model...") model = BiaffineDependencyParser( n_words=vocab.n_words, n_chars=vocab.n_chars, n_rels=vocab.n_rels, lstm_hidden=lstm_hidden, lstm_layers=lstm_layers, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) click.echo(f" Parameters: {n_params:,}") optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9)) scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000)) eval_fn = evaluate # Training click.echo(f"\nTraining for {epochs} epochs...") if max_time > 0: click.echo(f"Time limit: {max_time} minutes") output_path = Path(output) output_path.mkdir(parents=True, exist_ok=True) # Cost tracking cost_tracker = CostTracker(gpu_type=gpu_type) cost_tracker.report_interval = cost_interval cost_tracker.start() click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr") best_las = -1 no_improve = 0 time_limit_seconds = max_time * 60 if max_time > 0 else float('inf') for epoch in range(1, epochs + 1): # Check time limit if cost_tracker.elapsed_seconds() >= time_limit_seconds: click.echo(f"\nTime limit reached ({max_time} minutes)") break model.train() total_loss = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False) for batch in pbar: optimizer.zero_grad() if method == "trankit": input_ids, attn_mask, word_starts, heads, rels, mask = [x.to(device) for x in batch] with torch.amp.autocast('cuda', enabled=use_amp): word_hidden, word_mask = model.encode_pretokenized(input_ids, attn_mask, word_starts, mask) arc_scores, rel_scores = model(word_hidden, word_mask) loss = model.loss(arc_scores, rel_scores, heads, rels, mask) else: words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] arc_scores, rel_scores = model(words, chars, mask) loss = model.loss(arc_scores, rel_scores, heads, rels, mask) if use_amp and scaler: scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), 5.0) scaler.step(optimizer) scaler.update() else: loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() scheduler.step() total_loss += loss.item() pbar.set_postfix({'loss': f'{loss.item():.4f}'}) # Evaluate (skip if not eval epoch, unless last epoch) if epoch % eval_every != 0 and epoch != epochs: avg_loss = total_loss / len(train_loader) current_lr = optimizer.param_groups[0]['lr'] click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") continue dev_uas, dev_las = eval_fn(model, dev_loader, device) # Cost update progress = epoch / epochs current_cost = cost_tracker.current_cost() estimated_total_cost = cost_tracker.estimate_total_cost(progress) elapsed_minutes = cost_tracker.elapsed_seconds() / 60 cost_status = cost_tracker.update(epoch, epochs) if cost_status: click.echo(f" [{cost_status}]") avg_loss = total_loss / len(train_loader) click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%") # Log to wandb if use_wandb: wandb.log({ "epoch": epoch, "train/loss": avg_loss, "dev/uas": dev_uas, "dev/las": dev_las, "cost/current_usd": current_cost, "cost/estimated_total_usd": estimated_total_cost, "cost/elapsed_minutes": elapsed_minutes, }) # Save best model if dev_las >= best_las: best_las = dev_las no_improve = 0 if method == "trankit": config = { 'method': 'trankit', 'encoder': encoder, 'n_rels': vocab.n_rels, } else: config = { 'method': 'baseline', 'n_words': vocab.n_words, 'n_chars': vocab.n_chars, 'n_rels': vocab.n_rels, 'lstm_hidden': lstm_hidden, 'lstm_layers': lstm_layers, } # Save to local tmp first to avoid network filesystem issues import tempfile, shutil with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as tmp: torch.save({ 'model': model.state_dict(), 'vocab': vocab, 'config': config, }, tmp.name) shutil.move(tmp.name, output_path / 'model.pt') click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)") else: no_improve += 1 if no_improve >= patience: click.echo(f"\nEarly stopping after {patience} epochs without improvement") break # Final evaluation click.echo("\nLoading best model for final evaluation...") checkpoint = torch.load(output_path / 'model.pt', weights_only=False) model.load_state_dict(checkpoint['model']) test_uas, test_las = eval_fn(model, test_loader, device) click.echo(f"\nTest Results:") click.echo(f" UAS: {test_uas:.2f}%") click.echo(f" LAS: {test_las:.2f}%") click.echo(f"\nModel saved to: {output_path}") # Final cost summary final_cost = cost_tracker.current_cost() click.echo(f"\n{cost_tracker.summary(epoch, epochs)}") # Log final metrics to wandb if use_wandb: wandb.log({ "test/uas": test_uas, "test/las": test_las, "cost/final_usd": final_cost, }) wandb.finish() if __name__ == '__main__': train()