| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Evaluation script for Bamboo-1 Vietnamese Dependency Parser. |
| |
| Supports both BiLSTM and PhoBERT-based models, and multiple datasets: |
| - UDD-1: Main Vietnamese dependency dataset (~18K sentences) |
| - UD Vietnamese VTB: Universal Dependencies benchmark (~3.3K sentences) |
| |
| Usage: |
| uv run scripts/evaluate.py --model models/bamboo-1 |
| uv run scripts/evaluate.py --model models/bamboo-1-phobert --model-type phobert |
| uv run scripts/evaluate.py --model models/bamboo-1-phobert --dataset ud-vtb |
| uv run scripts/evaluate.py --model models/bamboo-1 --split test --detailed |
| """ |
|
|
| import sys |
| from pathlib import Path |
| from collections import Counter |
|
|
| import click |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from src.corpus import UDD1Corpus |
| from src.ud_corpus import UDVietnameseVTB |
|
|
|
|
| def read_conll_sentences(filepath: str): |
| """Read sentences from a CoNLL-U file.""" |
| sentences = [] |
| current_sentence = [] |
|
|
| with open(filepath, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line.startswith("#"): |
| continue |
| if not line: |
| if current_sentence: |
| sentences.append(current_sentence) |
| current_sentence = [] |
| else: |
| parts = line.split("\t") |
| if len(parts) >= 8 and not "-" in parts[0] and not "." in parts[0]: |
| current_sentence.append({ |
| "id": int(parts[0]), |
| "form": parts[1], |
| "upos": parts[3], |
| "head": int(parts[6]), |
| "deprel": parts[7], |
| }) |
|
|
| if current_sentence: |
| sentences.append(current_sentence) |
|
|
| return sentences |
|
|
|
|
| def calculate_attachment_scores(gold_sentences, pred_sentences): |
| """Calculate UAS and LAS scores.""" |
| total_tokens = 0 |
| correct_heads = 0 |
| correct_labels = 0 |
|
|
| deprel_stats = Counter() |
| deprel_correct = Counter() |
|
|
| for gold_sent, pred_sent in zip(gold_sentences, pred_sentences): |
| for gold_tok, pred_tok in zip(gold_sent, pred_sent): |
| total_tokens += 1 |
| deprel = gold_tok["deprel"] |
| deprel_stats[deprel] += 1 |
|
|
| if gold_tok["head"] == pred_tok["head"]: |
| correct_heads += 1 |
| if gold_tok["deprel"] == pred_tok["deprel"]: |
| correct_labels += 1 |
| deprel_correct[deprel] += 1 |
|
|
| uas = correct_heads / total_tokens if total_tokens > 0 else 0 |
| las = correct_labels / total_tokens if total_tokens > 0 else 0 |
|
|
| per_deprel_scores = {} |
| for deprel in deprel_stats: |
| if deprel_stats[deprel] > 0: |
| per_deprel_scores[deprel] = { |
| "total": deprel_stats[deprel], |
| "correct": deprel_correct[deprel], |
| "accuracy": deprel_correct[deprel] / deprel_stats[deprel], |
| } |
|
|
| return { |
| "uas": uas, |
| "las": las, |
| "total_tokens": total_tokens, |
| "correct_heads": correct_heads, |
| "correct_labels": correct_labels, |
| "per_deprel": per_deprel_scores, |
| } |
|
|
|
|
| def load_phobert_model(model_path, device='cuda'): |
| """Load PhoBERT-based model.""" |
| import torch |
| from src.models.transformer_parser import PhoBERTDependencyParser |
|
|
| if not torch.cuda.is_available(): |
| device = 'cpu' |
|
|
| return PhoBERTDependencyParser.load(model_path, device=device) |
|
|
|
|
| def predict_phobert(parser, words): |
| """Make predictions using PhoBERT model.""" |
| import torch |
|
|
| parser.eval() |
| device = next(parser.parameters()).device |
|
|
| |
| encoded = parser.tokenize_with_alignment([words]) |
| input_ids = encoded['input_ids'].to(device) |
| attention_mask = encoded['attention_mask'].to(device) |
| word_starts = encoded['word_starts'].to(device) |
| word_mask = encoded['word_mask'].to(device) |
|
|
| with torch.no_grad(): |
| arc_scores, rel_scores = parser.forward( |
| input_ids, attention_mask, word_starts, word_mask |
| ) |
| arc_preds, rel_preds = parser.decode(arc_scores, rel_scores, word_mask) |
|
|
| |
| arc_preds = arc_preds[0].cpu().tolist() |
| rel_preds = rel_preds[0].cpu().tolist() |
|
|
| results = [] |
| for i, word in enumerate(words): |
| head = arc_preds[i] |
| rel_idx = rel_preds[i] |
| rel = parser.idx2rel.get(rel_idx, "dep") |
| results.append((word, head, rel)) |
|
|
| return results |
|
|
|
|
| @click.command() |
| @click.option( |
| "--model", "-m", |
| required=True, |
| help="Path to trained model directory", |
| ) |
| @click.option( |
| "--model-type", |
| type=click.Choice(["bilstm", "phobert"]), |
| default="bilstm", |
| help="Model type: bilstm (underthesea) or phobert (transformer)", |
| show_default=True, |
| ) |
| @click.option( |
| "--dataset", |
| type=click.Choice(["udd1", "ud-vtb"]), |
| default="udd1", |
| help="Dataset: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)", |
| show_default=True, |
| ) |
| @click.option( |
| "--split", |
| type=click.Choice(["dev", "test", "both"]), |
| default="test", |
| help="Dataset split to evaluate on", |
| show_default=True, |
| ) |
| @click.option( |
| "--detailed", |
| is_flag=True, |
| help="Show detailed per-relation scores", |
| ) |
| @click.option( |
| "--output", "-o", |
| help="Save predictions to file (CoNLL-U format)", |
| ) |
| def evaluate(model, model_type, dataset, split, detailed, output): |
| """Evaluate Bamboo-1 Vietnamese Dependency Parser. |
| |
| Supports both BiLSTM (underthesea) and PhoBERT-based models, |
| and evaluation on UDD-1 or UD Vietnamese VTB datasets. |
| """ |
| click.echo("=" * 60) |
| click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation") |
| click.echo("=" * 60) |
|
|
| |
| click.echo(f"\nLoading {model_type} model from {model}...") |
| if model_type == "phobert": |
| parser = load_phobert_model(model) |
| predict_fn = lambda words: predict_phobert(parser, words) |
| else: |
| from underthesea.models.dependency_parser import DependencyParser |
| parser = DependencyParser.load(model) |
| predict_fn = lambda words: parser.predict(" ".join(words)) |
|
|
| |
| click.echo(f"Loading {dataset.upper()} corpus...") |
| if dataset == "udd1": |
| corpus = UDD1Corpus() |
| else: |
| corpus = UDVietnameseVTB() |
|
|
| splits_to_eval = [] |
| if split == "both": |
| splits_to_eval = [("dev", corpus.dev), ("test", corpus.test)] |
| elif split == "dev": |
| splits_to_eval = [("dev", corpus.dev)] |
| else: |
| splits_to_eval = [("test", corpus.test)] |
|
|
| for split_name, split_path in splits_to_eval: |
| click.echo(f"\n{'=' * 40}") |
| click.echo(f"Evaluating on {split_name} set: {split_path}") |
| click.echo("=" * 40) |
|
|
| |
| gold_sentences = read_conll_sentences(split_path) |
| click.echo(f" Sentences: {len(gold_sentences)}") |
| click.echo(f" Tokens: {sum(len(s) for s in gold_sentences)}") |
|
|
| |
| click.echo("\nMaking predictions...") |
| pred_sentences = [] |
|
|
| for gold_sent in gold_sentences: |
| |
| tokens = [tok["form"] for tok in gold_sent] |
|
|
| |
| result = predict_fn(tokens) |
|
|
| |
| pred_sent = [] |
| for i, (word, head, deprel) in enumerate(result): |
| pred_sent.append({ |
| "id": i + 1, |
| "form": word, |
| "head": head, |
| "deprel": deprel, |
| }) |
| pred_sentences.append(pred_sent) |
|
|
| |
| scores = calculate_attachment_scores(gold_sentences, pred_sentences) |
|
|
| click.echo(f"\nResults:") |
| click.echo(f" UAS: {scores['uas']:.4f} ({scores['uas']*100:.2f}%)") |
| click.echo(f" LAS: {scores['las']:.4f} ({scores['las']*100:.2f}%)") |
| click.echo(f" Total tokens: {scores['total_tokens']}") |
| click.echo(f" Correct heads: {scores['correct_heads']}") |
| click.echo(f" Correct labels: {scores['correct_labels']}") |
|
|
| if detailed: |
| click.echo("\nPer-relation scores:") |
| click.echo("-" * 50) |
| click.echo(f"{'Relation':<15} {'Count':>8} {'Correct':>8} {'Accuracy':>10}") |
| click.echo("-" * 50) |
|
|
| for deprel in sorted(scores["per_deprel"].keys()): |
| stats = scores["per_deprel"][deprel] |
| click.echo( |
| f"{deprel:<15} {stats['total']:>8} {stats['correct']:>8} " |
| f"{stats['accuracy']*100:>9.2f}%" |
| ) |
|
|
| |
| if output: |
| out_path = Path(output) |
| if split_name != "test": |
| out_path = out_path.with_stem(f"{out_path.stem}_{split_name}") |
|
|
| click.echo(f"\nSaving predictions to {out_path}...") |
| with open(out_path, "w", encoding="utf-8") as f: |
| for i, (gold_sent, pred_sent) in enumerate(zip(gold_sentences, pred_sentences)): |
| f.write(f"# sent_id = {i + 1}\n") |
| for gold_tok, pred_tok in zip(gold_sent, pred_sent): |
| f.write( |
| f"{gold_tok['id']}\t{gold_tok['form']}\t_\t{gold_tok['upos']}\t_\t_\t" |
| f"{pred_tok['head']}\t{pred_tok['deprel']}\t_\t_\n" |
| ) |
| f.write("\n") |
|
|
| click.echo("\nEvaluation complete!") |
|
|
|
|
| if __name__ == "__main__": |
| evaluate() |
|
|