| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Training script for PhoBERT-based Vietnamese Dependency Parser. |
| |
| This script trains a transformer-based dependency parser using PhoBERT as the |
| encoder, following the Trankit approach for Vietnamese dependency parsing. |
| |
| Architecture: |
| PhoBERT -> Word-level pooling -> Biaffine attention -> MST decoding |
| |
| Usage: |
| uv run scripts/train_phobert.py |
| uv run scripts/train_phobert.py --output models/bamboo-1-phobert --epochs 100 |
| uv run scripts/train_phobert.py --encoder vinai/phobert-large |
| uv run scripts/train_phobert.py --dataset ud-vtb # Use UD Vietnamese VTB |
| """ |
|
|
| import sys |
| from pathlib import Path |
| from collections import Counter |
| from dataclasses import dataclass |
| from typing import List, Tuple, Optional, Dict |
|
|
| |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from tqdm import tqdm |
|
|
| import click |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from bamboo1.corpus import UDD1Corpus |
| from bamboo1.ud_corpus import UDVietnameseVTB |
| from bamboo1.models.transformer_parser import PhoBERTDependencyParser |
| from scripts.cost_estimate import CostTracker, detect_hardware |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Sentence: |
| """A dependency-parsed sentence.""" |
| words: List[str] |
| heads: List[int] |
| rels: List[str] |
|
|
|
|
| def read_conllu(path: str) -> List[Sentence]: |
| """Read CoNLL-U file and return list of sentences.""" |
| sentences = [] |
| words, heads, rels = [], [], [] |
|
|
| with open(path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| if words: |
| sentences.append(Sentence(words, heads, rels)) |
| words, heads, rels = [], [], [] |
| elif line.startswith('#'): |
| continue |
| else: |
| parts = line.split('\t') |
| if '-' in parts[0] or '.' in parts[0]: |
| continue |
| words.append(parts[1]) |
| heads.append(int(parts[6])) |
| rels.append(parts[7]) |
|
|
| if words: |
| sentences.append(Sentence(words, heads, rels)) |
|
|
| return sentences |
|
|
|
|
| class Vocabulary: |
| """Vocabulary for relations.""" |
|
|
| def __init__(self): |
| self.rel2idx = {} |
| self.idx2rel = {} |
|
|
| def build(self, sentences: List[Sentence]): |
| """Build vocabulary from sentences.""" |
| rel_counts = Counter() |
| for sent in sentences: |
| for rel in sent.rels: |
| rel_counts[rel] += 1 |
|
|
| for rel in sorted(rel_counts.keys()): |
| if rel not in self.rel2idx: |
| idx = len(self.rel2idx) |
| self.rel2idx[rel] = idx |
| self.idx2rel[idx] = rel |
|
|
| @property |
| def n_rels(self) -> int: |
| return len(self.rel2idx) |
|
|
|
|
| class PhoBERTDependencyDataset(Dataset): |
| """Dataset for PhoBERT dependency parsing.""" |
|
|
| def __init__( |
| self, |
| sentences: List[Sentence], |
| vocab: Vocabulary, |
| tokenizer, |
| max_length: int = 256, |
| ): |
| self.sentences = sentences |
| self.vocab = vocab |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.sentences) |
|
|
| def __getitem__(self, idx): |
| sent = self.sentences[idx] |
|
|
| |
| word_starts = [] |
| subword_ids = [self.tokenizer.cls_token_id] |
|
|
| for word in sent.words: |
| word_starts.append(len(subword_ids)) |
| word_tokens = self.tokenizer.encode(word, add_special_tokens=False) |
| if not word_tokens: |
| word_tokens = [self.tokenizer.unk_token_id] |
| subword_ids.extend(word_tokens) |
|
|
| subword_ids.append(self.tokenizer.sep_token_id) |
|
|
| |
| if len(subword_ids) > self.max_length: |
| subword_ids = subword_ids[:self.max_length-1] + [self.tokenizer.sep_token_id] |
| |
| valid_words = sum(1 for ws in word_starts if ws < self.max_length - 1) |
| word_starts = word_starts[:valid_words] |
| heads = sent.heads[:valid_words] |
| rels = sent.rels[:valid_words] |
| else: |
| heads = sent.heads |
| rels = sent.rels |
|
|
| |
| rel_ids = [self.vocab.rel2idx.get(r, 0) for r in rels] |
|
|
| return { |
| 'input_ids': subword_ids, |
| 'word_starts': word_starts, |
| 'heads': heads, |
| 'rels': rel_ids, |
| } |
|
|
|
|
| def collate_fn(batch): |
| """Collate function for DataLoader.""" |
| |
| max_subword_len = max(len(item['input_ids']) for item in batch) |
| max_word_len = max(len(item['word_starts']) for item in batch) |
|
|
| batch_size = len(batch) |
|
|
| |
| input_ids = torch.zeros(batch_size, max_subword_len, dtype=torch.long) |
| attention_mask = torch.zeros(batch_size, max_subword_len, dtype=torch.long) |
| word_starts = torch.zeros(batch_size, max_word_len, dtype=torch.long) |
| word_mask = torch.zeros(batch_size, max_word_len, dtype=torch.bool) |
| heads = torch.zeros(batch_size, max_word_len, dtype=torch.long) |
| rels = torch.zeros(batch_size, max_word_len, dtype=torch.long) |
|
|
| for i, item in enumerate(batch): |
| |
| seq_len = len(item['input_ids']) |
| input_ids[i, :seq_len] = torch.tensor(item['input_ids']) |
| attention_mask[i, :seq_len] = 1 |
|
|
| |
| word_len = len(item['word_starts']) |
| word_starts[i, :word_len] = torch.tensor(item['word_starts']) |
| word_mask[i, :word_len] = True |
| heads[i, :word_len] = torch.tensor(item['heads']) |
| rels[i, :word_len] = torch.tensor(item['rels']) |
|
|
| return { |
| 'input_ids': input_ids, |
| 'attention_mask': attention_mask, |
| 'word_starts': word_starts, |
| 'word_mask': word_mask, |
| 'heads': heads, |
| 'rels': rels, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): |
| """Create scheduler with linear warmup and linear decay.""" |
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| return max( |
| 0.0, |
| float(num_training_steps - current_step) / |
| float(max(1, num_training_steps - num_warmup_steps)) |
| ) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def evaluate(model, dataloader, device): |
| """Evaluate model and return UAS/LAS.""" |
| model.eval() |
|
|
| total_arcs = 0 |
| correct_arcs = 0 |
| correct_rels = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| word_starts = batch['word_starts'].to(device) |
| word_mask = batch['word_mask'].to(device) |
| heads = batch['heads'].to(device) |
| rels = batch['rels'].to(device) |
|
|
| arc_scores, rel_scores = model( |
| input_ids, attention_mask, word_starts, word_mask |
| ) |
| arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask) |
|
|
| |
| arc_correct = (arc_preds == heads) & word_mask |
| rel_correct = (rel_preds == rels) & word_mask & arc_correct |
|
|
| total_arcs += word_mask.sum().item() |
| correct_arcs += arc_correct.sum().item() |
| correct_rels += rel_correct.sum().item() |
|
|
| uas = correct_arcs / total_arcs * 100 if total_arcs > 0 else 0 |
| las = correct_rels / total_arcs * 100 if total_arcs > 0 else 0 |
|
|
| return uas, las |
|
|
|
|
| @click.command() |
| @click.option('--output', '-o', default='models/bamboo-1-phobert', help='Output directory') |
| @click.option('--encoder', default='vinai/phobert-base', help='PhoBERT encoder model') |
| @click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb']), default='udd1', |
| help='Dataset to use: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)') |
| @click.option('--data-dir', default=None, help='Directory for dataset cache (default: ./data/<dataset>)') |
| @click.option('--epochs', default=100, type=int, help='Number of epochs') |
| @click.option('--batch-size', default=32, type=int, help='Batch size') |
| @click.option('--bert-lr', default=1e-5, type=float, help='Learning rate for BERT layers') |
| @click.option('--head-lr', default=1e-3, type=float, help='Learning rate for parser head') |
| @click.option('--warmup-steps', default=1000, type=int, help='Warmup steps') |
| @click.option('--weight-decay', default=0.01, type=float, help='Weight decay') |
| @click.option('--max-grad-norm', default=5.0, type=float, help='Max gradient norm for clipping') |
| @click.option('--arc-hidden', default=500, type=int, help='Arc MLP hidden size') |
| @click.option('--rel-hidden', default=100, type=int, help='Relation MLP hidden size') |
| @click.option('--dropout', default=0.33, type=float, help='Dropout rate') |
| @click.option('--patience', default=10, type=int, help='Early stopping patience') |
| @click.option('--use-mst/--no-mst', default=True, help='Use MST decoding') |
| @click.option('--force-download', is_flag=True, help='Force re-download dataset') |
| @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation') |
| @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds') |
| @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') |
| @click.option('--wandb-project', default='bamboo-1-phobert', help='W&B project name') |
| @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)') |
| @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)') |
| @click.option('--freeze-bert', default=0, type=int, help='Freeze BERT for first N epochs') |
| @click.option('--fp16/--no-fp16', default=True, help='Use mixed precision training (FP16)') |
| @click.option('--num-workers', default=4, type=int, help='DataLoader workers') |
| @click.option('--grad-accum', default=1, type=int, help='Gradient accumulation steps') |
| @click.option('--compile/--no-compile', default=False, help='Use torch.compile for faster training') |
| @click.option('--eval-every', default=1, type=int, help='Evaluate every N epochs (default: 1)') |
| def train( |
| output, encoder, dataset, data_dir, epochs, batch_size, bert_lr, head_lr, warmup_steps, |
| weight_decay, max_grad_norm, arc_hidden, rel_hidden, dropout, patience, |
| use_mst, force_download, gpu_type, cost_interval, use_wandb, wandb_project, |
| max_time, sample, freeze_bert, fp16, num_workers, grad_accum, compile, eval_every |
| ): |
| """Train PhoBERT-based Vietnamese Dependency Parser.""" |
|
|
| |
| hardware = detect_hardware() |
| detected_gpu_type = hardware.get_gpu_type() |
|
|
| if gpu_type == "RTX_A4000": |
| gpu_type = detected_gpu_type |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| click.echo(f"Using device: {device}") |
| click.echo(f"Hardware: {hardware}") |
|
|
| |
| if torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| use_amp = fp16 and torch.cuda.is_available() |
| scaler = torch.amp.GradScaler('cuda') if use_amp else None |
| if use_amp: |
| click.echo(f"Mixed precision (FP16): enabled") |
|
|
| |
| if use_wandb: |
| import wandb |
| import os |
| os.environ["WANDB__REQUIRE_LEGACY_SERVICE"] = "true" |
| wandb.init( |
| project=wandb_project, |
| config={ |
| "encoder": encoder, |
| "dataset": dataset, |
| "data_dir": data_dir, |
| "epochs": epochs, |
| "batch_size": batch_size, |
| "bert_lr": bert_lr, |
| "head_lr": head_lr, |
| "warmup_steps": warmup_steps, |
| "weight_decay": weight_decay, |
| "arc_hidden": arc_hidden, |
| "rel_hidden": rel_hidden, |
| "dropout": dropout, |
| "patience": patience, |
| "use_mst": use_mst, |
| "gpu_type": gpu_type, |
| "hardware": hardware.to_dict(), |
| "eval_every": eval_every, |
| "compile": compile, |
| }, |
| ) |
| click.echo(f"W&B logging enabled: {wandb.run.url}") |
|
|
| click.echo("=" * 60) |
| click.echo("Bamboo-1: PhoBERT Vietnamese Dependency Parser") |
| click.echo("=" * 60) |
|
|
| |
| click.echo(f"\nLoading {dataset.upper()} corpus...") |
| if dataset == 'udd1': |
| corpus = UDD1Corpus(data_dir=data_dir, force_download=force_download) |
| else: |
| corpus = UDVietnameseVTB(data_dir=data_dir, force_download=force_download) |
|
|
| train_sents = read_conllu(corpus.train) |
| dev_sents = read_conllu(corpus.dev) |
| test_sents = read_conllu(corpus.test) |
|
|
| if sample > 0: |
| train_sents = train_sents[:sample] |
| dev_sents = dev_sents[:min(sample // 2, len(dev_sents))] |
| test_sents = test_sents[:min(sample // 2, len(test_sents))] |
| click.echo(f" Sampling {sample} sentences...") |
|
|
| click.echo(f" Train: {len(train_sents)} sentences") |
| click.echo(f" Dev: {len(dev_sents)} sentences") |
| click.echo(f" Test: {len(test_sents)} sentences") |
|
|
| |
| click.echo("\nBuilding vocabulary...") |
| vocab = Vocabulary() |
| vocab.build(train_sents) |
| click.echo(f" Relations: {vocab.n_rels}") |
|
|
| |
| click.echo(f"\nInitializing model with {encoder}...") |
| model = PhoBERTDependencyParser( |
| encoder_name=encoder, |
| n_rels=vocab.n_rels, |
| arc_hidden=arc_hidden, |
| rel_hidden=rel_hidden, |
| dropout=dropout, |
| use_mst=use_mst, |
| ).to(device) |
|
|
| |
| model.rel2idx = vocab.rel2idx |
| model.idx2rel = vocab.idx2rel |
|
|
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| n_bert_params = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad) |
| n_head_params = n_params - n_bert_params |
| click.echo(f" Total parameters: {n_params:,}") |
| click.echo(f" BERT parameters: {n_bert_params:,}") |
| click.echo(f" Head parameters: {n_head_params:,}") |
|
|
| |
| if compile: |
| click.echo(" Compiling model with torch.compile...") |
| model = torch.compile(model, mode="reduce-overhead") |
|
|
| |
| train_dataset = PhoBERTDependencyDataset(train_sents, vocab, model.tokenizer) |
| dev_dataset = PhoBERTDependencyDataset(dev_sents, vocab, model.tokenizer) |
| test_dataset = PhoBERTDependencyDataset(test_sents, vocab, model.tokenizer) |
|
|
| |
| loader_kwargs = { |
| 'collate_fn': collate_fn, |
| 'num_workers': num_workers, |
| 'pin_memory': torch.cuda.is_available(), |
| 'persistent_workers': num_workers > 0, |
| 'prefetch_factor': 4 if num_workers > 0 else None, |
| } |
| train_loader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs |
| ) |
| dev_loader = DataLoader( |
| dev_dataset, batch_size=batch_size, **loader_kwargs |
| ) |
| test_loader = DataLoader( |
| test_dataset, batch_size=batch_size, **loader_kwargs |
| ) |
|
|
| |
| effective_batch_size = batch_size * grad_accum |
| if grad_accum > 1: |
| click.echo(f" Effective batch size: {effective_batch_size} (batch={batch_size} x accum={grad_accum})") |
|
|
| |
| no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias'] |
| optimizer_grouped_parameters = [ |
| |
| { |
| 'params': [p for n, p in model.encoder.named_parameters() |
| if not any(nd in n for nd in no_decay)], |
| 'lr': bert_lr, |
| 'weight_decay': weight_decay, |
| }, |
| |
| { |
| 'params': [p for n, p in model.encoder.named_parameters() |
| if any(nd in n for nd in no_decay)], |
| 'lr': bert_lr, |
| 'weight_decay': 0.0, |
| }, |
| |
| { |
| 'params': [p for n, p in model.named_parameters() |
| if not n.startswith('encoder.')], |
| 'lr': head_lr, |
| 'weight_decay': weight_decay, |
| }, |
| ] |
| optimizer = AdamW(optimizer_grouped_parameters) |
|
|
| |
| total_steps = len(train_loader) * epochs |
| scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps) |
|
|
| |
| click.echo(f"\nTraining for {epochs} epochs...") |
| if freeze_bert > 0: |
| click.echo(f" Freezing BERT for first {freeze_bert} epochs") |
| if max_time > 0: |
| click.echo(f" Time limit: {max_time} minutes") |
|
|
| output_path = Path(output) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| cost_tracker = CostTracker(gpu_type=gpu_type) |
| cost_tracker.report_interval = cost_interval |
| cost_tracker.start() |
| click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr") |
|
|
| best_las = -1 |
| no_improve = 0 |
| time_limit_seconds = max_time * 60 if max_time > 0 else float('inf') |
|
|
| for epoch in range(1, epochs + 1): |
| |
| if cost_tracker.elapsed_seconds() >= time_limit_seconds: |
| click.echo(f"\nTime limit reached ({max_time} minutes)") |
| break |
|
|
| |
| if epoch <= freeze_bert: |
| for param in model.encoder.parameters(): |
| param.requires_grad = False |
| elif epoch == freeze_bert + 1: |
| click.echo(" Unfreezing BERT parameters...") |
| for param in model.encoder.parameters(): |
| param.requires_grad = True |
|
|
| model.train() |
| total_loss = 0 |
| optimizer.zero_grad() |
|
|
| pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False) |
| for step, batch in enumerate(pbar): |
| input_ids = batch['input_ids'].to(device, non_blocking=True) |
| attention_mask = batch['attention_mask'].to(device, non_blocking=True) |
| word_starts = batch['word_starts'].to(device, non_blocking=True) |
| word_mask = batch['word_mask'].to(device, non_blocking=True) |
| heads = batch['heads'].to(device, non_blocking=True) |
| rels = batch['rels'].to(device, non_blocking=True) |
|
|
| |
| with torch.amp.autocast('cuda', enabled=use_amp): |
| arc_scores, rel_scores = model( |
| input_ids, attention_mask, word_starts, word_mask |
| ) |
| loss = model.loss(arc_scores, rel_scores, heads, rels, word_mask) |
| loss = loss / grad_accum |
|
|
| |
| if use_amp: |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
|
|
| |
| if (step + 1) % grad_accum == 0 or (step + 1) == len(train_loader): |
| if use_amp: |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
|
|
| total_loss += loss.item() * grad_accum |
| pbar.set_postfix({'loss': f'{loss.item() * grad_accum:.4f}'}) |
|
|
| avg_loss = total_loss / len(train_loader) |
| current_lr = scheduler.get_last_lr()[0] |
|
|
| |
| if epoch % eval_every == 0 or epoch == epochs: |
| dev_uas, dev_las = evaluate(model, dev_loader, device) |
|
|
| |
| progress = epoch / epochs |
| current_cost = cost_tracker.current_cost() |
| estimated_total_cost = cost_tracker.estimate_total_cost(progress) |
| elapsed_minutes = cost_tracker.elapsed_seconds() / 60 |
|
|
| cost_status = cost_tracker.update(epoch, epochs) |
| if cost_status: |
| click.echo(f" [{cost_status}]") |
|
|
| click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
| f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}% | " |
| f"LR: {current_lr:.2e}") |
|
|
| |
| if use_wandb: |
| wandb.log({ |
| "epoch": epoch, |
| "train/loss": avg_loss, |
| "dev/uas": dev_uas, |
| "dev/las": dev_las, |
| "lr": current_lr, |
| "cost/current_usd": current_cost, |
| "cost/estimated_total_usd": estimated_total_cost, |
| "cost/elapsed_minutes": elapsed_minutes, |
| }) |
|
|
| |
| if dev_las >= best_las: |
| best_las = dev_las |
| no_improve = 0 |
| model.save( |
| str(output_path), |
| vocab={'rel2idx': vocab.rel2idx, 'idx2rel': vocab.idx2rel} |
| ) |
| click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)") |
| else: |
| no_improve += 1 |
| if no_improve >= patience: |
| click.echo(f"\nEarly stopping after {patience} epochs without improvement") |
| break |
| else: |
| |
| click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") |
| if use_wandb: |
| wandb.log({ |
| "epoch": epoch, |
| "train/loss": avg_loss, |
| "lr": current_lr, |
| }) |
|
|
| |
| click.echo("\nLoading best model for final evaluation...") |
| model = PhoBERTDependencyParser.load(str(output_path), device=str(device)) |
|
|
| test_uas, test_las = evaluate(model, test_loader, device) |
| click.echo(f"\nTest Results:") |
| click.echo(f" UAS: {test_uas:.2f}%") |
| click.echo(f" LAS: {test_las:.2f}%") |
|
|
| click.echo(f"\nModel saved to: {output_path}") |
|
|
| |
| final_cost = cost_tracker.current_cost() |
| click.echo(f"\n{cost_tracker.summary(epoch, epochs)}") |
|
|
| |
| if use_wandb: |
| wandb.log({ |
| "test/uas": test_uas, |
| "test/las": test_las, |
| "cost/final_usd": final_cost, |
| }) |
| wandb.finish() |
|
|
|
|
| if __name__ == '__main__': |
| train() |
|
|