nexa-classify-api / error_analysis.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
8.38 kB
"""
error_analysis.py
─────────────────
Detailed analysis of model errors on the test set.
Generates confidence distributions, per-class accuracy bars,
and a CSV of the hardest misclassified examples.
Usage
─────
python error_analysis.py --model roberta-base
python error_analysis.py --model lr
python error_analysis.py --model svm
"""
import argparse
import logging
import os
from typing import List, Tuple
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.metrics import accuracy_score
from config import CFG
from data_loader import load_test_only
import traditional_model as tm
import transformer_model as trm
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# ── Probability extraction ─────────────────────────────────────────────────────
def _proba_sklearn(text_list: List[str], pipeline) -> np.ndarray:
clf = list(pipeline.named_steps.values())[-1]
if hasattr(clf, "predict_proba"):
return pipeline.predict_proba(text_list)
# LinearSVC: convert decision scores to pseudo-probabilities via softmax
scores = pipeline.decision_function(text_list)
scores -= scores.max(axis=1, keepdims=True)
exp = np.exp(scores)
return exp / exp.sum(axis=1, keepdims=True)
def _proba_transformer(text_list: List[str], model, tokenizer) -> np.ndarray:
all_probs = []
batch_size = 32
for i in range(0, len(text_list), batch_size):
batch = text_list[i : i + batch_size]
enc = tokenizer(batch, truncation=True, max_length=CFG.max_length,
padding=True, return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits
all_probs.append(torch.softmax(logits, dim=-1).numpy())
return np.vstack(all_probs)
# ── Main analysis ─────────────────────────────────────────────────────────────
def analyse(model_name: str, save_dir: str = None) -> pd.DataFrame:
"""
Full error analysis pipeline.
Returns
-------
DataFrame of all misclassified examples.
"""
logger.info("Loading test set …")
X_test, y_test = load_test_only()
logger.info(f"Running predictions with: {model_name}")
if model_name in ("lr", "svm"):
pipeline = tm.load_model(model_name)
proba = _proba_sklearn(X_test, pipeline)
preds = proba.argmax(axis=1).tolist()
else:
model, tokenizer = trm.load_model(model_name)
proba = _proba_transformer(X_test, model, tokenizer)
preds = proba.argmax(axis=1).tolist()
acc = accuracy_score(y_test, preds)
logger.info(f"Test accuracy: {acc * 100:.2f}%")
# Build analysis DataFrame
df = pd.DataFrame({
"text": X_test,
"true_label": [CFG.label_names[y] for y in y_test],
"pred_label": [CFG.label_names[p] for p in preds],
"confidence": proba.max(axis=1),
"correct": [int(y) == int(p) for y, p in zip(y_test, preds)],
})
for i, name in enumerate(CFG.label_names):
df[f"prob_{name}"] = proba[:, i]
errors = df[~df["correct"].astype(bool)]
corrects = df[df["correct"].astype(bool)]
# ── Console report ───────────────────────────────────────────────────────
print("\n" + "═" * 60)
print(f" ERROR ANALYSIS β€” {model_name.upper()}")
print("═" * 60)
print(f" Total : {len(df):,}")
print(f" Correct : {len(corrects):,} ({len(corrects)/len(df)*100:.2f}%)")
print(f" Errors : {len(errors):,} ({len(errors)/len(df)*100:.2f}%)")
print("\n Errors by true class:")
for label in CFG.label_names:
n = len(errors[errors["true_label"] == label])
print(f" {label:<12} {n:>4} errors")
print("\n Top confused pairs (True β†’ Predicted):")
confused = (
errors.groupby(["true_label", "pred_label"])
.size()
.sort_values(ascending=False)
.head(6)
)
for (true, pred), count in confused.items():
print(f" {true:<12} β†’ {pred:<12} {count:>4} times")
print("\n 5 Hardest Errors (lowest confidence):")
for _, row in errors.nsmallest(5, "confidence").iterrows():
snippet = row["text"][:75] + "…"
print(f" [{row['true_label']} β†’ {row['pred_label']} conf={row['confidence']:.3f}]")
print(f" {snippet}\n")
# ── Plots ────────────────────────────────────────────────────────────────
_plot_analysis(df, model_name, save_dir)
# ── Save CSV ─────────────────────────────────────────────────────────────
if save_dir:
os.makedirs(save_dir, exist_ok=True)
csv_path = os.path.join(save_dir, f"errors_{model_name.replace('-','_')}.csv")
errors.to_csv(csv_path, index=False)
logger.info(f"Error CSV β†’ {csv_path}")
return errors
def _plot_analysis(df: pd.DataFrame, model_name: str, save_dir: str = None) -> None:
"""Two-panel figure: confidence distribution + per-class accuracy bars."""
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
fig.suptitle(f"Error Analysis β€” {model_name}", fontsize=14, fontweight="bold")
# Panel 1: Confidence histograms
correct_conf = df[df["correct"].astype(bool)]["confidence"]
error_conf = df[~df["correct"].astype(bool)]["confidence"]
axes[0].hist(correct_conf, bins=30, alpha=0.75, color="#27ae60",
label=f"Correct (n={len(correct_conf):,})")
axes[0].hist(error_conf, bins=30, alpha=0.75, color="#e74c3c",
label=f"Incorrect (n={len(error_conf):,})")
axes[0].set_xlabel("Prediction Confidence", fontsize=11)
axes[0].set_ylabel("Count", fontsize=11)
axes[0].set_title("Confidence Distribution", fontsize=12)
axes[0].legend(fontsize=10)
axes[0].axvline(correct_conf.mean(), color="#27ae60", linestyle="--", linewidth=1.2,
label=f"Mean correct: {correct_conf.mean():.3f}")
axes[0].axvline(error_conf.mean(), color="#e74c3c", linestyle="--", linewidth=1.2,
label=f"Mean error: {error_conf.mean():.3f}")
# Panel 2: Per-class accuracy
colours = ["#3498db", "#27ae60", "#e67e22", "#9b59b6"]
class_accs = [
df[df["true_label"] == lbl]["correct"].astype(float).mean() * 100
for lbl in CFG.label_names
]
bars = axes[1].bar(CFG.label_names, class_accs, color=colours,
edgecolor="white", linewidth=1.5)
axes[1].set_ylim(80, 100)
axes[1].set_xlabel("Class", fontsize=11)
axes[1].set_ylabel("Accuracy (%)", fontsize=11)
axes[1].set_title("Per-Class Accuracy", fontsize=12)
for bar, acc in zip(bars, class_accs):
axes[1].text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.3,
f"{acc:.1f}%", ha="center", va="bottom", fontsize=11, fontweight="bold")
plt.tight_layout()
if save_dir:
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(save_dir, f"analysis_{model_name.replace('-','_')}.png")
plt.savefig(path, dpi=150)
logger.info(f"Plot β†’ {path}")
plt.show()
plt.close(fig)
def main() -> None:
parser = argparse.ArgumentParser(description="Document classifier error analysis")
parser.add_argument(
"--model", default="roberta-base",
help="Model name: 'lr', 'svm', or transformer checkpoint (e.g. 'roberta-base')"
)
args = parser.parse_args()
save_dir = os.path.join(CFG.outputs_dir, "error_analysis")
analyse(args.model, save_dir=save_dir)
if __name__ == "__main__":
main()