bamboo-1 / src /train_phobert.py
rain1024's picture
Consolidate project: merge scripts/, bamboo1/ into src/, optimize training
24ec440
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "transformers>=4.30.0",
# "datasets>=2.14.0",
# "click>=8.0.0",
# "tqdm>=4.60.0",
# "wandb>=0.15.0",
# "python-dotenv>=1.0.0",
# ]
# ///
"""
Training script for PhoBERT-based Vietnamese Dependency Parser.
This script trains a transformer-based dependency parser using PhoBERT as the
encoder, following the Trankit approach for Vietnamese dependency parsing.
Architecture:
PhoBERT -> Word-level pooling -> Biaffine attention -> MST decoding
Usage:
uv run scripts/train_phobert.py
uv run scripts/train_phobert.py --output models/bamboo-1-phobert --epochs 100
uv run scripts/train_phobert.py --encoder vinai/phobert-large
uv run scripts/train_phobert.py --dataset ud-vtb # Use UD Vietnamese VTB
"""
import sys
from pathlib import Path
from collections import Counter
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import click
sys.path.insert(0, str(Path(__file__).parent.parent))
from bamboo1.corpus import UDD1Corpus
from bamboo1.ud_corpus import UDVietnameseVTB
from bamboo1.models.transformer_parser import PhoBERTDependencyParser
from scripts.cost_estimate import CostTracker, detect_hardware
# ============================================================================
# Data Processing
# ============================================================================
@dataclass
class Sentence:
"""A dependency-parsed sentence."""
words: List[str]
heads: List[int]
rels: List[str]
def read_conllu(path: str) -> List[Sentence]:
"""Read CoNLL-U file and return list of sentences."""
sentences = []
words, heads, rels = [], [], []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
if words:
sentences.append(Sentence(words, heads, rels))
words, heads, rels = [], [], []
elif line.startswith('#'):
continue
else:
parts = line.split('\t')
if '-' in parts[0] or '.' in parts[0]:
continue
words.append(parts[1])
heads.append(int(parts[6]))
rels.append(parts[7])
if words:
sentences.append(Sentence(words, heads, rels))
return sentences
class Vocabulary:
"""Vocabulary for relations."""
def __init__(self):
self.rel2idx = {}
self.idx2rel = {}
def build(self, sentences: List[Sentence]):
"""Build vocabulary from sentences."""
rel_counts = Counter()
for sent in sentences:
for rel in sent.rels:
rel_counts[rel] += 1
for rel in sorted(rel_counts.keys()):
if rel not in self.rel2idx:
idx = len(self.rel2idx)
self.rel2idx[rel] = idx
self.idx2rel[idx] = rel
@property
def n_rels(self) -> int:
return len(self.rel2idx)
class PhoBERTDependencyDataset(Dataset):
"""Dataset for PhoBERT dependency parsing."""
def __init__(
self,
sentences: List[Sentence],
vocab: Vocabulary,
tokenizer,
max_length: int = 256,
):
self.sentences = sentences
self.vocab = vocab
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
sent = self.sentences[idx]
# Tokenize with word boundary tracking
word_starts = []
subword_ids = [self.tokenizer.cls_token_id]
for word in sent.words:
word_starts.append(len(subword_ids))
word_tokens = self.tokenizer.encode(word, add_special_tokens=False)
if not word_tokens:
word_tokens = [self.tokenizer.unk_token_id]
subword_ids.extend(word_tokens)
subword_ids.append(self.tokenizer.sep_token_id)
# Truncate if needed
if len(subword_ids) > self.max_length:
subword_ids = subword_ids[:self.max_length-1] + [self.tokenizer.sep_token_id]
# Keep words that fit
valid_words = sum(1 for ws in word_starts if ws < self.max_length - 1)
word_starts = word_starts[:valid_words]
heads = sent.heads[:valid_words]
rels = sent.rels[:valid_words]
else:
heads = sent.heads
rels = sent.rels
# Encode relations
rel_ids = [self.vocab.rel2idx.get(r, 0) for r in rels]
return {
'input_ids': subword_ids,
'word_starts': word_starts,
'heads': heads,
'rels': rel_ids,
}
def collate_fn(batch):
"""Collate function for DataLoader."""
# Get max lengths
max_subword_len = max(len(item['input_ids']) for item in batch)
max_word_len = max(len(item['word_starts']) for item in batch)
batch_size = len(batch)
# Initialize tensors
input_ids = torch.zeros(batch_size, max_subword_len, dtype=torch.long)
attention_mask = torch.zeros(batch_size, max_subword_len, dtype=torch.long)
word_starts = torch.zeros(batch_size, max_word_len, dtype=torch.long)
word_mask = torch.zeros(batch_size, max_word_len, dtype=torch.bool)
heads = torch.zeros(batch_size, max_word_len, dtype=torch.long)
rels = torch.zeros(batch_size, max_word_len, dtype=torch.long)
for i, item in enumerate(batch):
# Subwords
seq_len = len(item['input_ids'])
input_ids[i, :seq_len] = torch.tensor(item['input_ids'])
attention_mask[i, :seq_len] = 1
# Words
word_len = len(item['word_starts'])
word_starts[i, :word_len] = torch.tensor(item['word_starts'])
word_mask[i, :word_len] = True
heads[i, :word_len] = torch.tensor(item['heads'])
rels[i, :word_len] = torch.tensor(item['rels'])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'word_starts': word_starts,
'word_mask': word_mask,
'heads': heads,
'rels': rels,
}
# ============================================================================
# Training
# ============================================================================
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
"""Create scheduler with linear warmup and linear decay."""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0,
float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps))
)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def evaluate(model, dataloader, device):
"""Evaluate model and return UAS/LAS."""
model.eval()
total_arcs = 0
correct_arcs = 0
correct_rels = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
word_starts = batch['word_starts'].to(device)
word_mask = batch['word_mask'].to(device)
heads = batch['heads'].to(device)
rels = batch['rels'].to(device)
arc_scores, rel_scores = model(
input_ids, attention_mask, word_starts, word_mask
)
arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask)
# Count correct
arc_correct = (arc_preds == heads) & word_mask
rel_correct = (rel_preds == rels) & word_mask & arc_correct
total_arcs += word_mask.sum().item()
correct_arcs += arc_correct.sum().item()
correct_rels += rel_correct.sum().item()
uas = correct_arcs / total_arcs * 100 if total_arcs > 0 else 0
las = correct_rels / total_arcs * 100 if total_arcs > 0 else 0
return uas, las
@click.command()
@click.option('--output', '-o', default='models/bamboo-1-phobert', help='Output directory')
@click.option('--encoder', default='vinai/phobert-base', help='PhoBERT encoder model')
@click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb']), default='udd1',
help='Dataset to use: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)')
@click.option('--data-dir', default=None, help='Directory for dataset cache (default: ./data/<dataset>)')
@click.option('--epochs', default=100, type=int, help='Number of epochs')
@click.option('--batch-size', default=32, type=int, help='Batch size')
@click.option('--bert-lr', default=1e-5, type=float, help='Learning rate for BERT layers')
@click.option('--head-lr', default=1e-3, type=float, help='Learning rate for parser head')
@click.option('--warmup-steps', default=1000, type=int, help='Warmup steps')
@click.option('--weight-decay', default=0.01, type=float, help='Weight decay')
@click.option('--max-grad-norm', default=5.0, type=float, help='Max gradient norm for clipping')
@click.option('--arc-hidden', default=500, type=int, help='Arc MLP hidden size')
@click.option('--rel-hidden', default=100, type=int, help='Relation MLP hidden size')
@click.option('--dropout', default=0.33, type=float, help='Dropout rate')
@click.option('--patience', default=10, type=int, help='Early stopping patience')
@click.option('--use-mst/--no-mst', default=True, help='Use MST decoding')
@click.option('--force-download', is_flag=True, help='Force re-download dataset')
@click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation')
@click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds')
@click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
@click.option('--wandb-project', default='bamboo-1-phobert', help='W&B project name')
@click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)')
@click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)')
@click.option('--freeze-bert', default=0, type=int, help='Freeze BERT for first N epochs')
@click.option('--fp16/--no-fp16', default=True, help='Use mixed precision training (FP16)')
@click.option('--num-workers', default=4, type=int, help='DataLoader workers')
@click.option('--grad-accum', default=1, type=int, help='Gradient accumulation steps')
@click.option('--compile/--no-compile', default=False, help='Use torch.compile for faster training')
@click.option('--eval-every', default=1, type=int, help='Evaluate every N epochs (default: 1)')
def train(
output, encoder, dataset, data_dir, epochs, batch_size, bert_lr, head_lr, warmup_steps,
weight_decay, max_grad_norm, arc_hidden, rel_hidden, dropout, patience,
use_mst, force_download, gpu_type, cost_interval, use_wandb, wandb_project,
max_time, sample, freeze_bert, fp16, num_workers, grad_accum, compile, eval_every
):
"""Train PhoBERT-based Vietnamese Dependency Parser."""
# Detect hardware
hardware = detect_hardware()
detected_gpu_type = hardware.get_gpu_type()
if gpu_type == "RTX_A4000":
gpu_type = detected_gpu_type
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
click.echo(f"Using device: {device}")
click.echo(f"Hardware: {hardware}")
# Performance optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True # Auto-tune convolutions
torch.backends.cuda.matmul.allow_tf32 = True # Allow TF32 for faster matmul
torch.backends.cudnn.allow_tf32 = True
# Mixed precision training
use_amp = fp16 and torch.cuda.is_available()
scaler = torch.amp.GradScaler('cuda') if use_amp else None
if use_amp:
click.echo(f"Mixed precision (FP16): enabled")
# Initialize wandb with system metrics
if use_wandb:
import wandb
import os
os.environ["WANDB__REQUIRE_LEGACY_SERVICE"] = "true" # Enable legacy service for system metrics
wandb.init(
project=wandb_project,
config={
"encoder": encoder,
"dataset": dataset,
"data_dir": data_dir,
"epochs": epochs,
"batch_size": batch_size,
"bert_lr": bert_lr,
"head_lr": head_lr,
"warmup_steps": warmup_steps,
"weight_decay": weight_decay,
"arc_hidden": arc_hidden,
"rel_hidden": rel_hidden,
"dropout": dropout,
"patience": patience,
"use_mst": use_mst,
"gpu_type": gpu_type,
"hardware": hardware.to_dict(),
"eval_every": eval_every,
"compile": compile,
},
)
click.echo(f"W&B logging enabled: {wandb.run.url}")
click.echo("=" * 60)
click.echo("Bamboo-1: PhoBERT Vietnamese Dependency Parser")
click.echo("=" * 60)
# Load corpus
click.echo(f"\nLoading {dataset.upper()} corpus...")
if dataset == 'udd1':
corpus = UDD1Corpus(data_dir=data_dir, force_download=force_download)
else:
corpus = UDVietnameseVTB(data_dir=data_dir, force_download=force_download)
train_sents = read_conllu(corpus.train)
dev_sents = read_conllu(corpus.dev)
test_sents = read_conllu(corpus.test)
if sample > 0:
train_sents = train_sents[:sample]
dev_sents = dev_sents[:min(sample // 2, len(dev_sents))]
test_sents = test_sents[:min(sample // 2, len(test_sents))]
click.echo(f" Sampling {sample} sentences...")
click.echo(f" Train: {len(train_sents)} sentences")
click.echo(f" Dev: {len(dev_sents)} sentences")
click.echo(f" Test: {len(test_sents)} sentences")
# Build vocabulary
click.echo("\nBuilding vocabulary...")
vocab = Vocabulary()
vocab.build(train_sents)
click.echo(f" Relations: {vocab.n_rels}")
# Create model
click.echo(f"\nInitializing model with {encoder}...")
model = PhoBERTDependencyParser(
encoder_name=encoder,
n_rels=vocab.n_rels,
arc_hidden=arc_hidden,
rel_hidden=rel_hidden,
dropout=dropout,
use_mst=use_mst,
).to(device)
# Set relation mappings
model.rel2idx = vocab.rel2idx
model.idx2rel = vocab.idx2rel
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_bert_params = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
n_head_params = n_params - n_bert_params
click.echo(f" Total parameters: {n_params:,}")
click.echo(f" BERT parameters: {n_bert_params:,}")
click.echo(f" Head parameters: {n_head_params:,}")
# torch.compile optimization (PyTorch 2.0+)
if compile:
click.echo(" Compiling model with torch.compile...")
model = torch.compile(model, mode="reduce-overhead")
# Create datasets
train_dataset = PhoBERTDependencyDataset(train_sents, vocab, model.tokenizer)
dev_dataset = PhoBERTDependencyDataset(dev_sents, vocab, model.tokenizer)
test_dataset = PhoBERTDependencyDataset(test_sents, vocab, model.tokenizer)
# DataLoader with optimizations
loader_kwargs = {
'collate_fn': collate_fn,
'num_workers': num_workers,
'pin_memory': torch.cuda.is_available(),
'persistent_workers': num_workers > 0,
'prefetch_factor': 4 if num_workers > 0 else None, # Prefetch batches for better overlap
}
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
)
dev_loader = DataLoader(
dev_dataset, batch_size=batch_size, **loader_kwargs
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, **loader_kwargs
)
# Effective batch size with gradient accumulation
effective_batch_size = batch_size * grad_accum
if grad_accum > 1:
click.echo(f" Effective batch size: {effective_batch_size} (batch={batch_size} x accum={grad_accum})")
# Optimizer with differential learning rates
no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']
optimizer_grouped_parameters = [
# BERT parameters with weight decay
{
'params': [p for n, p in model.encoder.named_parameters()
if not any(nd in n for nd in no_decay)],
'lr': bert_lr,
'weight_decay': weight_decay,
},
# BERT parameters without weight decay
{
'params': [p for n, p in model.encoder.named_parameters()
if any(nd in n for nd in no_decay)],
'lr': bert_lr,
'weight_decay': 0.0,
},
# Parser head parameters
{
'params': [p for n, p in model.named_parameters()
if not n.startswith('encoder.')],
'lr': head_lr,
'weight_decay': weight_decay,
},
]
optimizer = AdamW(optimizer_grouped_parameters)
# Learning rate scheduler with warmup
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# Training
click.echo(f"\nTraining for {epochs} epochs...")
if freeze_bert > 0:
click.echo(f" Freezing BERT for first {freeze_bert} epochs")
if max_time > 0:
click.echo(f" Time limit: {max_time} minutes")
output_path = Path(output)
output_path.mkdir(parents=True, exist_ok=True)
# Cost tracking
cost_tracker = CostTracker(gpu_type=gpu_type)
cost_tracker.report_interval = cost_interval
cost_tracker.start()
click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr")
best_las = -1
no_improve = 0
time_limit_seconds = max_time * 60 if max_time > 0 else float('inf')
for epoch in range(1, epochs + 1):
# Check time limit
if cost_tracker.elapsed_seconds() >= time_limit_seconds:
click.echo(f"\nTime limit reached ({max_time} minutes)")
break
# Freeze/unfreeze BERT
if epoch <= freeze_bert:
for param in model.encoder.parameters():
param.requires_grad = False
elif epoch == freeze_bert + 1:
click.echo(" Unfreezing BERT parameters...")
for param in model.encoder.parameters():
param.requires_grad = True
model.train()
total_loss = 0
optimizer.zero_grad()
pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False)
for step, batch in enumerate(pbar):
input_ids = batch['input_ids'].to(device, non_blocking=True)
attention_mask = batch['attention_mask'].to(device, non_blocking=True)
word_starts = batch['word_starts'].to(device, non_blocking=True)
word_mask = batch['word_mask'].to(device, non_blocking=True)
heads = batch['heads'].to(device, non_blocking=True)
rels = batch['rels'].to(device, non_blocking=True)
# Mixed precision forward pass
with torch.amp.autocast('cuda', enabled=use_amp):
arc_scores, rel_scores = model(
input_ids, attention_mask, word_starts, word_mask
)
loss = model.loss(arc_scores, rel_scores, heads, rels, word_mask)
loss = loss / grad_accum # Scale for gradient accumulation
# Backward pass with gradient scaling
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
# Optimizer step (every grad_accum steps)
if (step + 1) % grad_accum == 0 or (step + 1) == len(train_loader):
if use_amp:
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item() * grad_accum
pbar.set_postfix({'loss': f'{loss.item() * grad_accum:.4f}'})
avg_loss = total_loss / len(train_loader)
current_lr = scheduler.get_last_lr()[0]
# Evaluate every N epochs or at last epoch
if epoch % eval_every == 0 or epoch == epochs:
dev_uas, dev_las = evaluate(model, dev_loader, device)
# Cost update
progress = epoch / epochs
current_cost = cost_tracker.current_cost()
estimated_total_cost = cost_tracker.estimate_total_cost(progress)
elapsed_minutes = cost_tracker.elapsed_seconds() / 60
cost_status = cost_tracker.update(epoch, epochs)
if cost_status:
click.echo(f" [{cost_status}]")
click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}% | "
f"LR: {current_lr:.2e}")
# Log to wandb
if use_wandb:
wandb.log({
"epoch": epoch,
"train/loss": avg_loss,
"dev/uas": dev_uas,
"dev/las": dev_las,
"lr": current_lr,
"cost/current_usd": current_cost,
"cost/estimated_total_usd": estimated_total_cost,
"cost/elapsed_minutes": elapsed_minutes,
})
# Save best model
if dev_las >= best_las:
best_las = dev_las
no_improve = 0
model.save(
str(output_path),
vocab={'rel2idx': vocab.rel2idx, 'idx2rel': vocab.idx2rel}
)
click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)")
else:
no_improve += 1
if no_improve >= patience:
click.echo(f"\nEarly stopping after {patience} epochs without improvement")
break
else:
# Just log training loss without evaluation
click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
if use_wandb:
wandb.log({
"epoch": epoch,
"train/loss": avg_loss,
"lr": current_lr,
})
# Final evaluation
click.echo("\nLoading best model for final evaluation...")
model = PhoBERTDependencyParser.load(str(output_path), device=str(device))
test_uas, test_las = evaluate(model, test_loader, device)
click.echo(f"\nTest Results:")
click.echo(f" UAS: {test_uas:.2f}%")
click.echo(f" LAS: {test_las:.2f}%")
click.echo(f"\nModel saved to: {output_path}")
# Final cost summary
final_cost = cost_tracker.current_cost()
click.echo(f"\n{cost_tracker.summary(epoch, epochs)}")
# Log final metrics to wandb
if use_wandb:
wandb.log({
"test/uas": test_uas,
"test/las": test_las,
"cost/final_usd": final_cost,
})
wandb.finish()
if __name__ == '__main__':
train()