|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Training script for Bamboo-1 Vietnamese Dependency Parser. |
|
|
Biaffine parser implementation from scratch (Dozat & Manning, 2017). |
|
|
|
|
|
Usage: |
|
|
uv run scripts/train.py |
|
|
uv run scripts/train.py --output models/bamboo-1 --epochs 100 |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
from collections import Counter |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Tuple, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.optim import Adam |
|
|
from torch.optim.lr_scheduler import ExponentialLR |
|
|
from tqdm import tqdm |
|
|
|
|
|
import click |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from bamboo1.corpus import UDD1Corpus |
|
|
from scripts.cost_estimate import CostTracker, detect_hardware |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Sentence: |
|
|
"""A dependency-parsed sentence.""" |
|
|
words: List[str] |
|
|
heads: List[int] |
|
|
rels: List[str] |
|
|
|
|
|
|
|
|
def read_conllu(path: str) -> List[Sentence]: |
|
|
"""Read CoNLL-U file and return list of sentences.""" |
|
|
sentences = [] |
|
|
words, heads, rels = [], [], [] |
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if not line: |
|
|
if words: |
|
|
sentences.append(Sentence(words, heads, rels)) |
|
|
words, heads, rels = [], [], [] |
|
|
elif line.startswith('#'): |
|
|
continue |
|
|
else: |
|
|
parts = line.split('\t') |
|
|
if '-' in parts[0] or '.' in parts[0]: |
|
|
continue |
|
|
words.append(parts[1]) |
|
|
heads.append(int(parts[6])) |
|
|
rels.append(parts[7]) |
|
|
|
|
|
if words: |
|
|
sentences.append(Sentence(words, heads, rels)) |
|
|
|
|
|
return sentences |
|
|
|
|
|
|
|
|
class Vocabulary: |
|
|
"""Vocabulary for words, characters, and relations.""" |
|
|
PAD = '<pad>' |
|
|
UNK = '<unk>' |
|
|
|
|
|
def __init__(self, min_freq: int = 2): |
|
|
self.min_freq = min_freq |
|
|
self.word2idx = {self.PAD: 0, self.UNK: 1} |
|
|
self.char2idx = {self.PAD: 0, self.UNK: 1} |
|
|
self.rel2idx = {} |
|
|
self.idx2rel = {} |
|
|
|
|
|
def build(self, sentences: List[Sentence]): |
|
|
"""Build vocabulary from sentences.""" |
|
|
word_counts = Counter() |
|
|
char_counts = Counter() |
|
|
rel_counts = Counter() |
|
|
|
|
|
for sent in sentences: |
|
|
for word in sent.words: |
|
|
word_counts[word.lower()] += 1 |
|
|
for char in word: |
|
|
char_counts[char] += 1 |
|
|
for rel in sent.rels: |
|
|
rel_counts[rel] += 1 |
|
|
|
|
|
|
|
|
for word, count in word_counts.items(): |
|
|
if count >= self.min_freq and word not in self.word2idx: |
|
|
self.word2idx[word] = len(self.word2idx) |
|
|
|
|
|
|
|
|
for char, count in char_counts.items(): |
|
|
if char not in self.char2idx: |
|
|
self.char2idx[char] = len(self.char2idx) |
|
|
|
|
|
|
|
|
for rel in rel_counts: |
|
|
if rel not in self.rel2idx: |
|
|
idx = len(self.rel2idx) |
|
|
self.rel2idx[rel] = idx |
|
|
self.idx2rel[idx] = rel |
|
|
|
|
|
def encode_word(self, word: str) -> int: |
|
|
return self.word2idx.get(word.lower(), self.word2idx[self.UNK]) |
|
|
|
|
|
def encode_char(self, char: str) -> int: |
|
|
return self.char2idx.get(char, self.char2idx[self.UNK]) |
|
|
|
|
|
def encode_rel(self, rel: str) -> int: |
|
|
return self.rel2idx.get(rel, 0) |
|
|
|
|
|
@property |
|
|
def n_words(self) -> int: |
|
|
return len(self.word2idx) |
|
|
|
|
|
@property |
|
|
def n_chars(self) -> int: |
|
|
return len(self.char2idx) |
|
|
|
|
|
@property |
|
|
def n_rels(self) -> int: |
|
|
return len(self.rel2idx) |
|
|
|
|
|
|
|
|
class DependencyDataset(Dataset): |
|
|
"""Dataset for dependency parsing.""" |
|
|
|
|
|
def __init__(self, sentences: List[Sentence], vocab: Vocabulary): |
|
|
self.sentences = sentences |
|
|
self.vocab = vocab |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.sentences) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sent = self.sentences[idx] |
|
|
|
|
|
|
|
|
word_ids = [self.vocab.encode_word(w) for w in sent.words] |
|
|
|
|
|
|
|
|
char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words] |
|
|
|
|
|
|
|
|
heads = sent.heads |
|
|
rels = [self.vocab.encode_rel(r) for r in sent.rels] |
|
|
|
|
|
return word_ids, char_ids, heads, rels |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
"""Collate function for DataLoader.""" |
|
|
word_ids, char_ids, heads, rels = zip(*batch) |
|
|
|
|
|
|
|
|
lengths = [len(w) for w in word_ids] |
|
|
max_len = max(lengths) |
|
|
|
|
|
|
|
|
word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long) |
|
|
for i, wids in enumerate(word_ids): |
|
|
word_ids_padded[i, :len(wids)] = torch.tensor(wids) |
|
|
|
|
|
|
|
|
max_word_len = max(max(len(c) for c in chars) for chars in char_ids) |
|
|
char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long) |
|
|
for i, chars in enumerate(char_ids): |
|
|
for j, c in enumerate(chars): |
|
|
char_ids_padded[i, j, :len(c)] = torch.tensor(c) |
|
|
|
|
|
|
|
|
heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long) |
|
|
for i, h in enumerate(heads): |
|
|
heads_padded[i, :len(h)] = torch.tensor(h) |
|
|
|
|
|
|
|
|
rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long) |
|
|
for i, r in enumerate(rels): |
|
|
rels_padded[i, :len(r)] = torch.tensor(r) |
|
|
|
|
|
|
|
|
mask = torch.zeros(len(batch), max_len, dtype=torch.bool) |
|
|
for i, l in enumerate(lengths): |
|
|
mask[i, :l] = True |
|
|
|
|
|
lengths = torch.tensor(lengths) |
|
|
|
|
|
return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CharLSTM(nn.Module): |
|
|
"""Character-level LSTM embeddings.""" |
|
|
|
|
|
def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100): |
|
|
super().__init__() |
|
|
self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0) |
|
|
self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True) |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
def forward(self, chars): |
|
|
""" |
|
|
Args: |
|
|
chars: (batch, seq_len, max_word_len) |
|
|
Returns: |
|
|
(batch, seq_len, hidden_dim) |
|
|
""" |
|
|
batch, seq_len, max_word_len = chars.shape |
|
|
|
|
|
|
|
|
chars_flat = chars.view(-1, max_word_len) |
|
|
|
|
|
|
|
|
word_lens = (chars_flat != 0).sum(dim=1) |
|
|
word_lens = word_lens.clamp(min=1) |
|
|
|
|
|
|
|
|
char_embeds = self.embed(chars_flat) |
|
|
|
|
|
|
|
|
packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False) |
|
|
_, (hidden, _) = self.lstm(packed) |
|
|
|
|
|
|
|
|
hidden = torch.cat([hidden[0], hidden[1]], dim=-1) |
|
|
|
|
|
return hidden.view(batch, seq_len, self.hidden_dim) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
"""Multi-layer perceptron.""" |
|
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(input_dim, hidden_dim) |
|
|
self.activation = nn.LeakyReLU(0.1) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.dropout(self.activation(self.linear(x))) |
|
|
|
|
|
|
|
|
class Biaffine(nn.Module): |
|
|
"""Biaffine attention layer.""" |
|
|
|
|
|
def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.output_dim = output_dim |
|
|
self.bias_x = bias_x |
|
|
self.bias_y = bias_y |
|
|
|
|
|
self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)) |
|
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
|
|
def forward(self, x, y): |
|
|
""" |
|
|
Args: |
|
|
x: (batch, seq_len, input_dim) - dependent |
|
|
y: (batch, seq_len, input_dim) - head |
|
|
Returns: |
|
|
(batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1 |
|
|
""" |
|
|
if self.bias_x: |
|
|
x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) |
|
|
if self.bias_y: |
|
|
y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) |
|
|
|
|
|
|
|
|
x = torch.einsum('bxi,oij->bxoj', x, self.weight) |
|
|
|
|
|
scores = torch.einsum('bxoj,byj->bxyo', x, y) |
|
|
|
|
|
if self.output_dim == 1: |
|
|
scores = scores.squeeze(-1) |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
class BiaffineDependencyParser(nn.Module): |
|
|
"""Biaffine Dependency Parser (Dozat & Manning, 2017).""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
n_words: int, |
|
|
n_chars: int, |
|
|
n_rels: int, |
|
|
word_dim: int = 100, |
|
|
char_dim: int = 50, |
|
|
char_hidden: int = 100, |
|
|
lstm_hidden: int = 400, |
|
|
lstm_layers: int = 3, |
|
|
arc_hidden: int = 500, |
|
|
rel_hidden: int = 100, |
|
|
dropout: float = 0.33, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0) |
|
|
self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden) |
|
|
|
|
|
input_dim = word_dim + char_hidden |
|
|
|
|
|
self.lstm = nn.LSTM( |
|
|
input_dim, lstm_hidden // 2, |
|
|
num_layers=lstm_layers, |
|
|
batch_first=True, |
|
|
bidirectional=True, |
|
|
dropout=dropout if lstm_layers > 1 else 0 |
|
|
) |
|
|
|
|
|
self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout) |
|
|
self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout) |
|
|
self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout) |
|
|
self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout) |
|
|
|
|
|
self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
|
|
self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.n_rels = n_rels |
|
|
|
|
|
def forward(self, words, chars, mask): |
|
|
""" |
|
|
Args: |
|
|
words: (batch, seq_len) |
|
|
chars: (batch, seq_len, max_word_len) |
|
|
mask: (batch, seq_len) |
|
|
Returns: |
|
|
arc_scores: (batch, seq_len, seq_len) |
|
|
rel_scores: (batch, seq_len, seq_len, n_rels) |
|
|
""" |
|
|
|
|
|
word_embeds = self.word_embed(words) |
|
|
char_embeds = self.char_lstm(chars) |
|
|
embeds = torch.cat([word_embeds, char_embeds], dim=-1) |
|
|
embeds = self.dropout(embeds) |
|
|
|
|
|
|
|
|
lengths = mask.sum(dim=1).cpu() |
|
|
packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False) |
|
|
lstm_out, _ = self.lstm(packed) |
|
|
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1)) |
|
|
lstm_out = self.dropout(lstm_out) |
|
|
|
|
|
|
|
|
arc_dep = self.mlp_arc_dep(lstm_out) |
|
|
arc_head = self.mlp_arc_head(lstm_out) |
|
|
rel_dep = self.mlp_rel_dep(lstm_out) |
|
|
rel_head = self.mlp_rel_head(lstm_out) |
|
|
|
|
|
|
|
|
arc_scores = self.arc_attn(arc_dep, arc_head) |
|
|
rel_scores = self.rel_attn(rel_dep, rel_head) |
|
|
|
|
|
return arc_scores, rel_scores |
|
|
|
|
|
def loss(self, arc_scores, rel_scores, heads, rels, mask): |
|
|
"""Compute loss.""" |
|
|
batch_size, seq_len = mask.shape |
|
|
|
|
|
|
|
|
arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf')) |
|
|
arc_loss = F.cross_entropy( |
|
|
arc_scores[mask].view(-1, seq_len), |
|
|
heads[mask], |
|
|
reduction='mean' |
|
|
) |
|
|
|
|
|
|
|
|
rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads] |
|
|
rel_loss = F.cross_entropy( |
|
|
rel_scores_gold[mask], |
|
|
rels[mask], |
|
|
reduction='mean' |
|
|
) |
|
|
|
|
|
return arc_loss + rel_loss |
|
|
|
|
|
def decode(self, arc_scores, rel_scores, mask): |
|
|
"""Decode predictions.""" |
|
|
|
|
|
arc_preds = arc_scores.argmax(dim=-1) |
|
|
|
|
|
batch_size, seq_len = mask.shape |
|
|
rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds] |
|
|
rel_preds = rel_scores_pred.argmax(dim=-1) |
|
|
|
|
|
return arc_preds, rel_preds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(model, dataloader, device): |
|
|
"""Evaluate model and return UAS/LAS.""" |
|
|
model.eval() |
|
|
|
|
|
total_arcs = 0 |
|
|
correct_arcs = 0 |
|
|
correct_rels = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] |
|
|
|
|
|
arc_scores, rel_scores = model(words, chars, mask) |
|
|
arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask) |
|
|
|
|
|
|
|
|
arc_correct = (arc_preds == heads) & mask |
|
|
rel_correct = (rel_preds == rels) & mask & arc_correct |
|
|
|
|
|
total_arcs += mask.sum().item() |
|
|
correct_arcs += arc_correct.sum().item() |
|
|
correct_rels += rel_correct.sum().item() |
|
|
|
|
|
uas = correct_arcs / total_arcs * 100 |
|
|
las = correct_rels / total_arcs * 100 |
|
|
|
|
|
return uas, las |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option('--output', '-o', default='models/bamboo-1', help='Output directory') |
|
|
@click.option('--epochs', default=100, type=int, help='Number of epochs') |
|
|
@click.option('--batch-size', default=32, type=int, help='Batch size') |
|
|
@click.option('--lr', default=2e-3, type=float, help='Learning rate') |
|
|
@click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size') |
|
|
@click.option('--lstm-layers', default=3, type=int, help='LSTM layers') |
|
|
@click.option('--patience', default=10, type=int, help='Early stopping patience') |
|
|
@click.option('--force-download', is_flag=True, help='Force re-download dataset') |
|
|
@click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation') |
|
|
@click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds') |
|
|
@click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') |
|
|
@click.option('--wandb-project', default='bamboo-1', help='W&B project name') |
|
|
@click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)') |
|
|
@click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)') |
|
|
def train(output, epochs, batch_size, lr, lstm_hidden, lstm_layers, patience, force_download, gpu_type, cost_interval, use_wandb, wandb_project, max_time, sample): |
|
|
"""Train Bamboo-1 Vietnamese Dependency Parser.""" |
|
|
|
|
|
|
|
|
hardware = detect_hardware() |
|
|
detected_gpu_type = hardware.get_gpu_type() |
|
|
|
|
|
|
|
|
if gpu_type == "RTX_A4000": |
|
|
gpu_type = detected_gpu_type |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
click.echo(f"Using device: {device}") |
|
|
click.echo(f"Hardware: {hardware}") |
|
|
|
|
|
|
|
|
if use_wandb: |
|
|
import wandb |
|
|
wandb.init( |
|
|
project=wandb_project, |
|
|
config={ |
|
|
"epochs": epochs, |
|
|
"batch_size": batch_size, |
|
|
"lr": lr, |
|
|
"lstm_hidden": lstm_hidden, |
|
|
"lstm_layers": lstm_layers, |
|
|
"patience": patience, |
|
|
"gpu_type": gpu_type, |
|
|
"hardware": hardware.to_dict(), |
|
|
} |
|
|
) |
|
|
click.echo(f"W&B logging enabled: {wandb.run.url}") |
|
|
|
|
|
click.echo("=" * 60) |
|
|
click.echo("Bamboo-1: Vietnamese Dependency Parser") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
|
|
|
click.echo("\nLoading UDD-1 corpus...") |
|
|
corpus = UDD1Corpus(force_download=force_download) |
|
|
|
|
|
train_sents = read_conllu(corpus.train) |
|
|
dev_sents = read_conllu(corpus.dev) |
|
|
test_sents = read_conllu(corpus.test) |
|
|
|
|
|
|
|
|
if sample > 0: |
|
|
train_sents = train_sents[:sample] |
|
|
dev_sents = dev_sents[:min(sample // 2, len(dev_sents))] |
|
|
test_sents = test_sents[:min(sample // 2, len(test_sents))] |
|
|
click.echo(f" Sampling {sample} sentences...") |
|
|
|
|
|
click.echo(f" Train: {len(train_sents)} sentences") |
|
|
click.echo(f" Dev: {len(dev_sents)} sentences") |
|
|
click.echo(f" Test: {len(test_sents)} sentences") |
|
|
|
|
|
|
|
|
click.echo("\nBuilding vocabulary...") |
|
|
vocab = Vocabulary(min_freq=2) |
|
|
vocab.build(train_sents) |
|
|
click.echo(f" Words: {vocab.n_words}") |
|
|
click.echo(f" Chars: {vocab.n_chars}") |
|
|
click.echo(f" Relations: {vocab.n_rels}") |
|
|
|
|
|
|
|
|
train_dataset = DependencyDataset(train_sents, vocab) |
|
|
dev_dataset = DependencyDataset(dev_sents, vocab) |
|
|
test_dataset = DependencyDataset(test_sents, vocab) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) |
|
|
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn) |
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn) |
|
|
|
|
|
|
|
|
click.echo("\nInitializing model...") |
|
|
model = BiaffineDependencyParser( |
|
|
n_words=vocab.n_words, |
|
|
n_chars=vocab.n_chars, |
|
|
n_rels=vocab.n_rels, |
|
|
lstm_hidden=lstm_hidden, |
|
|
lstm_layers=lstm_layers, |
|
|
).to(device) |
|
|
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
click.echo(f" Parameters: {n_params:,}") |
|
|
|
|
|
|
|
|
optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9)) |
|
|
scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000)) |
|
|
|
|
|
|
|
|
click.echo(f"\nTraining for {epochs} epochs...") |
|
|
if max_time > 0: |
|
|
click.echo(f"Time limit: {max_time} minutes") |
|
|
output_path = Path(output) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
cost_tracker = CostTracker(gpu_type=gpu_type) |
|
|
cost_tracker.report_interval = cost_interval |
|
|
cost_tracker.start() |
|
|
click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr") |
|
|
|
|
|
best_las = -1 |
|
|
no_improve = 0 |
|
|
time_limit_seconds = max_time * 60 if max_time > 0 else float('inf') |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
|
|
|
|
if cost_tracker.elapsed_seconds() >= time_limit_seconds: |
|
|
click.echo(f"\nTime limit reached ({max_time} minutes)") |
|
|
break |
|
|
model.train() |
|
|
total_loss = 0 |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False) |
|
|
for batch in pbar: |
|
|
words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch] |
|
|
|
|
|
optimizer.zero_grad() |
|
|
arc_scores, rel_scores = model(words, chars, mask) |
|
|
loss = model.loss(arc_scores, rel_scores, heads, rels, mask) |
|
|
loss.backward() |
|
|
nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
pbar.set_postfix({'loss': f'{loss.item():.4f}'}) |
|
|
|
|
|
|
|
|
dev_uas, dev_las = evaluate(model, dev_loader, device) |
|
|
|
|
|
|
|
|
progress = epoch / epochs |
|
|
current_cost = cost_tracker.current_cost() |
|
|
estimated_total_cost = cost_tracker.estimate_total_cost(progress) |
|
|
elapsed_minutes = cost_tracker.elapsed_seconds() / 60 |
|
|
|
|
|
cost_status = cost_tracker.update(epoch, epochs) |
|
|
if cost_status: |
|
|
click.echo(f" [{cost_status}]") |
|
|
|
|
|
avg_loss = total_loss / len(train_loader) |
|
|
click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
|
|
f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%") |
|
|
|
|
|
|
|
|
if use_wandb: |
|
|
wandb.log({ |
|
|
"epoch": epoch, |
|
|
"train/loss": avg_loss, |
|
|
"dev/uas": dev_uas, |
|
|
"dev/las": dev_las, |
|
|
"cost/current_usd": current_cost, |
|
|
"cost/estimated_total_usd": estimated_total_cost, |
|
|
"cost/elapsed_minutes": elapsed_minutes, |
|
|
}) |
|
|
|
|
|
|
|
|
if dev_las >= best_las: |
|
|
best_las = dev_las |
|
|
no_improve = 0 |
|
|
torch.save({ |
|
|
'model': model.state_dict(), |
|
|
'vocab': vocab, |
|
|
'config': { |
|
|
'n_words': vocab.n_words, |
|
|
'n_chars': vocab.n_chars, |
|
|
'n_rels': vocab.n_rels, |
|
|
'lstm_hidden': lstm_hidden, |
|
|
'lstm_layers': lstm_layers, |
|
|
} |
|
|
}, output_path / 'model.pt') |
|
|
click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)") |
|
|
else: |
|
|
no_improve += 1 |
|
|
if no_improve >= patience: |
|
|
click.echo(f"\nEarly stopping after {patience} epochs without improvement") |
|
|
break |
|
|
|
|
|
|
|
|
click.echo("\nLoading best model for final evaluation...") |
|
|
checkpoint = torch.load(output_path / 'model.pt', weights_only=False) |
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
|
|
test_uas, test_las = evaluate(model, test_loader, device) |
|
|
click.echo(f"\nTest Results:") |
|
|
click.echo(f" UAS: {test_uas:.2f}%") |
|
|
click.echo(f" LAS: {test_las:.2f}%") |
|
|
|
|
|
click.echo(f"\nModel saved to: {output_path}") |
|
|
|
|
|
|
|
|
final_cost = cost_tracker.current_cost() |
|
|
click.echo(f"\n{cost_tracker.summary(epoch, epochs)}") |
|
|
|
|
|
|
|
|
if use_wandb: |
|
|
wandb.log({ |
|
|
"test/uas": test_uas, |
|
|
"test/las": test_las, |
|
|
"cost/final_usd": final_cost, |
|
|
}) |
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
train() |
|
|
|