|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Evaluation script for Vietnamese POS Tagger (TRE-1). |
|
|
|
|
|
Usage: |
|
|
uv run scripts/evaluate.py |
|
|
uv run scripts/evaluate.py --version v1.0.0 |
|
|
uv run scripts/evaluate.py --model models/pos_tagger/v1.0.0/model.crfsuite |
|
|
uv run scripts/evaluate.py --save-plots |
|
|
""" |
|
|
|
|
|
import re |
|
|
from collections import Counter |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
import pycrfsuite |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, |
|
|
precision_recall_fscore_support, |
|
|
classification_report, |
|
|
confusion_matrix, |
|
|
) |
|
|
|
|
|
|
|
|
FEATURE_TEMPLATES = [ |
|
|
"T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", |
|
|
"T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", |
|
|
"T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", |
|
|
"T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", |
|
|
"T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", |
|
|
"T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", |
|
|
"T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", |
|
|
] |
|
|
|
|
|
|
|
|
def get_token_value(tokens, position, index): |
|
|
actual_pos = position + index |
|
|
if actual_pos < 0: |
|
|
return "__BOS__" |
|
|
elif actual_pos >= len(tokens): |
|
|
return "__EOS__" |
|
|
return tokens[actual_pos] |
|
|
|
|
|
|
|
|
def apply_attribute(value, attribute, dictionary=None): |
|
|
if value in ("__BOS__", "__EOS__"): |
|
|
return value |
|
|
if attribute is None: |
|
|
return value |
|
|
elif attribute == "lower": |
|
|
return value.lower() |
|
|
elif attribute == "upper": |
|
|
return value.upper() |
|
|
elif attribute == "istitle": |
|
|
return str(value.istitle()) |
|
|
elif attribute == "isupper": |
|
|
return str(value.isupper()) |
|
|
elif attribute == "islower": |
|
|
return str(value.islower()) |
|
|
elif attribute == "isdigit": |
|
|
return str(value.isdigit()) |
|
|
elif attribute == "isalpha": |
|
|
return str(value.isalpha()) |
|
|
elif attribute == "is_in_dict": |
|
|
return str(value in dictionary) if dictionary else "False" |
|
|
elif attribute.startswith("prefix"): |
|
|
n = int(attribute[6:]) if len(attribute) > 6 else 2 |
|
|
return value[:n] if len(value) >= n else value |
|
|
elif attribute.startswith("suffix"): |
|
|
n = int(attribute[6:]) if len(attribute) > 6 else 2 |
|
|
return value[-n:] if len(value) >= n else value |
|
|
return value |
|
|
|
|
|
|
|
|
def parse_template(template): |
|
|
match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template) |
|
|
if not match: |
|
|
return None, None |
|
|
indices_str = match.group(1) |
|
|
attribute = match.group(2) |
|
|
indices = [int(i.strip()) for i in indices_str.split(",")] |
|
|
return indices, attribute |
|
|
|
|
|
|
|
|
def extract_features(tokens, position, dictionary=None): |
|
|
features = {} |
|
|
for template in FEATURE_TEMPLATES: |
|
|
indices, attribute = parse_template(template) |
|
|
if indices is None: |
|
|
continue |
|
|
if len(indices) == 1: |
|
|
value = get_token_value(tokens, position, indices[0]) |
|
|
value = apply_attribute(value, attribute, dictionary) |
|
|
features[template] = value |
|
|
else: |
|
|
values = [get_token_value(tokens, position, idx) for idx in indices] |
|
|
if attribute == "is_in_dict": |
|
|
combined = " ".join(values) |
|
|
features[template] = str(combined in dictionary) if dictionary else "False" |
|
|
else: |
|
|
combined = "|".join(values) |
|
|
features[template] = combined |
|
|
return features |
|
|
|
|
|
|
|
|
def sentence_to_features(tokens): |
|
|
return [ |
|
|
[f"{k}={v}" for k, v in extract_features(tokens, i).items()] |
|
|
for i in range(len(tokens)) |
|
|
] |
|
|
|
|
|
|
|
|
def load_test_data(): |
|
|
click.echo("Loading UDD-1 dataset...") |
|
|
dataset = load_dataset("undertheseanlp/UDD-1") |
|
|
|
|
|
sentences = [] |
|
|
for item in dataset["test"]: |
|
|
tokens = item["tokens"] |
|
|
tags = item["upos"] |
|
|
if tokens and tags: |
|
|
sentences.append((tokens, tags)) |
|
|
|
|
|
click.echo(f"Test set: {len(sentences)} sentences") |
|
|
return sentences |
|
|
|
|
|
|
|
|
def plot_confusion_matrix(y_true, y_pred, labels, output_path): |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
|
|
|
cm = confusion_matrix(y_true, y_pred, labels=labels) |
|
|
|
|
|
plt.figure(figsize=(12, 10)) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt="d", |
|
|
cmap="Blues", |
|
|
xticklabels=labels, |
|
|
yticklabels=labels, |
|
|
) |
|
|
plt.xlabel("Predicted") |
|
|
plt.ylabel("True") |
|
|
plt.title("Confusion Matrix - Vietnamese POS Tagger (TRE-1)") |
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=150) |
|
|
plt.close() |
|
|
click.echo(f"Confusion matrix saved to {output_path}") |
|
|
|
|
|
|
|
|
def plot_per_tag_metrics(report_dict, output_path): |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
tags = [k for k in report_dict.keys() if k not in ("accuracy", "macro avg", "weighted avg")] |
|
|
|
|
|
precision = [report_dict[t]["precision"] for t in tags] |
|
|
recall = [report_dict[t]["recall"] for t in tags] |
|
|
f1 = [report_dict[t]["f1-score"] for t in tags] |
|
|
|
|
|
x = range(len(tags)) |
|
|
width = 0.25 |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 6)) |
|
|
ax.bar([i - width for i in x], precision, width, label="Precision", color="#2ecc71") |
|
|
ax.bar(x, recall, width, label="Recall", color="#3498db") |
|
|
ax.bar([i + width for i in x], f1, width, label="F1-Score", color="#e74c3c") |
|
|
|
|
|
ax.set_xlabel("POS Tag") |
|
|
ax.set_ylabel("Score") |
|
|
ax.set_title("Per-Tag Performance Metrics - Vietnamese POS Tagger (TRE-1)") |
|
|
ax.set_xticks(x) |
|
|
ax.set_xticklabels(tags, rotation=45) |
|
|
ax.legend() |
|
|
ax.set_ylim(0, 1.1) |
|
|
ax.grid(axis="y", alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=150) |
|
|
plt.close() |
|
|
click.echo(f"Per-tag metrics saved to {output_path}") |
|
|
|
|
|
|
|
|
def analyze_errors(y_true, y_pred, tokens_flat, top_n=10): |
|
|
"""Analyze common error patterns.""" |
|
|
errors = Counter() |
|
|
error_examples = {} |
|
|
|
|
|
for true, pred, token in zip(y_true, y_pred, tokens_flat): |
|
|
if true != pred: |
|
|
key = (true, pred) |
|
|
errors[key] += 1 |
|
|
if key not in error_examples: |
|
|
error_examples[key] = token |
|
|
|
|
|
click.echo(f"\nTop {top_n} Error Patterns:") |
|
|
click.echo("-" * 60) |
|
|
click.echo(f"{'True':<10} {'Predicted':<10} {'Count':<8} {'Example'}") |
|
|
click.echo("-" * 60) |
|
|
|
|
|
for (true, pred), count in errors.most_common(top_n): |
|
|
example = error_examples.get((true, pred), "") |
|
|
click.echo(f"{true:<10} {pred:<10} {count:<8} {example}") |
|
|
|
|
|
|
|
|
def get_latest_version(task="pos_tagger"): |
|
|
"""Get the latest model version (sorted by timestamp).""" |
|
|
models_dir = PROJECT_ROOT / "models" / task |
|
|
if not models_dir.exists(): |
|
|
return None |
|
|
versions = [d.name for d in models_dir.iterdir() if d.is_dir()] |
|
|
if not versions: |
|
|
return None |
|
|
return sorted(versions)[-1] |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"--version", "-v", |
|
|
default=None, |
|
|
help="Model version to evaluate (default: latest)", |
|
|
) |
|
|
@click.option( |
|
|
"--model", "-m", |
|
|
default=None, |
|
|
help="Custom model path (overrides version-based path)", |
|
|
) |
|
|
@click.option( |
|
|
"--save-plots", |
|
|
is_flag=True, |
|
|
help="Save confusion matrix and per-tag metrics plots", |
|
|
) |
|
|
def evaluate(version, model, save_plots): |
|
|
"""Evaluate Vietnamese POS Tagger on UDD-1 test set.""" |
|
|
|
|
|
if version is None and model is None: |
|
|
version = get_latest_version("pos_tagger") |
|
|
if version is None: |
|
|
raise click.ClickException("No models found in models/pos_tagger/") |
|
|
|
|
|
|
|
|
if model: |
|
|
model_path = Path(model) |
|
|
else: |
|
|
model_path = PROJECT_ROOT / "models" / "pos_tagger" / version / "model.crfsuite" |
|
|
|
|
|
|
|
|
if save_plots: |
|
|
results_dir = PROJECT_ROOT / "results" / "pos_tagger" |
|
|
results_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
click.echo(f"Loading model from {model_path}...") |
|
|
tagger = pycrfsuite.Tagger() |
|
|
tagger.open(str(model_path)) |
|
|
|
|
|
test_data = load_test_data() |
|
|
|
|
|
click.echo("Extracting features and predicting...") |
|
|
X_test = [sentence_to_features(tokens) for tokens, _ in test_data] |
|
|
y_test = [tags for _, tags in test_data] |
|
|
tokens_test = [tokens for tokens, _ in test_data] |
|
|
|
|
|
y_pred = [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
y_test_flat = [tag for tags in y_test for tag in tags] |
|
|
y_pred_flat = [tag for tags in y_pred for tag in tags] |
|
|
tokens_flat = [token for tokens in tokens_test for token in tokens] |
|
|
|
|
|
|
|
|
labels = sorted(set(y_test_flat)) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(y_test_flat, y_pred_flat) |
|
|
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
|
|
y_test_flat, y_pred_flat, average="macro" |
|
|
) |
|
|
_, _, f1_weighted, _ = precision_recall_fscore_support( |
|
|
y_test_flat, y_pred_flat, average="weighted" |
|
|
) |
|
|
|
|
|
click.echo("\n" + "=" * 60) |
|
|
click.echo("EVALUATION RESULTS") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
click.echo("\nOverall Metrics:") |
|
|
click.echo(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") |
|
|
click.echo(f" Precision (macro): {precision_macro:.4f}") |
|
|
click.echo(f" Recall (macro): {recall_macro:.4f}") |
|
|
click.echo(f" F1 (macro): {f1_macro:.4f}") |
|
|
click.echo(f" F1 (weighted): {f1_weighted:.4f}") |
|
|
|
|
|
click.echo("\nPer-Tag Classification Report:") |
|
|
report = classification_report(y_test_flat, y_pred_flat, digits=4) |
|
|
click.echo(report) |
|
|
|
|
|
|
|
|
analyze_errors(y_test_flat, y_pred_flat, tokens_flat) |
|
|
|
|
|
|
|
|
tag_counts = Counter(y_test_flat) |
|
|
total_tokens = len(y_test_flat) |
|
|
|
|
|
click.echo("\nTest Set Tag Distribution:") |
|
|
click.echo("-" * 40) |
|
|
for tag in labels: |
|
|
count = tag_counts[tag] |
|
|
pct = count / total_tokens * 100 |
|
|
click.echo(f" {tag:<8} {count:>6} ({pct:>5.2f}%)") |
|
|
|
|
|
if save_plots: |
|
|
cm_path = results_dir / f"confusion_matrix_{version}.png" |
|
|
plot_confusion_matrix( |
|
|
y_test_flat, y_pred_flat, labels, |
|
|
str(cm_path) |
|
|
) |
|
|
|
|
|
report_dict = classification_report( |
|
|
y_test_flat, y_pred_flat, output_dict=True |
|
|
) |
|
|
metrics_path = results_dir / f"per_tag_metrics_{version}.png" |
|
|
plot_per_tag_metrics(report_dict, str(metrics_path)) |
|
|
|
|
|
return accuracy |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
evaluate() |
|
|
|