tre-1 / scripts /evaluate.py
rain1024's picture
Add word segmentation support and underthesea-core integration
5d8bdc8
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "python-crfsuite>=0.9.11",
# "datasets>=4.5.0",
# "scikit-learn>=1.6.1",
# "matplotlib>=3.5.0",
# "seaborn>=0.12.0",
# "click>=8.0.0",
# ]
# ///
"""
Evaluation script for Vietnamese POS Tagger (TRE-1).
Usage:
uv run scripts/evaluate.py
uv run scripts/evaluate.py --version v1.0.0
uv run scripts/evaluate.py --model models/pos_tagger/v1.0.0/model.crfsuite
uv run scripts/evaluate.py --save-plots
"""
import re
from collections import Counter
from pathlib import Path
import click
import pycrfsuite
from datasets import load_dataset
# Get project root directory
PROJECT_ROOT = Path(__file__).parent.parent
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
classification_report,
confusion_matrix,
)
FEATURE_TEMPLATES = [
"T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper",
"T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3",
"T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower",
"T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower",
"T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper",
"T[2]", "T[2].lower", "T[-1,0]", "T[0,1]",
"T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict",
]
def get_token_value(tokens, position, index):
actual_pos = position + index
if actual_pos < 0:
return "__BOS__"
elif actual_pos >= len(tokens):
return "__EOS__"
return tokens[actual_pos]
def apply_attribute(value, attribute, dictionary=None):
if value in ("__BOS__", "__EOS__"):
return value
if attribute is None:
return value
elif attribute == "lower":
return value.lower()
elif attribute == "upper":
return value.upper()
elif attribute == "istitle":
return str(value.istitle())
elif attribute == "isupper":
return str(value.isupper())
elif attribute == "islower":
return str(value.islower())
elif attribute == "isdigit":
return str(value.isdigit())
elif attribute == "isalpha":
return str(value.isalpha())
elif attribute == "is_in_dict":
return str(value in dictionary) if dictionary else "False"
elif attribute.startswith("prefix"):
n = int(attribute[6:]) if len(attribute) > 6 else 2
return value[:n] if len(value) >= n else value
elif attribute.startswith("suffix"):
n = int(attribute[6:]) if len(attribute) > 6 else 2
return value[-n:] if len(value) >= n else value
return value
def parse_template(template):
match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template)
if not match:
return None, None
indices_str = match.group(1)
attribute = match.group(2)
indices = [int(i.strip()) for i in indices_str.split(",")]
return indices, attribute
def extract_features(tokens, position, dictionary=None):
features = {}
for template in FEATURE_TEMPLATES:
indices, attribute = parse_template(template)
if indices is None:
continue
if len(indices) == 1:
value = get_token_value(tokens, position, indices[0])
value = apply_attribute(value, attribute, dictionary)
features[template] = value
else:
values = [get_token_value(tokens, position, idx) for idx in indices]
if attribute == "is_in_dict":
combined = " ".join(values)
features[template] = str(combined in dictionary) if dictionary else "False"
else:
combined = "|".join(values)
features[template] = combined
return features
def sentence_to_features(tokens):
return [
[f"{k}={v}" for k, v in extract_features(tokens, i).items()]
for i in range(len(tokens))
]
def load_test_data():
click.echo("Loading UDD-1 dataset...")
dataset = load_dataset("undertheseanlp/UDD-1")
sentences = []
for item in dataset["test"]:
tokens = item["tokens"]
tags = item["upos"]
if tokens and tags:
sentences.append((tokens, tags))
click.echo(f"Test set: {len(sentences)} sentences")
return sentences
def plot_confusion_matrix(y_true, y_pred, labels, output_path):
import matplotlib.pyplot as plt
import seaborn as sns
cm = confusion_matrix(y_true, y_pred, labels=labels)
plt.figure(figsize=(12, 10))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=labels,
yticklabels=labels,
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix - Vietnamese POS Tagger (TRE-1)")
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
click.echo(f"Confusion matrix saved to {output_path}")
def plot_per_tag_metrics(report_dict, output_path):
import matplotlib.pyplot as plt
tags = [k for k in report_dict.keys() if k not in ("accuracy", "macro avg", "weighted avg")]
precision = [report_dict[t]["precision"] for t in tags]
recall = [report_dict[t]["recall"] for t in tags]
f1 = [report_dict[t]["f1-score"] for t in tags]
x = range(len(tags))
width = 0.25
fig, ax = plt.subplots(figsize=(14, 6))
ax.bar([i - width for i in x], precision, width, label="Precision", color="#2ecc71")
ax.bar(x, recall, width, label="Recall", color="#3498db")
ax.bar([i + width for i in x], f1, width, label="F1-Score", color="#e74c3c")
ax.set_xlabel("POS Tag")
ax.set_ylabel("Score")
ax.set_title("Per-Tag Performance Metrics - Vietnamese POS Tagger (TRE-1)")
ax.set_xticks(x)
ax.set_xticklabels(tags, rotation=45)
ax.legend()
ax.set_ylim(0, 1.1)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
click.echo(f"Per-tag metrics saved to {output_path}")
def analyze_errors(y_true, y_pred, tokens_flat, top_n=10):
"""Analyze common error patterns."""
errors = Counter()
error_examples = {}
for true, pred, token in zip(y_true, y_pred, tokens_flat):
if true != pred:
key = (true, pred)
errors[key] += 1
if key not in error_examples:
error_examples[key] = token
click.echo(f"\nTop {top_n} Error Patterns:")
click.echo("-" * 60)
click.echo(f"{'True':<10} {'Predicted':<10} {'Count':<8} {'Example'}")
click.echo("-" * 60)
for (true, pred), count in errors.most_common(top_n):
example = error_examples.get((true, pred), "")
click.echo(f"{true:<10} {pred:<10} {count:<8} {example}")
def get_latest_version(task="pos_tagger"):
"""Get the latest model version (sorted by timestamp)."""
models_dir = PROJECT_ROOT / "models" / task
if not models_dir.exists():
return None
versions = [d.name for d in models_dir.iterdir() if d.is_dir()]
if not versions:
return None
return sorted(versions)[-1] # Latest timestamp
@click.command()
@click.option(
"--version", "-v",
default=None,
help="Model version to evaluate (default: latest)",
)
@click.option(
"--model", "-m",
default=None,
help="Custom model path (overrides version-based path)",
)
@click.option(
"--save-plots",
is_flag=True,
help="Save confusion matrix and per-tag metrics plots",
)
def evaluate(version, model, save_plots):
"""Evaluate Vietnamese POS Tagger on UDD-1 test set."""
# Use latest version if not specified
if version is None and model is None:
version = get_latest_version("pos_tagger")
if version is None:
raise click.ClickException("No models found in models/pos_tagger/")
# Determine model path
if model:
model_path = Path(model)
else:
model_path = PROJECT_ROOT / "models" / "pos_tagger" / version / "model.crfsuite"
# Determine output directory for plots
if save_plots:
results_dir = PROJECT_ROOT / "results" / "pos_tagger"
results_dir.mkdir(parents=True, exist_ok=True)
click.echo(f"Loading model from {model_path}...")
tagger = pycrfsuite.Tagger()
tagger.open(str(model_path))
test_data = load_test_data()
click.echo("Extracting features and predicting...")
X_test = [sentence_to_features(tokens) for tokens, _ in test_data]
y_test = [tags for _, tags in test_data]
tokens_test = [tokens for tokens, _ in test_data]
y_pred = [tagger.tag(xseq) for xseq in X_test]
# Flatten
y_test_flat = [tag for tags in y_test for tag in tags]
y_pred_flat = [tag for tags in y_pred for tag in tags]
tokens_flat = [token for tokens in tokens_test for token in tokens]
# Get unique labels
labels = sorted(set(y_test_flat))
# Overall metrics
accuracy = accuracy_score(y_test_flat, y_pred_flat)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
y_test_flat, y_pred_flat, average="macro"
)
_, _, f1_weighted, _ = precision_recall_fscore_support(
y_test_flat, y_pred_flat, average="weighted"
)
click.echo("\n" + "=" * 60)
click.echo("EVALUATION RESULTS")
click.echo("=" * 60)
click.echo("\nOverall Metrics:")
click.echo(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
click.echo(f" Precision (macro): {precision_macro:.4f}")
click.echo(f" Recall (macro): {recall_macro:.4f}")
click.echo(f" F1 (macro): {f1_macro:.4f}")
click.echo(f" F1 (weighted): {f1_weighted:.4f}")
click.echo("\nPer-Tag Classification Report:")
report = classification_report(y_test_flat, y_pred_flat, digits=4)
click.echo(report)
# Error analysis
analyze_errors(y_test_flat, y_pred_flat, tokens_flat)
# Dataset statistics
tag_counts = Counter(y_test_flat)
total_tokens = len(y_test_flat)
click.echo("\nTest Set Tag Distribution:")
click.echo("-" * 40)
for tag in labels:
count = tag_counts[tag]
pct = count / total_tokens * 100
click.echo(f" {tag:<8} {count:>6} ({pct:>5.2f}%)")
if save_plots:
cm_path = results_dir / f"confusion_matrix_{version}.png"
plot_confusion_matrix(
y_test_flat, y_pred_flat, labels,
str(cm_path)
)
report_dict = classification_report(
y_test_flat, y_pred_flat, output_dict=True
)
metrics_path = results_dir / f"per_tag_metrics_{version}.png"
plot_per_tag_metrics(report_dict, str(metrics_path))
return accuracy
if __name__ == "__main__":
evaluate()