| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Training script for Bamboo-1 Vietnamese Dependency Parser. |
| |
| Supports multiple methods: |
| - baseline: BiLSTM + Biaffine (Dozat & Manning, 2017) |
| - trankit: XLM-RoBERTa + Biaffine (Nguyen et al., 2021) |
| |
| Usage: |
| uv run scripts/train.py # Default baseline |
| uv run scripts/train.py --method trankit # Reproduce Trankit |
| uv run scripts/train.py --method trankit --dataset ud-vtb # Trankit on VTB |
| """ |
|
|
| import sys |
| from pathlib import Path |
| from collections import Counter |
| from dataclasses import dataclass |
| from typing import List, Tuple, Optional |
|
|
| |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import Adam, AdamW |
| from torch.optim.lr_scheduler import ExponentialLR |
| from tqdm import tqdm |
|
|
| import click |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from src.corpus import UDD1Corpus |
| from src.ud_corpus import UDVietnameseVTB |
| from src.vndt_corpus import VnDTCorpus |
| from src.cost_estimate import CostTracker, detect_hardware |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Sentence: |
| """A dependency-parsed sentence.""" |
| words: List[str] |
| heads: List[int] |
| rels: List[str] |
|
|
|
|
| def read_conllu(path: str) -> List[Sentence]: |
| """Read CoNLL-U file and return list of sentences.""" |
| sentences = [] |
| words, heads, rels = [], [], [] |
|
|
| with open(path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| if words: |
| sentences.append(Sentence(words, heads, rels)) |
| words, heads, rels = [], [], [] |
| elif line.startswith('#'): |
| continue |
| else: |
| parts = line.split('\t') |
| if '-' in parts[0] or '.' in parts[0]: |
| continue |
| words.append(parts[1]) |
| heads.append(int(parts[6])) |
| rels.append(parts[7]) |
|
|
| if words: |
| sentences.append(Sentence(words, heads, rels)) |
|
|
| return sentences |
|
|
|
|
| class Vocabulary: |
| """Vocabulary for words, characters, and relations.""" |
| PAD = '<pad>' |
| UNK = '<unk>' |
|
|
| def __init__(self, min_freq: int = 2): |
| self.min_freq = min_freq |
| self.word2idx = {self.PAD: 0, self.UNK: 1} |
| self.char2idx = {self.PAD: 0, self.UNK: 1} |
| self.rel2idx = {} |
| self.idx2rel = {} |
|
|
| def build(self, sentences: List[Sentence]): |
| """Build vocabulary from sentences.""" |
| word_counts = Counter() |
| char_counts = Counter() |
| rel_counts = Counter() |
|
|
| for sent in sentences: |
| for word in sent.words: |
| word_counts[word.lower()] += 1 |
| for char in word: |
| char_counts[char] += 1 |
| for rel in sent.rels: |
| rel_counts[rel] += 1 |
|
|
| |
| for word, count in word_counts.items(): |
| if count >= self.min_freq and word not in self.word2idx: |
| self.word2idx[word] = len(self.word2idx) |
|
|
| |
| for char, count in char_counts.items(): |
| if char not in self.char2idx: |
| self.char2idx[char] = len(self.char2idx) |
|
|
| |
| for rel in rel_counts: |
| if rel not in self.rel2idx: |
| idx = len(self.rel2idx) |
| self.rel2idx[rel] = idx |
| self.idx2rel[idx] = rel |
|
|
| def encode_word(self, word: str) -> int: |
| return self.word2idx.get(word.lower(), self.word2idx[self.UNK]) |
|
|
| def encode_char(self, char: str) -> int: |
| return self.char2idx.get(char, self.char2idx[self.UNK]) |
|
|
| def encode_rel(self, rel: str) -> int: |
| return self.rel2idx.get(rel, 0) |
|
|
| @property |
| def n_words(self) -> int: |
| return len(self.word2idx) |
|
|
| @property |
| def n_chars(self) -> int: |
| return len(self.char2idx) |
|
|
| @property |
| def n_rels(self) -> int: |
| return len(self.rel2idx) |
|
|
|
|
| class DependencyDataset(Dataset): |
| """Dataset for dependency parsing.""" |
|
|
| def __init__(self, sentences: List[Sentence], vocab: Vocabulary): |
| self.sentences = sentences |
| self.vocab = vocab |
|
|
| def __len__(self): |
| return len(self.sentences) |
|
|
| def __getitem__(self, idx): |
| sent = self.sentences[idx] |
|
|
| |
| word_ids = [self.vocab.encode_word(w) for w in sent.words] |
|
|
| |
| char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words] |
|
|
| |
| heads = sent.heads |
| rels = [self.vocab.encode_rel(r) for r in sent.rels] |
|
|
| return word_ids, char_ids, heads, rels |
|
|
|
|
| def collate_fn(batch): |
| """Collate function for DataLoader.""" |
| word_ids, char_ids, heads, rels = zip(*batch) |
|
|
| |
| lengths = [len(w) for w in word_ids] |
| max_len = max(lengths) |
|
|
| |
| word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long) |
| for i, wids in enumerate(word_ids): |
| word_ids_padded[i, :len(wids)] = torch.tensor(wids) |
|
|
| |
| max_word_len = max(max(len(c) for c in chars) for chars in char_ids) |
| char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long) |
| for i, chars in enumerate(char_ids): |
| for j, c in enumerate(chars): |
| char_ids_padded[i, j, :len(c)] = torch.tensor(c) |
|
|
| |
| heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long) |
| for i, h in enumerate(heads): |
| heads_padded[i, :len(h)] = torch.tensor(h) |
|
|
| |
| rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long) |
| for i, r in enumerate(rels): |
| rels_padded[i, :len(r)] = torch.tensor(r) |
|
|
| |
| mask = torch.zeros(len(batch), max_len, dtype=torch.bool) |
| for i, l in enumerate(lengths): |
| mask[i, :l] = True |
|
|
| lengths = torch.tensor(lengths) |
|
|
| return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths |
|
|
|
|
| |
| |
| |
|
|
| class CharLSTM(nn.Module): |
| """Character-level LSTM embeddings.""" |
|
|
| def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100): |
| super().__init__() |
| self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0) |
| self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True) |
| self.hidden_dim = hidden_dim |
|
|
| def forward(self, chars): |
| """ |
| Args: |
| chars: (batch, seq_len, max_word_len) |
| Returns: |
| (batch, seq_len, hidden_dim) |
| """ |
| batch, seq_len, max_word_len = chars.shape |
|
|
| |
| chars_flat = chars.view(-1, max_word_len) |
|
|
| |
| word_lens = (chars_flat != 0).sum(dim=1) |
| word_lens = word_lens.clamp(min=1) |
|
|
| |
| char_embeds = self.embed(chars_flat) |
|
|
| |
| packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False) |
| _, (hidden, _) = self.lstm(packed) |
|
|
| |
| hidden = torch.cat([hidden[0], hidden[1]], dim=-1) |
|
|
| return hidden.view(batch, seq_len, self.hidden_dim) |
|
|
|
|
| class MLP(nn.Module): |
| """Multi-layer perceptron.""" |
|
|
| def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, hidden_dim) |
| self.activation = nn.LeakyReLU(0.1) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.activation(self.linear(x))) |
|
|
|
|
| class Biaffine(nn.Module): |
| """Biaffine attention layer.""" |
|
|
| def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True): |
| super().__init__() |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.bias_x = bias_x |
| self.bias_y = bias_y |
|
|
| self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)) |
| nn.init.xavier_uniform_(self.weight) |
|
|
| def forward(self, x, y): |
| """ |
| Args: |
| x: (batch, seq_len, input_dim) - dependent |
| y: (batch, seq_len, input_dim) - head |
| Returns: |
| (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1 |
| """ |
| if self.bias_x: |
| x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) |
| if self.bias_y: |
| y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) |
|
|
| |
| x = torch.einsum('bxi,oij->bxoj', x, self.weight) |
| |
| scores = torch.einsum('bxoj,byj->bxyo', x, y) |
|
|
| if self.output_dim == 1: |
| scores = scores.squeeze(-1) |
|
|
| return scores |
|
|
|
|
| class BiaffineDependencyParser(nn.Module): |
| """Biaffine Dependency Parser (Dozat & Manning, 2017).""" |
|
|
| def __init__( |
| self, |
| n_words: int, |
| n_chars: int, |
| n_rels: int, |
| word_dim: int = 100, |
| char_dim: int = 50, |
| char_hidden: int = 100, |
| lstm_hidden: int = 400, |
| lstm_layers: int = 3, |
| arc_hidden: int = 500, |
| rel_hidden: int = 100, |
| dropout: float = 0.33, |
| ): |
| super().__init__() |
|
|
| self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0) |
| self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden) |
|
|
| input_dim = word_dim + char_hidden |
|
|
| self.lstm = nn.LSTM( |
| input_dim, lstm_hidden // 2, |
| num_layers=lstm_layers, |
| batch_first=True, |
| bidirectional=True, |
| dropout=dropout if lstm_layers > 1 else 0 |
| ) |
|
|
| self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout) |
| self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout) |
| self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout) |
| self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout) |
|
|
| self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.n_rels = n_rels |
|
|
| def forward(self, words, chars, mask): |
| """ |
| Args: |
| words: (batch, seq_len) |
| chars: (batch, seq_len, max_word_len) |
| mask: (batch, seq_len) |
| Returns: |
| arc_scores: (batch, seq_len, seq_len) |
| rel_scores: (batch, seq_len, seq_len, n_rels) |
| """ |
| |
| word_embeds = self.word_embed(words) |
| char_embeds = self.char_lstm(chars) |
| embeds = torch.cat([word_embeds, char_embeds], dim=-1) |
| embeds = self.dropout(embeds) |
|
|
| |
| lengths = mask.sum(dim=1).cpu() |
| packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False) |
| lstm_out, _ = self.lstm(packed) |
| lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1)) |
| lstm_out = self.dropout(lstm_out) |
|
|
| |
| arc_dep = self.mlp_arc_dep(lstm_out) |
| arc_head = self.mlp_arc_head(lstm_out) |
| rel_dep = self.mlp_rel_dep(lstm_out) |
| rel_head = self.mlp_rel_head(lstm_out) |
|
|
| |
| arc_scores = self.arc_attn(arc_dep, arc_head) |
| rel_scores = self.rel_attn(rel_dep, rel_head) |
|
|
| return arc_scores, rel_scores |
|
|
| def loss(self, arc_scores, rel_scores, heads, rels, mask): |
| """Compute loss.""" |
| batch_size, seq_len = mask.shape |
|
|
| |
| arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf')) |
| arc_loss = F.cross_entropy( |
| arc_scores[mask].view(-1, seq_len), |
| heads[mask], |
| reduction='mean' |
| ) |
|
|
| |
| rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads] |
| rel_loss = F.cross_entropy( |
| rel_scores_gold[mask], |
| rels[mask], |
| reduction='mean' |
| ) |
|
|
| return arc_loss + rel_loss |
|
|
| def decode(self, arc_scores, rel_scores, mask): |
| """Decode predictions.""" |
| |
| arc_preds = arc_scores.argmax(dim=-1) |
|
|
| batch_size, seq_len = mask.shape |
| rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds] |
| rel_preds = rel_scores_pred.argmax(dim=-1) |
|
|
| return arc_preds, rel_preds |
|
|
|
|
| |
| |
| |
|
|
| class TransformerDependencyParser(nn.Module): |
| """ |
| Trankit-style dependency parser using XLM-RoBERTa. |
| |
| Architecture follows Nguyen et al. 2021 EACL: |
| - XLM-RoBERTa encoder |
| - Word-level pooling (first subword) |
| - Biaffine attention for arc/rel prediction |
| """ |
|
|
| def __init__( |
| self, |
| n_rels: int, |
| encoder: str = "xlm-roberta-base", |
| arc_hidden: int = 500, |
| rel_hidden: int = 100, |
| dropout: float = 0.33, |
| ): |
| super().__init__() |
| from transformers import AutoModel, AutoTokenizer |
|
|
| self.encoder_name = encoder |
| self.tokenizer = AutoTokenizer.from_pretrained(encoder) |
| self.encoder = AutoModel.from_pretrained(encoder) |
| self.hidden_size = self.encoder.config.hidden_size |
|
|
| |
| self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout) |
| self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout) |
| self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout) |
| self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout) |
|
|
| self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.n_rels = n_rels |
|
|
| def encode_pretokenized(self, input_ids, attention_mask, word_starts, word_mask): |
| """Encode pre-tokenized batch (fast path - no tokenization overhead).""" |
| outputs = self.encoder(input_ids, attention_mask=attention_mask) |
| hidden = outputs.last_hidden_state |
|
|
| |
| word_starts_exp = word_starts.unsqueeze(-1).expand(-1, -1, self.hidden_size) |
| word_starts_exp = word_starts_exp.clamp(0, hidden.size(1) - 1) |
| word_hidden = torch.gather(hidden, 1, word_starts_exp) |
|
|
| return word_hidden, word_mask |
|
|
| def forward(self, word_hidden, word_mask): |
| """Compute arc and relation scores from word representations.""" |
| word_hidden = self.dropout(word_hidden) |
|
|
| |
| arc_dep = self.mlp_arc_dep(word_hidden) |
| arc_head = self.mlp_arc_head(word_hidden) |
| rel_dep = self.mlp_rel_dep(word_hidden) |
| rel_head = self.mlp_rel_head(word_hidden) |
|
|
| arc_scores = self.arc_attn(arc_dep, arc_head) |
| rel_scores = self.rel_attn(rel_dep, rel_head) |
|
|
| return arc_scores, rel_scores |
|
|
| def loss(self, arc_scores, rel_scores, heads, rels, mask): |
| """Compute cross-entropy loss.""" |
| batch_size, seq_len = mask.shape |
|
|
| |
| arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf')) |
| arc_loss = F.cross_entropy( |
| arc_scores[mask].view(-1, seq_len), |
| heads[mask], |
| reduction='mean' |
| ) |
|
|
| |
| rel_scores_gold = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1), |
| torch.arange(seq_len, device=mask.device), heads] |
| rel_loss = F.cross_entropy( |
| rel_scores_gold[mask], |
| rels[mask], |
| reduction='mean' |
| ) |
|
|
| return arc_loss + rel_loss |
|
|
| def decode(self, arc_scores, rel_scores, mask): |
| """Greedy decoding.""" |
| arc_preds = arc_scores.argmax(dim=-1) |
|
|
| batch_size, seq_len = mask.shape |
| rel_scores_pred = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1), |
| torch.arange(seq_len, device=mask.device), arc_preds] |
| rel_preds = rel_scores_pred.argmax(dim=-1) |
|
|
| return arc_preds, rel_preds |
|
|
|
|
| class TransformerDataset(Dataset): |
| """Pre-tokenized dataset - tokenizes once at creation for fast training.""" |
|
|
| def __init__(self, sentences: List[Sentence], vocab, tokenizer): |
| self.data = [] |
| for sent in tqdm(sentences, desc="Pre-tokenizing", leave=False): |
| input_ids = [tokenizer.cls_token_id] |
| word_starts = [] |
| for word in sent.words: |
| word_starts.append(len(input_ids)) |
| tokens = tokenizer.encode(word, add_special_tokens=False) |
| input_ids.extend(tokens if tokens else [tokenizer.unk_token_id]) |
| input_ids.append(tokenizer.sep_token_id) |
| self.data.append(( |
| torch.tensor(input_ids, dtype=torch.long), |
| torch.tensor(word_starts, dtype=torch.long), |
| torch.tensor(sent.heads, dtype=torch.long), |
| torch.tensor([vocab.encode_rel(r) for r in sent.rels], dtype=torch.long), |
| len(sent.words), |
| )) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| return self.data[idx] |
|
|
|
|
| def transformer_collate_fn(batch): |
| """Collate for pre-tokenized transformer data.""" |
| input_ids, word_starts, heads, rels, n_words = zip(*batch) |
|
|
| batch_size = len(batch) |
| max_subwords = max(ids.size(0) for ids in input_ids) |
| max_words = max(n_words) |
|
|
| padded_ids = torch.zeros(batch_size, max_subwords, dtype=torch.long) |
| attention_mask = torch.zeros(batch_size, max_subwords, dtype=torch.long) |
| padded_starts = torch.zeros(batch_size, max_words, dtype=torch.long) |
| padded_heads = torch.zeros(batch_size, max_words, dtype=torch.long) |
| padded_rels = torch.zeros(batch_size, max_words, dtype=torch.long) |
| word_mask = torch.zeros(batch_size, max_words, dtype=torch.bool) |
|
|
| for i, (ids, starts, h, r, nw) in enumerate(zip(input_ids, word_starts, heads, rels, n_words)): |
| padded_ids[i, :ids.size(0)] = ids |
| attention_mask[i, :ids.size(0)] = 1 |
| padded_starts[i, :starts.size(0)] = starts |
| padded_heads[i, :nw] = h |
| padded_rels[i, :nw] = r |
| word_mask[i, :nw] = True |
|
|
| return padded_ids, attention_mask, padded_starts, padded_heads, padded_rels, word_mask |
|
|
|
|
| def evaluate_transformer(model, dataloader, device): |
| """Evaluate transformer-based model.""" |
| model.eval() |
|
|
| total_arcs = 0 |
| correct_arcs = 0 |
| correct_rels = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids, attn_mask, word_starts, heads, rels, mask = [x.to(device) for x in batch] |
|
|
| word_hidden, word_mask = model.encode_pretokenized(input_ids, attn_mask, word_starts, mask) |
| arc_scores, rel_scores = model(word_hidden, word_mask) |
| arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask) |
|
|
| arc_correct = (arc_preds == heads) & mask |
| rel_correct = (rel_preds == rels) & mask & arc_correct |
|
|
| total_arcs += mask.sum().item() |
| correct_arcs += arc_correct.sum().item() |
| correct_rels += rel_correct.sum().item() |
|
|
| uas = correct_arcs / total_arcs * 100 |
| las = correct_rels / total_arcs * 100 |
|
|
| return uas, las |
|
|
|
|
| |
| |
| |
|
|
| def evaluate(model, dataloader, device): |
| """Evaluate model and return UAS/LAS.""" |
| model.eval() |
|
|
| total_arcs = 0 |
| correct_arcs = 0 |
| correct_rels = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] |
|
|
| arc_scores, rel_scores = model(words, chars, mask) |
| arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask) |
|
|
| |
| arc_correct = (arc_preds == heads) & mask |
| rel_correct = (rel_preds == rels) & mask & arc_correct |
|
|
| total_arcs += mask.sum().item() |
| correct_arcs += arc_correct.sum().item() |
| correct_rels += rel_correct.sum().item() |
|
|
| uas = correct_arcs / total_arcs * 100 |
| las = correct_rels / total_arcs * 100 |
|
|
| return uas, las |
|
|
|
|
| @click.command() |
| @click.option('--method', type=click.Choice(['baseline', 'trankit']), default='baseline', |
| help='Parser method: baseline (BiLSTM) or trankit (XLM-RoBERTa)') |
| @click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb', 'vndt']), default='udd1', |
| help='Dataset: udd1 (UDD-1), ud-vtb (UD Vietnamese VTB), or vndt (VnDT v1.1)') |
| @click.option('--encoder', default='xlm-roberta-base', |
| help='Transformer encoder for trankit method') |
| @click.option('--output', '-o', default='models/bamboo-1', help='Output directory') |
| @click.option('--epochs', default=100, type=int, help='Number of epochs') |
| @click.option('--batch-size', default=64, type=int, help='Batch size') |
| @click.option('--lr', default=2e-3, type=float, help='Learning rate for baseline') |
| @click.option('--bert-lr', default=2e-5, type=float, help='Encoder learning rate for trankit') |
| @click.option('--head-lr', default=2e-4, type=float, help='Head learning rate for trankit') |
| @click.option('--warmup-steps', default=200, type=int, help='Warmup steps for trankit') |
| @click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size (baseline)') |
| @click.option('--lstm-layers', default=3, type=int, help='LSTM layers (baseline)') |
| @click.option('--patience', default=5, type=int, help='Early stopping patience') |
| @click.option('--force-download', is_flag=True, help='Force re-download dataset') |
| @click.option('--data-dir', default=None, help='Custom data directory') |
| @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation') |
| @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds') |
| @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') |
| @click.option('--wandb-project', default='bamboo-1', help='W&B project name') |
| @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)') |
| @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)') |
| @click.option('--eval-every', default=2, type=int, help='Evaluate every N epochs') |
| @click.option('--fp16', is_flag=True, default=True, help='Use mixed precision training') |
| def train(method, dataset, encoder, output, epochs, batch_size, lr, bert_lr, head_lr, warmup_steps, |
| lstm_hidden, lstm_layers, patience, force_download, data_dir, gpu_type, cost_interval, |
| use_wandb, wandb_project, max_time, sample, eval_every, fp16): |
| """Train Bamboo-1 Vietnamese Dependency Parser.""" |
|
|
| |
| hardware = detect_hardware() |
| detected_gpu_type = hardware.get_gpu_type() |
|
|
| if gpu_type == "RTX_A4000": |
| gpu_type = detected_gpu_type |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| click.echo(f"Using device: {device}") |
| click.echo(f"Hardware: {hardware}") |
|
|
| |
| if torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| use_amp = fp16 and torch.cuda.is_available() |
| scaler = torch.amp.GradScaler('cuda') if use_amp else None |
| if use_amp: |
| click.echo("Mixed precision (FP16): enabled") |
|
|
| |
| if use_wandb: |
| import wandb |
| wandb.init( |
| project=wandb_project, |
| config={ |
| "method": method, |
| "dataset": dataset, |
| "encoder": encoder if method == "trankit" else "bilstm", |
| "epochs": epochs, |
| "batch_size": batch_size, |
| "lr": lr if method == "baseline" else bert_lr, |
| "head_lr": head_lr if method == "trankit" else None, |
| "lstm_hidden": lstm_hidden if method == "baseline" else None, |
| "lstm_layers": lstm_layers if method == "baseline" else None, |
| "patience": patience, |
| "gpu_type": gpu_type, |
| "hardware": hardware.to_dict(), |
| } |
| ) |
| click.echo(f"W&B logging enabled: {wandb.run.url}") |
|
|
| click.echo("=" * 60) |
| click.echo(f"Bamboo-1: Vietnamese Dependency Parser ({method.upper()})") |
| click.echo("=" * 60) |
|
|
| |
| click.echo(f"\nLoading {dataset.upper()} corpus...") |
| if dataset == 'udd1': |
| corpus = UDD1Corpus(data_dir=data_dir, force_download=force_download) |
| elif dataset == 'ud-vtb': |
| corpus = UDVietnameseVTB(data_dir=data_dir, force_download=force_download) |
| else: |
| corpus = VnDTCorpus(data_dir=data_dir, force_download=force_download) |
|
|
| train_sents = read_conllu(corpus.train) |
| dev_sents = read_conllu(corpus.dev) |
| test_sents = read_conllu(corpus.test) |
|
|
| |
| if sample > 0: |
| train_sents = train_sents[:sample] |
| dev_sents = dev_sents[:min(sample // 2, len(dev_sents))] |
| test_sents = test_sents[:min(sample // 2, len(test_sents))] |
| click.echo(f" Sampling {sample} sentences...") |
|
|
| click.echo(f" Train: {len(train_sents)} sentences") |
| click.echo(f" Dev: {len(dev_sents)} sentences") |
| click.echo(f" Test: {len(test_sents)} sentences") |
|
|
| |
| click.echo("\nBuilding vocabulary...") |
| vocab = Vocabulary(min_freq=2) |
| vocab.build(train_sents) |
| if method == "baseline": |
| click.echo(f" Words: {vocab.n_words}") |
| click.echo(f" Chars: {vocab.n_chars}") |
| click.echo(f" Relations: {vocab.n_rels}") |
|
|
| |
| if method == "trankit": |
| |
| click.echo(f"\nInitializing model with {encoder}...") |
| model = TransformerDependencyParser( |
| n_rels=vocab.n_rels, |
| encoder=encoder, |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| encoder_params = sum(p.numel() for p in model.encoder.parameters()) |
| head_params = n_params - encoder_params |
| click.echo(f" Total parameters: {n_params:,}") |
| click.echo(f" Encoder parameters: {encoder_params:,}") |
| click.echo(f" Head parameters: {head_params:,}") |
|
|
| |
| click.echo("\nPre-tokenizing datasets...") |
| use_pin = torch.cuda.is_available() |
| train_dataset = TransformerDataset(train_sents, vocab, model.tokenizer) |
| dev_dataset = TransformerDataset(dev_sents, vocab, model.tokenizer) |
| test_dataset = TransformerDataset(test_sents, vocab, model.tokenizer) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, |
| collate_fn=transformer_collate_fn, pin_memory=use_pin) |
| dev_loader = DataLoader(dev_dataset, batch_size=batch_size, |
| collate_fn=transformer_collate_fn, pin_memory=use_pin) |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, |
| collate_fn=transformer_collate_fn, pin_memory=use_pin) |
|
|
| |
| encoder_params_list = list(model.encoder.parameters()) |
| head_params_list = [p for n, p in model.named_parameters() if 'encoder' not in n] |
| optimizer = AdamW([ |
| {'params': encoder_params_list, 'lr': bert_lr}, |
| {'params': head_params_list, 'lr': head_lr}, |
| ], weight_decay=0.01) |
|
|
| |
| total_steps = len(train_loader) * epochs |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / warmup_steps |
| return max(0.0, (total_steps - step) / (total_steps - warmup_steps)) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| eval_fn = evaluate_transformer |
| else: |
| |
| train_dataset = DependencyDataset(train_sents, vocab) |
| dev_dataset = DependencyDataset(dev_sents, vocab) |
| test_dataset = DependencyDataset(test_sents, vocab) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) |
| dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn) |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn) |
|
|
| click.echo("\nInitializing BiLSTM model...") |
| model = BiaffineDependencyParser( |
| n_words=vocab.n_words, |
| n_chars=vocab.n_chars, |
| n_rels=vocab.n_rels, |
| lstm_hidden=lstm_hidden, |
| lstm_layers=lstm_layers, |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| click.echo(f" Parameters: {n_params:,}") |
|
|
| optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9)) |
| scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000)) |
|
|
| eval_fn = evaluate |
|
|
| |
| click.echo(f"\nTraining for {epochs} epochs...") |
| if max_time > 0: |
| click.echo(f"Time limit: {max_time} minutes") |
| output_path = Path(output) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| cost_tracker = CostTracker(gpu_type=gpu_type) |
| cost_tracker.report_interval = cost_interval |
| cost_tracker.start() |
| click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr") |
|
|
| best_las = -1 |
| no_improve = 0 |
| time_limit_seconds = max_time * 60 if max_time > 0 else float('inf') |
|
|
| for epoch in range(1, epochs + 1): |
| |
| if cost_tracker.elapsed_seconds() >= time_limit_seconds: |
| click.echo(f"\nTime limit reached ({max_time} minutes)") |
| break |
| model.train() |
| total_loss = 0 |
|
|
| pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False) |
| for batch in pbar: |
| optimizer.zero_grad() |
|
|
| if method == "trankit": |
| input_ids, attn_mask, word_starts, heads, rels, mask = [x.to(device) for x in batch] |
|
|
| with torch.amp.autocast('cuda', enabled=use_amp): |
| word_hidden, word_mask = model.encode_pretokenized(input_ids, attn_mask, word_starts, mask) |
| arc_scores, rel_scores = model(word_hidden, word_mask) |
| loss = model.loss(arc_scores, rel_scores, heads, rels, mask) |
| else: |
| words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] |
| arc_scores, rel_scores = model(words, chars, mask) |
| loss = model.loss(arc_scores, rel_scores, heads, rels, mask) |
|
|
| if use_amp and scaler: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
| optimizer.step() |
|
|
| scheduler.step() |
| total_loss += loss.item() |
| pbar.set_postfix({'loss': f'{loss.item():.4f}'}) |
|
|
| |
| if epoch % eval_every != 0 and epoch != epochs: |
| avg_loss = total_loss / len(train_loader) |
| current_lr = optimizer.param_groups[0]['lr'] |
| click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") |
| continue |
|
|
| dev_uas, dev_las = eval_fn(model, dev_loader, device) |
|
|
| |
| progress = epoch / epochs |
| current_cost = cost_tracker.current_cost() |
| estimated_total_cost = cost_tracker.estimate_total_cost(progress) |
| elapsed_minutes = cost_tracker.elapsed_seconds() / 60 |
|
|
| cost_status = cost_tracker.update(epoch, epochs) |
| if cost_status: |
| click.echo(f" [{cost_status}]") |
|
|
| avg_loss = total_loss / len(train_loader) |
| click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
| f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%") |
|
|
| |
| if use_wandb: |
| wandb.log({ |
| "epoch": epoch, |
| "train/loss": avg_loss, |
| "dev/uas": dev_uas, |
| "dev/las": dev_las, |
| "cost/current_usd": current_cost, |
| "cost/estimated_total_usd": estimated_total_cost, |
| "cost/elapsed_minutes": elapsed_minutes, |
| }) |
|
|
| |
| if dev_las >= best_las: |
| best_las = dev_las |
| no_improve = 0 |
| if method == "trankit": |
| config = { |
| 'method': 'trankit', |
| 'encoder': encoder, |
| 'n_rels': vocab.n_rels, |
| } |
| else: |
| config = { |
| 'method': 'baseline', |
| 'n_words': vocab.n_words, |
| 'n_chars': vocab.n_chars, |
| 'n_rels': vocab.n_rels, |
| 'lstm_hidden': lstm_hidden, |
| 'lstm_layers': lstm_layers, |
| } |
| |
| import tempfile, shutil |
| with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as tmp: |
| torch.save({ |
| 'model': model.state_dict(), |
| 'vocab': vocab, |
| 'config': config, |
| }, tmp.name) |
| shutil.move(tmp.name, output_path / 'model.pt') |
| click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)") |
| else: |
| no_improve += 1 |
| if no_improve >= patience: |
| click.echo(f"\nEarly stopping after {patience} epochs without improvement") |
| break |
|
|
| |
| click.echo("\nLoading best model for final evaluation...") |
| checkpoint = torch.load(output_path / 'model.pt', weights_only=False) |
| model.load_state_dict(checkpoint['model']) |
|
|
| test_uas, test_las = eval_fn(model, test_loader, device) |
| click.echo(f"\nTest Results:") |
| click.echo(f" UAS: {test_uas:.2f}%") |
| click.echo(f" LAS: {test_las:.2f}%") |
|
|
| click.echo(f"\nModel saved to: {output_path}") |
|
|
| |
| final_cost = cost_tracker.current_cost() |
| click.echo(f"\n{cost_tracker.summary(epoch, epochs)}") |
|
|
| |
| if use_wandb: |
| wandb.log({ |
| "test/uas": test_uas, |
| "test/las": test_las, |
| "cost/final_usd": final_cost, |
| }) |
| wandb.finish() |
|
|
|
|
| if __name__ == '__main__': |
| train() |
|
|