""" error_analysis.py ───────────────── Detailed analysis of model errors on the test set. Generates confidence distributions, per-class accuracy bars, and a CSV of the hardest misclassified examples. Usage ───── python error_analysis.py --model roberta-base python error_analysis.py --model lr python error_analysis.py --model svm """ import argparse import logging import os from typing import List, Tuple import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import torch from sklearn.metrics import accuracy_score from config import CFG from data_loader import load_test_only import traditional_model as tm import transformer_model as trm logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) # ── Probability extraction ───────────────────────────────────────────────────── def _proba_sklearn(text_list: List[str], pipeline) -> np.ndarray: clf = list(pipeline.named_steps.values())[-1] if hasattr(clf, "predict_proba"): return pipeline.predict_proba(text_list) # LinearSVC: convert decision scores to pseudo-probabilities via softmax scores = pipeline.decision_function(text_list) scores -= scores.max(axis=1, keepdims=True) exp = np.exp(scores) return exp / exp.sum(axis=1, keepdims=True) def _proba_transformer(text_list: List[str], model, tokenizer) -> np.ndarray: all_probs = [] batch_size = 32 for i in range(0, len(text_list), batch_size): batch = text_list[i : i + batch_size] enc = tokenizer(batch, truncation=True, max_length=CFG.max_length, padding=True, return_tensors="pt") with torch.no_grad(): logits = model(**enc).logits all_probs.append(torch.softmax(logits, dim=-1).numpy()) return np.vstack(all_probs) # ── Main analysis ───────────────────────────────────────────────────────────── def analyse(model_name: str, save_dir: str = None) -> pd.DataFrame: """ Full error analysis pipeline. Returns ------- DataFrame of all misclassified examples. """ logger.info("Loading test set …") X_test, y_test = load_test_only() logger.info(f"Running predictions with: {model_name}") if model_name in ("lr", "svm"): pipeline = tm.load_model(model_name) proba = _proba_sklearn(X_test, pipeline) preds = proba.argmax(axis=1).tolist() else: model, tokenizer = trm.load_model(model_name) proba = _proba_transformer(X_test, model, tokenizer) preds = proba.argmax(axis=1).tolist() acc = accuracy_score(y_test, preds) logger.info(f"Test accuracy: {acc * 100:.2f}%") # Build analysis DataFrame df = pd.DataFrame({ "text": X_test, "true_label": [CFG.label_names[y] for y in y_test], "pred_label": [CFG.label_names[p] for p in preds], "confidence": proba.max(axis=1), "correct": [int(y) == int(p) for y, p in zip(y_test, preds)], }) for i, name in enumerate(CFG.label_names): df[f"prob_{name}"] = proba[:, i] errors = df[~df["correct"].astype(bool)] corrects = df[df["correct"].astype(bool)] # ── Console report ─────────────────────────────────────────────────────── print("\n" + "═" * 60) print(f" ERROR ANALYSIS — {model_name.upper()}") print("═" * 60) print(f" Total : {len(df):,}") print(f" Correct : {len(corrects):,} ({len(corrects)/len(df)*100:.2f}%)") print(f" Errors : {len(errors):,} ({len(errors)/len(df)*100:.2f}%)") print("\n Errors by true class:") for label in CFG.label_names: n = len(errors[errors["true_label"] == label]) print(f" {label:<12} {n:>4} errors") print("\n Top confused pairs (True → Predicted):") confused = ( errors.groupby(["true_label", "pred_label"]) .size() .sort_values(ascending=False) .head(6) ) for (true, pred), count in confused.items(): print(f" {true:<12} → {pred:<12} {count:>4} times") print("\n 5 Hardest Errors (lowest confidence):") for _, row in errors.nsmallest(5, "confidence").iterrows(): snippet = row["text"][:75] + "…" print(f" [{row['true_label']} → {row['pred_label']} conf={row['confidence']:.3f}]") print(f" {snippet}\n") # ── Plots ──────────────────────────────────────────────────────────────── _plot_analysis(df, model_name, save_dir) # ── Save CSV ───────────────────────────────────────────────────────────── if save_dir: os.makedirs(save_dir, exist_ok=True) csv_path = os.path.join(save_dir, f"errors_{model_name.replace('-','_')}.csv") errors.to_csv(csv_path, index=False) logger.info(f"Error CSV → {csv_path}") return errors def _plot_analysis(df: pd.DataFrame, model_name: str, save_dir: str = None) -> None: """Two-panel figure: confidence distribution + per-class accuracy bars.""" fig, axes = plt.subplots(1, 2, figsize=(13, 5)) fig.suptitle(f"Error Analysis — {model_name}", fontsize=14, fontweight="bold") # Panel 1: Confidence histograms correct_conf = df[df["correct"].astype(bool)]["confidence"] error_conf = df[~df["correct"].astype(bool)]["confidence"] axes[0].hist(correct_conf, bins=30, alpha=0.75, color="#27ae60", label=f"Correct (n={len(correct_conf):,})") axes[0].hist(error_conf, bins=30, alpha=0.75, color="#e74c3c", label=f"Incorrect (n={len(error_conf):,})") axes[0].set_xlabel("Prediction Confidence", fontsize=11) axes[0].set_ylabel("Count", fontsize=11) axes[0].set_title("Confidence Distribution", fontsize=12) axes[0].legend(fontsize=10) axes[0].axvline(correct_conf.mean(), color="#27ae60", linestyle="--", linewidth=1.2, label=f"Mean correct: {correct_conf.mean():.3f}") axes[0].axvline(error_conf.mean(), color="#e74c3c", linestyle="--", linewidth=1.2, label=f"Mean error: {error_conf.mean():.3f}") # Panel 2: Per-class accuracy colours = ["#3498db", "#27ae60", "#e67e22", "#9b59b6"] class_accs = [ df[df["true_label"] == lbl]["correct"].astype(float).mean() * 100 for lbl in CFG.label_names ] bars = axes[1].bar(CFG.label_names, class_accs, color=colours, edgecolor="white", linewidth=1.5) axes[1].set_ylim(80, 100) axes[1].set_xlabel("Class", fontsize=11) axes[1].set_ylabel("Accuracy (%)", fontsize=11) axes[1].set_title("Per-Class Accuracy", fontsize=12) for bar, acc in zip(bars, class_accs): axes[1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3, f"{acc:.1f}%", ha="center", va="bottom", fontsize=11, fontweight="bold") plt.tight_layout() if save_dir: os.makedirs(save_dir, exist_ok=True) path = os.path.join(save_dir, f"analysis_{model_name.replace('-','_')}.png") plt.savefig(path, dpi=150) logger.info(f"Plot → {path}") plt.show() plt.close(fig) def main() -> None: parser = argparse.ArgumentParser(description="Document classifier error analysis") parser.add_argument( "--model", default="roberta-base", help="Model name: 'lr', 'svm', or transformer checkpoint (e.g. 'roberta-base')" ) args = parser.parse_args() save_dir = os.path.join(CFG.outputs_dir, "error_analysis") analyse(args.model, save_dir=save_dir) if __name__ == "__main__": main()