# /// script # requires-python = ">=3.9" # dependencies = [ # "python-crfsuite>=0.9.11", # "datasets>=4.5.0", # "scikit-learn>=1.6.1", # "matplotlib>=3.5.0", # "seaborn>=0.12.0", # "click>=8.0.0", # ] # /// """ Evaluation script for Vietnamese POS Tagger (TRE-1). Usage: uv run scripts/evaluate.py uv run scripts/evaluate.py --version v1.0.0 uv run scripts/evaluate.py --model models/pos_tagger/v1.0.0/model.crfsuite uv run scripts/evaluate.py --save-plots """ import re from collections import Counter from pathlib import Path import click import pycrfsuite from datasets import load_dataset # Get project root directory PROJECT_ROOT = Path(__file__).parent.parent from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix, ) FEATURE_TEMPLATES = [ "T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", "T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", "T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", "T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", "T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", "T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", "T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", ] def get_token_value(tokens, position, index): actual_pos = position + index if actual_pos < 0: return "__BOS__" elif actual_pos >= len(tokens): return "__EOS__" return tokens[actual_pos] def apply_attribute(value, attribute, dictionary=None): if value in ("__BOS__", "__EOS__"): return value if attribute is None: return value elif attribute == "lower": return value.lower() elif attribute == "upper": return value.upper() elif attribute == "istitle": return str(value.istitle()) elif attribute == "isupper": return str(value.isupper()) elif attribute == "islower": return str(value.islower()) elif attribute == "isdigit": return str(value.isdigit()) elif attribute == "isalpha": return str(value.isalpha()) elif attribute == "is_in_dict": return str(value in dictionary) if dictionary else "False" elif attribute.startswith("prefix"): n = int(attribute[6:]) if len(attribute) > 6 else 2 return value[:n] if len(value) >= n else value elif attribute.startswith("suffix"): n = int(attribute[6:]) if len(attribute) > 6 else 2 return value[-n:] if len(value) >= n else value return value def parse_template(template): match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template) if not match: return None, None indices_str = match.group(1) attribute = match.group(2) indices = [int(i.strip()) for i in indices_str.split(",")] return indices, attribute def extract_features(tokens, position, dictionary=None): features = {} for template in FEATURE_TEMPLATES: indices, attribute = parse_template(template) if indices is None: continue if len(indices) == 1: value = get_token_value(tokens, position, indices[0]) value = apply_attribute(value, attribute, dictionary) features[template] = value else: values = [get_token_value(tokens, position, idx) for idx in indices] if attribute == "is_in_dict": combined = " ".join(values) features[template] = str(combined in dictionary) if dictionary else "False" else: combined = "|".join(values) features[template] = combined return features def sentence_to_features(tokens): return [ [f"{k}={v}" for k, v in extract_features(tokens, i).items()] for i in range(len(tokens)) ] def load_test_data(): click.echo("Loading UDD-1 dataset...") dataset = load_dataset("undertheseanlp/UDD-1") sentences = [] for item in dataset["test"]: tokens = item["tokens"] tags = item["upos"] if tokens and tags: sentences.append((tokens, tags)) click.echo(f"Test set: {len(sentences)} sentences") return sentences def plot_confusion_matrix(y_true, y_pred, labels, output_path): import matplotlib.pyplot as plt import seaborn as sns cm = confusion_matrix(y_true, y_pred, labels=labels) plt.figure(figsize=(12, 10)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels, ) plt.xlabel("Predicted") plt.ylabel("True") plt.title("Confusion Matrix - Vietnamese POS Tagger (TRE-1)") plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close() click.echo(f"Confusion matrix saved to {output_path}") def plot_per_tag_metrics(report_dict, output_path): import matplotlib.pyplot as plt tags = [k for k in report_dict.keys() if k not in ("accuracy", "macro avg", "weighted avg")] precision = [report_dict[t]["precision"] for t in tags] recall = [report_dict[t]["recall"] for t in tags] f1 = [report_dict[t]["f1-score"] for t in tags] x = range(len(tags)) width = 0.25 fig, ax = plt.subplots(figsize=(14, 6)) ax.bar([i - width for i in x], precision, width, label="Precision", color="#2ecc71") ax.bar(x, recall, width, label="Recall", color="#3498db") ax.bar([i + width for i in x], f1, width, label="F1-Score", color="#e74c3c") ax.set_xlabel("POS Tag") ax.set_ylabel("Score") ax.set_title("Per-Tag Performance Metrics - Vietnamese POS Tagger (TRE-1)") ax.set_xticks(x) ax.set_xticklabels(tags, rotation=45) ax.legend() ax.set_ylim(0, 1.1) ax.grid(axis="y", alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close() click.echo(f"Per-tag metrics saved to {output_path}") def analyze_errors(y_true, y_pred, tokens_flat, top_n=10): """Analyze common error patterns.""" errors = Counter() error_examples = {} for true, pred, token in zip(y_true, y_pred, tokens_flat): if true != pred: key = (true, pred) errors[key] += 1 if key not in error_examples: error_examples[key] = token click.echo(f"\nTop {top_n} Error Patterns:") click.echo("-" * 60) click.echo(f"{'True':<10} {'Predicted':<10} {'Count':<8} {'Example'}") click.echo("-" * 60) for (true, pred), count in errors.most_common(top_n): example = error_examples.get((true, pred), "") click.echo(f"{true:<10} {pred:<10} {count:<8} {example}") def get_latest_version(task="pos_tagger"): """Get the latest model version (sorted by timestamp).""" models_dir = PROJECT_ROOT / "models" / task if not models_dir.exists(): return None versions = [d.name for d in models_dir.iterdir() if d.is_dir()] if not versions: return None return sorted(versions)[-1] # Latest timestamp @click.command() @click.option( "--version", "-v", default=None, help="Model version to evaluate (default: latest)", ) @click.option( "--model", "-m", default=None, help="Custom model path (overrides version-based path)", ) @click.option( "--save-plots", is_flag=True, help="Save confusion matrix and per-tag metrics plots", ) def evaluate(version, model, save_plots): """Evaluate Vietnamese POS Tagger on UDD-1 test set.""" # Use latest version if not specified if version is None and model is None: version = get_latest_version("pos_tagger") if version is None: raise click.ClickException("No models found in models/pos_tagger/") # Determine model path if model: model_path = Path(model) else: model_path = PROJECT_ROOT / "models" / "pos_tagger" / version / "model.crfsuite" # Determine output directory for plots if save_plots: results_dir = PROJECT_ROOT / "results" / "pos_tagger" results_dir.mkdir(parents=True, exist_ok=True) click.echo(f"Loading model from {model_path}...") tagger = pycrfsuite.Tagger() tagger.open(str(model_path)) test_data = load_test_data() click.echo("Extracting features and predicting...") X_test = [sentence_to_features(tokens) for tokens, _ in test_data] y_test = [tags for _, tags in test_data] tokens_test = [tokens for tokens, _ in test_data] y_pred = [tagger.tag(xseq) for xseq in X_test] # Flatten y_test_flat = [tag for tags in y_test for tag in tags] y_pred_flat = [tag for tags in y_pred for tag in tags] tokens_flat = [token for tokens in tokens_test for token in tokens] # Get unique labels labels = sorted(set(y_test_flat)) # Overall metrics accuracy = accuracy_score(y_test_flat, y_pred_flat) precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( y_test_flat, y_pred_flat, average="macro" ) _, _, f1_weighted, _ = precision_recall_fscore_support( y_test_flat, y_pred_flat, average="weighted" ) click.echo("\n" + "=" * 60) click.echo("EVALUATION RESULTS") click.echo("=" * 60) click.echo("\nOverall Metrics:") click.echo(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") click.echo(f" Precision (macro): {precision_macro:.4f}") click.echo(f" Recall (macro): {recall_macro:.4f}") click.echo(f" F1 (macro): {f1_macro:.4f}") click.echo(f" F1 (weighted): {f1_weighted:.4f}") click.echo("\nPer-Tag Classification Report:") report = classification_report(y_test_flat, y_pred_flat, digits=4) click.echo(report) # Error analysis analyze_errors(y_test_flat, y_pred_flat, tokens_flat) # Dataset statistics tag_counts = Counter(y_test_flat) total_tokens = len(y_test_flat) click.echo("\nTest Set Tag Distribution:") click.echo("-" * 40) for tag in labels: count = tag_counts[tag] pct = count / total_tokens * 100 click.echo(f" {tag:<8} {count:>6} ({pct:>5.2f}%)") if save_plots: cm_path = results_dir / f"confusion_matrix_{version}.png" plot_confusion_matrix( y_test_flat, y_pred_flat, labels, str(cm_path) ) report_dict = classification_report( y_test_flat, y_pred_flat, output_dict=True ) metrics_path = results_dir / f"per_tag_metrics_{version}.png" plot_per_tag_metrics(report_dict, str(metrics_path)) return accuracy if __name__ == "__main__": evaluate()