bamboo-1 / scripts /evaluate.py
rain1024's picture
Add PhoBERT-based dependency parser for Trankit reproduction
b39f0e3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "underthesea[deep]>=6.8.0",
# "datasets>=2.14.0",
# "click>=8.0.0",
# "torch>=2.0.0",
# "transformers>=4.30.0",
# ]
# ///
"""
Evaluation script for Bamboo-1 Vietnamese Dependency Parser.
Supports both BiLSTM and PhoBERT-based models, and multiple datasets:
- UDD-1: Main Vietnamese dependency dataset (~18K sentences)
- UD Vietnamese VTB: Universal Dependencies benchmark (~3.3K sentences)
Usage:
uv run scripts/evaluate.py --model models/bamboo-1
uv run scripts/evaluate.py --model models/bamboo-1-phobert --model-type phobert
uv run scripts/evaluate.py --model models/bamboo-1-phobert --dataset ud-vtb
uv run scripts/evaluate.py --model models/bamboo-1 --split test --detailed
"""
import sys
from pathlib import Path
from collections import Counter
import click
# Add parent directory to path for bamboo1 module
sys.path.insert(0, str(Path(__file__).parent.parent))
from bamboo1.corpus import UDD1Corpus
from bamboo1.ud_corpus import UDVietnameseVTB
def read_conll_sentences(filepath: str):
"""Read sentences from a CoNLL-U file."""
sentences = []
current_sentence = []
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("#"):
continue
if not line:
if current_sentence:
sentences.append(current_sentence)
current_sentence = []
else:
parts = line.split("\t")
if len(parts) >= 8 and not "-" in parts[0] and not "." in parts[0]:
current_sentence.append({
"id": int(parts[0]),
"form": parts[1],
"upos": parts[3],
"head": int(parts[6]),
"deprel": parts[7],
})
if current_sentence:
sentences.append(current_sentence)
return sentences
def calculate_attachment_scores(gold_sentences, pred_sentences):
"""Calculate UAS and LAS scores."""
total_tokens = 0
correct_heads = 0
correct_labels = 0
deprel_stats = Counter()
deprel_correct = Counter()
for gold_sent, pred_sent in zip(gold_sentences, pred_sentences):
for gold_tok, pred_tok in zip(gold_sent, pred_sent):
total_tokens += 1
deprel = gold_tok["deprel"]
deprel_stats[deprel] += 1
if gold_tok["head"] == pred_tok["head"]:
correct_heads += 1
if gold_tok["deprel"] == pred_tok["deprel"]:
correct_labels += 1
deprel_correct[deprel] += 1
uas = correct_heads / total_tokens if total_tokens > 0 else 0
las = correct_labels / total_tokens if total_tokens > 0 else 0
per_deprel_scores = {}
for deprel in deprel_stats:
if deprel_stats[deprel] > 0:
per_deprel_scores[deprel] = {
"total": deprel_stats[deprel],
"correct": deprel_correct[deprel],
"accuracy": deprel_correct[deprel] / deprel_stats[deprel],
}
return {
"uas": uas,
"las": las,
"total_tokens": total_tokens,
"correct_heads": correct_heads,
"correct_labels": correct_labels,
"per_deprel": per_deprel_scores,
}
def load_phobert_model(model_path, device='cuda'):
"""Load PhoBERT-based model."""
import torch
from bamboo1.models.transformer_parser import PhoBERTDependencyParser
if not torch.cuda.is_available():
device = 'cpu'
return PhoBERTDependencyParser.load(model_path, device=device)
def predict_phobert(parser, words):
"""Make predictions using PhoBERT model."""
import torch
parser.eval()
device = next(parser.parameters()).device
# Tokenize
encoded = parser.tokenize_with_alignment([words])
input_ids = encoded['input_ids'].to(device)
attention_mask = encoded['attention_mask'].to(device)
word_starts = encoded['word_starts'].to(device)
word_mask = encoded['word_mask'].to(device)
with torch.no_grad():
arc_scores, rel_scores = parser.forward(
input_ids, attention_mask, word_starts, word_mask
)
arc_preds, rel_preds = parser.decode(arc_scores, rel_scores, word_mask)
# Convert to list
arc_preds = arc_preds[0].cpu().tolist()
rel_preds = rel_preds[0].cpu().tolist()
results = []
for i, word in enumerate(words):
head = arc_preds[i]
rel_idx = rel_preds[i]
rel = parser.idx2rel.get(rel_idx, "dep")
results.append((word, head, rel))
return results
@click.command()
@click.option(
"--model", "-m",
required=True,
help="Path to trained model directory",
)
@click.option(
"--model-type",
type=click.Choice(["bilstm", "phobert"]),
default="bilstm",
help="Model type: bilstm (underthesea) or phobert (transformer)",
show_default=True,
)
@click.option(
"--dataset",
type=click.Choice(["udd1", "ud-vtb"]),
default="udd1",
help="Dataset: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)",
show_default=True,
)
@click.option(
"--split",
type=click.Choice(["dev", "test", "both"]),
default="test",
help="Dataset split to evaluate on",
show_default=True,
)
@click.option(
"--detailed",
is_flag=True,
help="Show detailed per-relation scores",
)
@click.option(
"--output", "-o",
help="Save predictions to file (CoNLL-U format)",
)
def evaluate(model, model_type, dataset, split, detailed, output):
"""Evaluate Bamboo-1 Vietnamese Dependency Parser.
Supports both BiLSTM (underthesea) and PhoBERT-based models,
and evaluation on UDD-1 or UD Vietnamese VTB datasets.
"""
click.echo("=" * 60)
click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation")
click.echo("=" * 60)
# Load model
click.echo(f"\nLoading {model_type} model from {model}...")
if model_type == "phobert":
parser = load_phobert_model(model)
predict_fn = lambda words: predict_phobert(parser, words)
else:
from underthesea.models.dependency_parser import DependencyParser
parser = DependencyParser.load(model)
predict_fn = lambda words: parser.predict(" ".join(words))
# Load corpus
click.echo(f"Loading {dataset.upper()} corpus...")
if dataset == "udd1":
corpus = UDD1Corpus()
else:
corpus = UDVietnameseVTB()
splits_to_eval = []
if split == "both":
splits_to_eval = [("dev", corpus.dev), ("test", corpus.test)]
elif split == "dev":
splits_to_eval = [("dev", corpus.dev)]
else:
splits_to_eval = [("test", corpus.test)]
for split_name, split_path in splits_to_eval:
click.echo(f"\n{'=' * 40}")
click.echo(f"Evaluating on {split_name} set: {split_path}")
click.echo("=" * 40)
# Read gold data
gold_sentences = read_conll_sentences(split_path)
click.echo(f" Sentences: {len(gold_sentences)}")
click.echo(f" Tokens: {sum(len(s) for s in gold_sentences)}")
# Make predictions
click.echo("\nMaking predictions...")
pred_sentences = []
for gold_sent in gold_sentences:
# Get tokens
tokens = [tok["form"] for tok in gold_sent]
# Parse
result = predict_fn(tokens)
# Convert result to same format as gold
pred_sent = []
for i, (word, head, deprel) in enumerate(result):
pred_sent.append({
"id": i + 1,
"form": word,
"head": head,
"deprel": deprel,
})
pred_sentences.append(pred_sent)
# Calculate scores
scores = calculate_attachment_scores(gold_sentences, pred_sentences)
click.echo(f"\nResults:")
click.echo(f" UAS: {scores['uas']:.4f} ({scores['uas']*100:.2f}%)")
click.echo(f" LAS: {scores['las']:.4f} ({scores['las']*100:.2f}%)")
click.echo(f" Total tokens: {scores['total_tokens']}")
click.echo(f" Correct heads: {scores['correct_heads']}")
click.echo(f" Correct labels: {scores['correct_labels']}")
if detailed:
click.echo("\nPer-relation scores:")
click.echo("-" * 50)
click.echo(f"{'Relation':<15} {'Count':>8} {'Correct':>8} {'Accuracy':>10}")
click.echo("-" * 50)
for deprel in sorted(scores["per_deprel"].keys()):
stats = scores["per_deprel"][deprel]
click.echo(
f"{deprel:<15} {stats['total']:>8} {stats['correct']:>8} "
f"{stats['accuracy']*100:>9.2f}%"
)
# Save predictions if requested
if output:
out_path = Path(output)
if split_name != "test":
out_path = out_path.with_stem(f"{out_path.stem}_{split_name}")
click.echo(f"\nSaving predictions to {out_path}...")
with open(out_path, "w", encoding="utf-8") as f:
for i, (gold_sent, pred_sent) in enumerate(zip(gold_sentences, pred_sentences)):
f.write(f"# sent_id = {i + 1}\n")
for gold_tok, pred_tok in zip(gold_sent, pred_sent):
f.write(
f"{gold_tok['id']}\t{gold_tok['form']}\t_\t{gold_tok['upos']}\t_\t_\t"
f"{pred_tok['head']}\t{pred_tok['deprel']}\t_\t_\n"
)
f.write("\n")
click.echo("\nEvaluation complete!")
if __name__ == "__main__":
evaluate()