|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Evaluation script for Bamboo-1 Vietnamese Dependency Parser. |
|
|
|
|
|
Supports both BiLSTM and PhoBERT-based models, and multiple datasets: |
|
|
- UDD-1: Main Vietnamese dependency dataset (~18K sentences) |
|
|
- UD Vietnamese VTB: Universal Dependencies benchmark (~3.3K sentences) |
|
|
|
|
|
Usage: |
|
|
uv run scripts/evaluate.py --model models/bamboo-1 |
|
|
uv run scripts/evaluate.py --model models/bamboo-1-phobert --model-type phobert |
|
|
uv run scripts/evaluate.py --model models/bamboo-1-phobert --dataset ud-vtb |
|
|
uv run scripts/evaluate.py --model models/bamboo-1 --split test --detailed |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
from collections import Counter |
|
|
|
|
|
import click |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from bamboo1.corpus import UDD1Corpus |
|
|
from bamboo1.ud_corpus import UDVietnameseVTB |
|
|
|
|
|
|
|
|
def read_conll_sentences(filepath: str): |
|
|
"""Read sentences from a CoNLL-U file.""" |
|
|
sentences = [] |
|
|
current_sentence = [] |
|
|
|
|
|
with open(filepath, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line.startswith("#"): |
|
|
continue |
|
|
if not line: |
|
|
if current_sentence: |
|
|
sentences.append(current_sentence) |
|
|
current_sentence = [] |
|
|
else: |
|
|
parts = line.split("\t") |
|
|
if len(parts) >= 8 and not "-" in parts[0] and not "." in parts[0]: |
|
|
current_sentence.append({ |
|
|
"id": int(parts[0]), |
|
|
"form": parts[1], |
|
|
"upos": parts[3], |
|
|
"head": int(parts[6]), |
|
|
"deprel": parts[7], |
|
|
}) |
|
|
|
|
|
if current_sentence: |
|
|
sentences.append(current_sentence) |
|
|
|
|
|
return sentences |
|
|
|
|
|
|
|
|
def calculate_attachment_scores(gold_sentences, pred_sentences): |
|
|
"""Calculate UAS and LAS scores.""" |
|
|
total_tokens = 0 |
|
|
correct_heads = 0 |
|
|
correct_labels = 0 |
|
|
|
|
|
deprel_stats = Counter() |
|
|
deprel_correct = Counter() |
|
|
|
|
|
for gold_sent, pred_sent in zip(gold_sentences, pred_sentences): |
|
|
for gold_tok, pred_tok in zip(gold_sent, pred_sent): |
|
|
total_tokens += 1 |
|
|
deprel = gold_tok["deprel"] |
|
|
deprel_stats[deprel] += 1 |
|
|
|
|
|
if gold_tok["head"] == pred_tok["head"]: |
|
|
correct_heads += 1 |
|
|
if gold_tok["deprel"] == pred_tok["deprel"]: |
|
|
correct_labels += 1 |
|
|
deprel_correct[deprel] += 1 |
|
|
|
|
|
uas = correct_heads / total_tokens if total_tokens > 0 else 0 |
|
|
las = correct_labels / total_tokens if total_tokens > 0 else 0 |
|
|
|
|
|
per_deprel_scores = {} |
|
|
for deprel in deprel_stats: |
|
|
if deprel_stats[deprel] > 0: |
|
|
per_deprel_scores[deprel] = { |
|
|
"total": deprel_stats[deprel], |
|
|
"correct": deprel_correct[deprel], |
|
|
"accuracy": deprel_correct[deprel] / deprel_stats[deprel], |
|
|
} |
|
|
|
|
|
return { |
|
|
"uas": uas, |
|
|
"las": las, |
|
|
"total_tokens": total_tokens, |
|
|
"correct_heads": correct_heads, |
|
|
"correct_labels": correct_labels, |
|
|
"per_deprel": per_deprel_scores, |
|
|
} |
|
|
|
|
|
|
|
|
def load_phobert_model(model_path, device='cuda'): |
|
|
"""Load PhoBERT-based model.""" |
|
|
import torch |
|
|
from bamboo1.models.transformer_parser import PhoBERTDependencyParser |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
device = 'cpu' |
|
|
|
|
|
return PhoBERTDependencyParser.load(model_path, device=device) |
|
|
|
|
|
|
|
|
def predict_phobert(parser, words): |
|
|
"""Make predictions using PhoBERT model.""" |
|
|
import torch |
|
|
|
|
|
parser.eval() |
|
|
device = next(parser.parameters()).device |
|
|
|
|
|
|
|
|
encoded = parser.tokenize_with_alignment([words]) |
|
|
input_ids = encoded['input_ids'].to(device) |
|
|
attention_mask = encoded['attention_mask'].to(device) |
|
|
word_starts = encoded['word_starts'].to(device) |
|
|
word_mask = encoded['word_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
arc_scores, rel_scores = parser.forward( |
|
|
input_ids, attention_mask, word_starts, word_mask |
|
|
) |
|
|
arc_preds, rel_preds = parser.decode(arc_scores, rel_scores, word_mask) |
|
|
|
|
|
|
|
|
arc_preds = arc_preds[0].cpu().tolist() |
|
|
rel_preds = rel_preds[0].cpu().tolist() |
|
|
|
|
|
results = [] |
|
|
for i, word in enumerate(words): |
|
|
head = arc_preds[i] |
|
|
rel_idx = rel_preds[i] |
|
|
rel = parser.idx2rel.get(rel_idx, "dep") |
|
|
results.append((word, head, rel)) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"--model", "-m", |
|
|
required=True, |
|
|
help="Path to trained model directory", |
|
|
) |
|
|
@click.option( |
|
|
"--model-type", |
|
|
type=click.Choice(["bilstm", "phobert"]), |
|
|
default="bilstm", |
|
|
help="Model type: bilstm (underthesea) or phobert (transformer)", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--dataset", |
|
|
type=click.Choice(["udd1", "ud-vtb"]), |
|
|
default="udd1", |
|
|
help="Dataset: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--split", |
|
|
type=click.Choice(["dev", "test", "both"]), |
|
|
default="test", |
|
|
help="Dataset split to evaluate on", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--detailed", |
|
|
is_flag=True, |
|
|
help="Show detailed per-relation scores", |
|
|
) |
|
|
@click.option( |
|
|
"--output", "-o", |
|
|
help="Save predictions to file (CoNLL-U format)", |
|
|
) |
|
|
def evaluate(model, model_type, dataset, split, detailed, output): |
|
|
"""Evaluate Bamboo-1 Vietnamese Dependency Parser. |
|
|
|
|
|
Supports both BiLSTM (underthesea) and PhoBERT-based models, |
|
|
and evaluation on UDD-1 or UD Vietnamese VTB datasets. |
|
|
""" |
|
|
click.echo("=" * 60) |
|
|
click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
|
|
|
click.echo(f"\nLoading {model_type} model from {model}...") |
|
|
if model_type == "phobert": |
|
|
parser = load_phobert_model(model) |
|
|
predict_fn = lambda words: predict_phobert(parser, words) |
|
|
else: |
|
|
from underthesea.models.dependency_parser import DependencyParser |
|
|
parser = DependencyParser.load(model) |
|
|
predict_fn = lambda words: parser.predict(" ".join(words)) |
|
|
|
|
|
|
|
|
click.echo(f"Loading {dataset.upper()} corpus...") |
|
|
if dataset == "udd1": |
|
|
corpus = UDD1Corpus() |
|
|
else: |
|
|
corpus = UDVietnameseVTB() |
|
|
|
|
|
splits_to_eval = [] |
|
|
if split == "both": |
|
|
splits_to_eval = [("dev", corpus.dev), ("test", corpus.test)] |
|
|
elif split == "dev": |
|
|
splits_to_eval = [("dev", corpus.dev)] |
|
|
else: |
|
|
splits_to_eval = [("test", corpus.test)] |
|
|
|
|
|
for split_name, split_path in splits_to_eval: |
|
|
click.echo(f"\n{'=' * 40}") |
|
|
click.echo(f"Evaluating on {split_name} set: {split_path}") |
|
|
click.echo("=" * 40) |
|
|
|
|
|
|
|
|
gold_sentences = read_conll_sentences(split_path) |
|
|
click.echo(f" Sentences: {len(gold_sentences)}") |
|
|
click.echo(f" Tokens: {sum(len(s) for s in gold_sentences)}") |
|
|
|
|
|
|
|
|
click.echo("\nMaking predictions...") |
|
|
pred_sentences = [] |
|
|
|
|
|
for gold_sent in gold_sentences: |
|
|
|
|
|
tokens = [tok["form"] for tok in gold_sent] |
|
|
|
|
|
|
|
|
result = predict_fn(tokens) |
|
|
|
|
|
|
|
|
pred_sent = [] |
|
|
for i, (word, head, deprel) in enumerate(result): |
|
|
pred_sent.append({ |
|
|
"id": i + 1, |
|
|
"form": word, |
|
|
"head": head, |
|
|
"deprel": deprel, |
|
|
}) |
|
|
pred_sentences.append(pred_sent) |
|
|
|
|
|
|
|
|
scores = calculate_attachment_scores(gold_sentences, pred_sentences) |
|
|
|
|
|
click.echo(f"\nResults:") |
|
|
click.echo(f" UAS: {scores['uas']:.4f} ({scores['uas']*100:.2f}%)") |
|
|
click.echo(f" LAS: {scores['las']:.4f} ({scores['las']*100:.2f}%)") |
|
|
click.echo(f" Total tokens: {scores['total_tokens']}") |
|
|
click.echo(f" Correct heads: {scores['correct_heads']}") |
|
|
click.echo(f" Correct labels: {scores['correct_labels']}") |
|
|
|
|
|
if detailed: |
|
|
click.echo("\nPer-relation scores:") |
|
|
click.echo("-" * 50) |
|
|
click.echo(f"{'Relation':<15} {'Count':>8} {'Correct':>8} {'Accuracy':>10}") |
|
|
click.echo("-" * 50) |
|
|
|
|
|
for deprel in sorted(scores["per_deprel"].keys()): |
|
|
stats = scores["per_deprel"][deprel] |
|
|
click.echo( |
|
|
f"{deprel:<15} {stats['total']:>8} {stats['correct']:>8} " |
|
|
f"{stats['accuracy']*100:>9.2f}%" |
|
|
) |
|
|
|
|
|
|
|
|
if output: |
|
|
out_path = Path(output) |
|
|
if split_name != "test": |
|
|
out_path = out_path.with_stem(f"{out_path.stem}_{split_name}") |
|
|
|
|
|
click.echo(f"\nSaving predictions to {out_path}...") |
|
|
with open(out_path, "w", encoding="utf-8") as f: |
|
|
for i, (gold_sent, pred_sent) in enumerate(zip(gold_sentences, pred_sentences)): |
|
|
f.write(f"# sent_id = {i + 1}\n") |
|
|
for gold_tok, pred_tok in zip(gold_sent, pred_sent): |
|
|
f.write( |
|
|
f"{gold_tok['id']}\t{gold_tok['form']}\t_\t{gold_tok['upos']}\t_\t_\t" |
|
|
f"{pred_tok['head']}\t{pred_tok['deprel']}\t_\t_\n" |
|
|
) |
|
|
f.write("\n") |
|
|
|
|
|
click.echo("\nEvaluation complete!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
evaluate() |
|
|
|