Spaces:
Running
Running
| """ | |
| error_analysis.py | |
| βββββββββββββββββ | |
| Detailed analysis of model errors on the test set. | |
| Generates confidence distributions, per-class accuracy bars, | |
| and a CSV of the hardest misclassified examples. | |
| Usage | |
| βββββ | |
| python error_analysis.py --model roberta-base | |
| python error_analysis.py --model lr | |
| python error_analysis.py --model svm | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| from typing import List, Tuple | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| import torch | |
| from sklearn.metrics import accuracy_score | |
| from config import CFG | |
| from data_loader import load_test_only | |
| import traditional_model as tm | |
| import transformer_model as trm | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββ Probability extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _proba_sklearn(text_list: List[str], pipeline) -> np.ndarray: | |
| clf = list(pipeline.named_steps.values())[-1] | |
| if hasattr(clf, "predict_proba"): | |
| return pipeline.predict_proba(text_list) | |
| # LinearSVC: convert decision scores to pseudo-probabilities via softmax | |
| scores = pipeline.decision_function(text_list) | |
| scores -= scores.max(axis=1, keepdims=True) | |
| exp = np.exp(scores) | |
| return exp / exp.sum(axis=1, keepdims=True) | |
| def _proba_transformer(text_list: List[str], model, tokenizer) -> np.ndarray: | |
| all_probs = [] | |
| batch_size = 32 | |
| for i in range(0, len(text_list), batch_size): | |
| batch = text_list[i : i + batch_size] | |
| enc = tokenizer(batch, truncation=True, max_length=CFG.max_length, | |
| padding=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**enc).logits | |
| all_probs.append(torch.softmax(logits, dim=-1).numpy()) | |
| return np.vstack(all_probs) | |
| # ββ Main analysis βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyse(model_name: str, save_dir: str = None) -> pd.DataFrame: | |
| """ | |
| Full error analysis pipeline. | |
| Returns | |
| ------- | |
| DataFrame of all misclassified examples. | |
| """ | |
| logger.info("Loading test set β¦") | |
| X_test, y_test = load_test_only() | |
| logger.info(f"Running predictions with: {model_name}") | |
| if model_name in ("lr", "svm"): | |
| pipeline = tm.load_model(model_name) | |
| proba = _proba_sklearn(X_test, pipeline) | |
| preds = proba.argmax(axis=1).tolist() | |
| else: | |
| model, tokenizer = trm.load_model(model_name) | |
| proba = _proba_transformer(X_test, model, tokenizer) | |
| preds = proba.argmax(axis=1).tolist() | |
| acc = accuracy_score(y_test, preds) | |
| logger.info(f"Test accuracy: {acc * 100:.2f}%") | |
| # Build analysis DataFrame | |
| df = pd.DataFrame({ | |
| "text": X_test, | |
| "true_label": [CFG.label_names[y] for y in y_test], | |
| "pred_label": [CFG.label_names[p] for p in preds], | |
| "confidence": proba.max(axis=1), | |
| "correct": [int(y) == int(p) for y, p in zip(y_test, preds)], | |
| }) | |
| for i, name in enumerate(CFG.label_names): | |
| df[f"prob_{name}"] = proba[:, i] | |
| errors = df[~df["correct"].astype(bool)] | |
| corrects = df[df["correct"].astype(bool)] | |
| # ββ Console report βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "β" * 60) | |
| print(f" ERROR ANALYSIS β {model_name.upper()}") | |
| print("β" * 60) | |
| print(f" Total : {len(df):,}") | |
| print(f" Correct : {len(corrects):,} ({len(corrects)/len(df)*100:.2f}%)") | |
| print(f" Errors : {len(errors):,} ({len(errors)/len(df)*100:.2f}%)") | |
| print("\n Errors by true class:") | |
| for label in CFG.label_names: | |
| n = len(errors[errors["true_label"] == label]) | |
| print(f" {label:<12} {n:>4} errors") | |
| print("\n Top confused pairs (True β Predicted):") | |
| confused = ( | |
| errors.groupby(["true_label", "pred_label"]) | |
| .size() | |
| .sort_values(ascending=False) | |
| .head(6) | |
| ) | |
| for (true, pred), count in confused.items(): | |
| print(f" {true:<12} β {pred:<12} {count:>4} times") | |
| print("\n 5 Hardest Errors (lowest confidence):") | |
| for _, row in errors.nsmallest(5, "confidence").iterrows(): | |
| snippet = row["text"][:75] + "β¦" | |
| print(f" [{row['true_label']} β {row['pred_label']} conf={row['confidence']:.3f}]") | |
| print(f" {snippet}\n") | |
| # ββ Plots ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _plot_analysis(df, model_name, save_dir) | |
| # ββ Save CSV βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| csv_path = os.path.join(save_dir, f"errors_{model_name.replace('-','_')}.csv") | |
| errors.to_csv(csv_path, index=False) | |
| logger.info(f"Error CSV β {csv_path}") | |
| return errors | |
| def _plot_analysis(df: pd.DataFrame, model_name: str, save_dir: str = None) -> None: | |
| """Two-panel figure: confidence distribution + per-class accuracy bars.""" | |
| fig, axes = plt.subplots(1, 2, figsize=(13, 5)) | |
| fig.suptitle(f"Error Analysis β {model_name}", fontsize=14, fontweight="bold") | |
| # Panel 1: Confidence histograms | |
| correct_conf = df[df["correct"].astype(bool)]["confidence"] | |
| error_conf = df[~df["correct"].astype(bool)]["confidence"] | |
| axes[0].hist(correct_conf, bins=30, alpha=0.75, color="#27ae60", | |
| label=f"Correct (n={len(correct_conf):,})") | |
| axes[0].hist(error_conf, bins=30, alpha=0.75, color="#e74c3c", | |
| label=f"Incorrect (n={len(error_conf):,})") | |
| axes[0].set_xlabel("Prediction Confidence", fontsize=11) | |
| axes[0].set_ylabel("Count", fontsize=11) | |
| axes[0].set_title("Confidence Distribution", fontsize=12) | |
| axes[0].legend(fontsize=10) | |
| axes[0].axvline(correct_conf.mean(), color="#27ae60", linestyle="--", linewidth=1.2, | |
| label=f"Mean correct: {correct_conf.mean():.3f}") | |
| axes[0].axvline(error_conf.mean(), color="#e74c3c", linestyle="--", linewidth=1.2, | |
| label=f"Mean error: {error_conf.mean():.3f}") | |
| # Panel 2: Per-class accuracy | |
| colours = ["#3498db", "#27ae60", "#e67e22", "#9b59b6"] | |
| class_accs = [ | |
| df[df["true_label"] == lbl]["correct"].astype(float).mean() * 100 | |
| for lbl in CFG.label_names | |
| ] | |
| bars = axes[1].bar(CFG.label_names, class_accs, color=colours, | |
| edgecolor="white", linewidth=1.5) | |
| axes[1].set_ylim(80, 100) | |
| axes[1].set_xlabel("Class", fontsize=11) | |
| axes[1].set_ylabel("Accuracy (%)", fontsize=11) | |
| axes[1].set_title("Per-Class Accuracy", fontsize=12) | |
| for bar, acc in zip(bars, class_accs): | |
| axes[1].text(bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 0.3, | |
| f"{acc:.1f}%", ha="center", va="bottom", fontsize=11, fontweight="bold") | |
| plt.tight_layout() | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| path = os.path.join(save_dir, f"analysis_{model_name.replace('-','_')}.png") | |
| plt.savefig(path, dpi=150) | |
| logger.info(f"Plot β {path}") | |
| plt.show() | |
| plt.close(fig) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Document classifier error analysis") | |
| parser.add_argument( | |
| "--model", default="roberta-base", | |
| help="Model name: 'lr', 'svm', or transformer checkpoint (e.g. 'roberta-base')" | |
| ) | |
| args = parser.parse_args() | |
| save_dir = os.path.join(CFG.outputs_dir, "error_analysis") | |
| analyse(args.model, save_dir=save_dir) | |
| if __name__ == "__main__": | |
| main() | |