# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0.0", # "datasets>=2.14.0", # "click>=8.0.0", # "tqdm>=4.60.0", # "wandb>=0.15.0", # ] # /// """ Training script for Bamboo-1 Vietnamese Dependency Parser. Biaffine parser implementation from scratch (Dozat & Manning, 2017). Usage: uv run scripts/train.py uv run scripts/train.py --output models/bamboo-1 --epochs 100 """ import sys from pathlib import Path from collections import Counter from dataclasses import dataclass from typing import List, Tuple, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence from torch.utils.data import Dataset, DataLoader from torch.optim import Adam from torch.optim.lr_scheduler import ExponentialLR from tqdm import tqdm import click sys.path.insert(0, str(Path(__file__).parent.parent)) from bamboo1.corpus import UDD1Corpus from scripts.cost_estimate import CostTracker, detect_hardware # ============================================================================ # Data Processing # ============================================================================ @dataclass class Sentence: """A dependency-parsed sentence.""" words: List[str] heads: List[int] rels: List[str] def read_conllu(path: str) -> List[Sentence]: """Read CoNLL-U file and return list of sentences.""" sentences = [] words, heads, rels = [], [], [] with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: if words: sentences.append(Sentence(words, heads, rels)) words, heads, rels = [], [], [] elif line.startswith('#'): continue else: parts = line.split('\t') if '-' in parts[0] or '.' in parts[0]: # Skip multi-word tokens continue words.append(parts[1]) # FORM heads.append(int(parts[6])) # HEAD rels.append(parts[7]) # DEPREL if words: sentences.append(Sentence(words, heads, rels)) return sentences class Vocabulary: """Vocabulary for words, characters, and relations.""" PAD = '' UNK = '' def __init__(self, min_freq: int = 2): self.min_freq = min_freq self.word2idx = {self.PAD: 0, self.UNK: 1} self.char2idx = {self.PAD: 0, self.UNK: 1} self.rel2idx = {} self.idx2rel = {} def build(self, sentences: List[Sentence]): """Build vocabulary from sentences.""" word_counts = Counter() char_counts = Counter() rel_counts = Counter() for sent in sentences: for word in sent.words: word_counts[word.lower()] += 1 for char in word: char_counts[char] += 1 for rel in sent.rels: rel_counts[rel] += 1 # Words for word, count in word_counts.items(): if count >= self.min_freq and word not in self.word2idx: self.word2idx[word] = len(self.word2idx) # Characters for char, count in char_counts.items(): if char not in self.char2idx: self.char2idx[char] = len(self.char2idx) # Relations for rel in rel_counts: if rel not in self.rel2idx: idx = len(self.rel2idx) self.rel2idx[rel] = idx self.idx2rel[idx] = rel def encode_word(self, word: str) -> int: return self.word2idx.get(word.lower(), self.word2idx[self.UNK]) def encode_char(self, char: str) -> int: return self.char2idx.get(char, self.char2idx[self.UNK]) def encode_rel(self, rel: str) -> int: return self.rel2idx.get(rel, 0) @property def n_words(self) -> int: return len(self.word2idx) @property def n_chars(self) -> int: return len(self.char2idx) @property def n_rels(self) -> int: return len(self.rel2idx) class DependencyDataset(Dataset): """Dataset for dependency parsing.""" def __init__(self, sentences: List[Sentence], vocab: Vocabulary): self.sentences = sentences self.vocab = vocab def __len__(self): return len(self.sentences) def __getitem__(self, idx): sent = self.sentences[idx] # Encode words word_ids = [self.vocab.encode_word(w) for w in sent.words] # Encode characters char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words] # Heads and relations heads = sent.heads rels = [self.vocab.encode_rel(r) for r in sent.rels] return word_ids, char_ids, heads, rels def collate_fn(batch): """Collate function for DataLoader.""" word_ids, char_ids, heads, rels = zip(*batch) # Get lengths lengths = [len(w) for w in word_ids] max_len = max(lengths) # Pad words word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long) for i, wids in enumerate(word_ids): word_ids_padded[i, :len(wids)] = torch.tensor(wids) # Pad characters max_word_len = max(max(len(c) for c in chars) for chars in char_ids) char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long) for i, chars in enumerate(char_ids): for j, c in enumerate(chars): char_ids_padded[i, j, :len(c)] = torch.tensor(c) # Pad heads heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long) for i, h in enumerate(heads): heads_padded[i, :len(h)] = torch.tensor(h) # Pad rels rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long) for i, r in enumerate(rels): rels_padded[i, :len(r)] = torch.tensor(r) # Mask mask = torch.zeros(len(batch), max_len, dtype=torch.bool) for i, l in enumerate(lengths): mask[i, :l] = True lengths = torch.tensor(lengths) return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths # ============================================================================ # Model # ============================================================================ class CharLSTM(nn.Module): """Character-level LSTM embeddings.""" def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100): super().__init__() self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0) self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True) self.hidden_dim = hidden_dim def forward(self, chars): """ Args: chars: (batch, seq_len, max_word_len) Returns: (batch, seq_len, hidden_dim) """ batch, seq_len, max_word_len = chars.shape # Flatten chars_flat = chars.view(-1, max_word_len) # (batch * seq_len, max_word_len) # Get word lengths word_lens = (chars_flat != 0).sum(dim=1) word_lens = word_lens.clamp(min=1) # Embed char_embeds = self.embed(chars_flat) # (batch * seq_len, max_word_len, char_dim) # Pack and run LSTM packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False) _, (hidden, _) = self.lstm(packed) # Concatenate forward and backward hidden states hidden = torch.cat([hidden[0], hidden[1]], dim=-1) # (batch * seq_len, hidden_dim) return hidden.view(batch, seq_len, self.hidden_dim) class MLP(nn.Module): """Multi-layer perceptron.""" def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): super().__init__() self.linear = nn.Linear(input_dim, hidden_dim) self.activation = nn.LeakyReLU(0.1) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.activation(self.linear(x))) class Biaffine(nn.Module): """Biaffine attention layer.""" def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.bias_x = bias_x self.bias_y = bias_y self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)) nn.init.xavier_uniform_(self.weight) def forward(self, x, y): """ Args: x: (batch, seq_len, input_dim) - dependent y: (batch, seq_len, input_dim) - head Returns: (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1 """ if self.bias_x: x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) if self.bias_y: y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) # (batch, seq_len, output_dim, input_dim+1) x = torch.einsum('bxi,oij->bxoj', x, self.weight) # (batch, seq_len, seq_len, output_dim) scores = torch.einsum('bxoj,byj->bxyo', x, y) if self.output_dim == 1: scores = scores.squeeze(-1) return scores class BiaffineDependencyParser(nn.Module): """Biaffine Dependency Parser (Dozat & Manning, 2017).""" def __init__( self, n_words: int, n_chars: int, n_rels: int, word_dim: int = 100, char_dim: int = 50, char_hidden: int = 100, lstm_hidden: int = 400, lstm_layers: int = 3, arc_hidden: int = 500, rel_hidden: int = 100, dropout: float = 0.33, ): super().__init__() self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0) self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden) input_dim = word_dim + char_hidden self.lstm = nn.LSTM( input_dim, lstm_hidden // 2, num_layers=lstm_layers, batch_first=True, bidirectional=True, dropout=dropout if lstm_layers > 1 else 0 ) self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout) self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout) self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout) self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout) self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) self.dropout = nn.Dropout(dropout) self.n_rels = n_rels def forward(self, words, chars, mask): """ Args: words: (batch, seq_len) chars: (batch, seq_len, max_word_len) mask: (batch, seq_len) Returns: arc_scores: (batch, seq_len, seq_len) rel_scores: (batch, seq_len, seq_len, n_rels) """ # Embeddings word_embeds = self.word_embed(words) char_embeds = self.char_lstm(chars) embeds = torch.cat([word_embeds, char_embeds], dim=-1) embeds = self.dropout(embeds) # BiLSTM lengths = mask.sum(dim=1).cpu() packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False) lstm_out, _ = self.lstm(packed) lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1)) lstm_out = self.dropout(lstm_out) # MLP arc_dep = self.mlp_arc_dep(lstm_out) arc_head = self.mlp_arc_head(lstm_out) rel_dep = self.mlp_rel_dep(lstm_out) rel_head = self.mlp_rel_head(lstm_out) # Biaffine arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq_len, seq_len) rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq_len, seq_len, n_rels) return arc_scores, rel_scores def loss(self, arc_scores, rel_scores, heads, rels, mask): """Compute loss.""" batch_size, seq_len = mask.shape # Arc loss arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf')) arc_loss = F.cross_entropy( arc_scores[mask].view(-1, seq_len), heads[mask], reduction='mean' ) # Rel loss - select scores for gold heads rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads] rel_loss = F.cross_entropy( rel_scores_gold[mask], rels[mask], reduction='mean' ) return arc_loss + rel_loss def decode(self, arc_scores, rel_scores, mask): """Decode predictions.""" # Greedy decoding arc_preds = arc_scores.argmax(dim=-1) batch_size, seq_len = mask.shape rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds] rel_preds = rel_scores_pred.argmax(dim=-1) return arc_preds, rel_preds # ============================================================================ # Training # ============================================================================ def evaluate(model, dataloader, device): """Evaluate model and return UAS/LAS.""" model.eval() total_arcs = 0 correct_arcs = 0 correct_rels = 0 with torch.no_grad(): for batch in dataloader: words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] arc_scores, rel_scores = model(words, chars, mask) arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask) # Count correct arc_correct = (arc_preds == heads) & mask rel_correct = (rel_preds == rels) & mask & arc_correct total_arcs += mask.sum().item() correct_arcs += arc_correct.sum().item() correct_rels += rel_correct.sum().item() uas = correct_arcs / total_arcs * 100 las = correct_rels / total_arcs * 100 return uas, las @click.command() @click.option('--output', '-o', default='models/bamboo-1', help='Output directory') @click.option('--epochs', default=100, type=int, help='Number of epochs') @click.option('--batch-size', default=32, type=int, help='Batch size') @click.option('--lr', default=2e-3, type=float, help='Learning rate') @click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size') @click.option('--lstm-layers', default=3, type=int, help='LSTM layers') @click.option('--patience', default=10, type=int, help='Early stopping patience') @click.option('--force-download', is_flag=True, help='Force re-download dataset') @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation') @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds') @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') @click.option('--wandb-project', default='bamboo-1', help='W&B project name') @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)') @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)') def train(output, epochs, batch_size, lr, lstm_hidden, lstm_layers, patience, force_download, gpu_type, cost_interval, use_wandb, wandb_project, max_time, sample): """Train Bamboo-1 Vietnamese Dependency Parser.""" # Detect hardware hardware = detect_hardware() detected_gpu_type = hardware.get_gpu_type() # Use detected GPU type if not explicitly specified or if using default if gpu_type == "RTX_A4000": # default value gpu_type = detected_gpu_type device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') click.echo(f"Using device: {device}") click.echo(f"Hardware: {hardware}") # Initialize wandb if use_wandb: import wandb wandb.init( project=wandb_project, config={ "epochs": epochs, "batch_size": batch_size, "lr": lr, "lstm_hidden": lstm_hidden, "lstm_layers": lstm_layers, "patience": patience, "gpu_type": gpu_type, "hardware": hardware.to_dict(), } ) click.echo(f"W&B logging enabled: {wandb.run.url}") click.echo("=" * 60) click.echo("Bamboo-1: Vietnamese Dependency Parser") click.echo("=" * 60) # Load corpus click.echo("\nLoading UDD-1 corpus...") corpus = UDD1Corpus(force_download=force_download) train_sents = read_conllu(corpus.train) dev_sents = read_conllu(corpus.dev) test_sents = read_conllu(corpus.test) # Sample subset if requested if sample > 0: train_sents = train_sents[:sample] dev_sents = dev_sents[:min(sample // 2, len(dev_sents))] test_sents = test_sents[:min(sample // 2, len(test_sents))] click.echo(f" Sampling {sample} sentences...") click.echo(f" Train: {len(train_sents)} sentences") click.echo(f" Dev: {len(dev_sents)} sentences") click.echo(f" Test: {len(test_sents)} sentences") # Build vocabulary click.echo("\nBuilding vocabulary...") vocab = Vocabulary(min_freq=2) vocab.build(train_sents) click.echo(f" Words: {vocab.n_words}") click.echo(f" Chars: {vocab.n_chars}") click.echo(f" Relations: {vocab.n_rels}") # Create datasets train_dataset = DependencyDataset(train_sents, vocab) dev_dataset = DependencyDataset(dev_sents, vocab) test_dataset = DependencyDataset(test_sents, vocab) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn) test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn) # Create model click.echo("\nInitializing model...") model = BiaffineDependencyParser( n_words=vocab.n_words, n_chars=vocab.n_chars, n_rels=vocab.n_rels, lstm_hidden=lstm_hidden, lstm_layers=lstm_layers, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) click.echo(f" Parameters: {n_params:,}") # Optimizer optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9)) scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000)) # Training click.echo(f"\nTraining for {epochs} epochs...") if max_time > 0: click.echo(f"Time limit: {max_time} minutes") output_path = Path(output) output_path.mkdir(parents=True, exist_ok=True) # Cost tracking cost_tracker = CostTracker(gpu_type=gpu_type) cost_tracker.report_interval = cost_interval cost_tracker.start() click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr") best_las = -1 no_improve = 0 time_limit_seconds = max_time * 60 if max_time > 0 else float('inf') for epoch in range(1, epochs + 1): # Check time limit if cost_tracker.elapsed_seconds() >= time_limit_seconds: click.echo(f"\nTime limit reached ({max_time} minutes)") break model.train() total_loss = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False) for batch in pbar: words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] optimizer.zero_grad() arc_scores, rel_scores = model(words, chars, mask) loss = model.loss(arc_scores, rel_scores, heads, rels, mask) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() scheduler.step() total_loss += loss.item() pbar.set_postfix({'loss': f'{loss.item():.4f}'}) # Evaluate dev_uas, dev_las = evaluate(model, dev_loader, device) # Cost update progress = epoch / epochs current_cost = cost_tracker.current_cost() estimated_total_cost = cost_tracker.estimate_total_cost(progress) elapsed_minutes = cost_tracker.elapsed_seconds() / 60 cost_status = cost_tracker.update(epoch, epochs) if cost_status: click.echo(f" [{cost_status}]") avg_loss = total_loss / len(train_loader) click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%") # Log to wandb if use_wandb: wandb.log({ "epoch": epoch, "train/loss": avg_loss, "dev/uas": dev_uas, "dev/las": dev_las, "cost/current_usd": current_cost, "cost/estimated_total_usd": estimated_total_cost, "cost/elapsed_minutes": elapsed_minutes, }) # Save best model if dev_las >= best_las: best_las = dev_las no_improve = 0 torch.save({ 'model': model.state_dict(), 'vocab': vocab, 'config': { 'n_words': vocab.n_words, 'n_chars': vocab.n_chars, 'n_rels': vocab.n_rels, 'lstm_hidden': lstm_hidden, 'lstm_layers': lstm_layers, } }, output_path / 'model.pt') click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)") else: no_improve += 1 if no_improve >= patience: click.echo(f"\nEarly stopping after {patience} epochs without improvement") break # Final evaluation click.echo("\nLoading best model for final evaluation...") checkpoint = torch.load(output_path / 'model.pt', weights_only=False) model.load_state_dict(checkpoint['model']) test_uas, test_las = evaluate(model, test_loader, device) click.echo(f"\nTest Results:") click.echo(f" UAS: {test_uas:.2f}%") click.echo(f" LAS: {test_las:.2f}%") click.echo(f"\nModel saved to: {output_path}") # Final cost summary final_cost = cost_tracker.current_cost() click.echo(f"\n{cost_tracker.summary(epoch, epochs)}") # Log final metrics to wandb if use_wandb: wandb.log({ "test/uas": test_uas, "test/las": test_las, "cost/final_usd": final_cost, }) wandb.finish() if __name__ == '__main__': train()