Spaces:
Running
Running
File size: 8,382 Bytes
a229747 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
error_analysis.py
βββββββββββββββββ
Detailed analysis of model errors on the test set.
Generates confidence distributions, per-class accuracy bars,
and a CSV of the hardest misclassified examples.
Usage
βββββ
python error_analysis.py --model roberta-base
python error_analysis.py --model lr
python error_analysis.py --model svm
"""
import argparse
import logging
import os
from typing import List, Tuple
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.metrics import accuracy_score
from config import CFG
from data_loader import load_test_only
import traditional_model as tm
import transformer_model as trm
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
# ββ Probability extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _proba_sklearn(text_list: List[str], pipeline) -> np.ndarray:
clf = list(pipeline.named_steps.values())[-1]
if hasattr(clf, "predict_proba"):
return pipeline.predict_proba(text_list)
# LinearSVC: convert decision scores to pseudo-probabilities via softmax
scores = pipeline.decision_function(text_list)
scores -= scores.max(axis=1, keepdims=True)
exp = np.exp(scores)
return exp / exp.sum(axis=1, keepdims=True)
def _proba_transformer(text_list: List[str], model, tokenizer) -> np.ndarray:
all_probs = []
batch_size = 32
for i in range(0, len(text_list), batch_size):
batch = text_list[i : i + batch_size]
enc = tokenizer(batch, truncation=True, max_length=CFG.max_length,
padding=True, return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits
all_probs.append(torch.softmax(logits, dim=-1).numpy())
return np.vstack(all_probs)
# ββ Main analysis βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def analyse(model_name: str, save_dir: str = None) -> pd.DataFrame:
"""
Full error analysis pipeline.
Returns
-------
DataFrame of all misclassified examples.
"""
logger.info("Loading test set β¦")
X_test, y_test = load_test_only()
logger.info(f"Running predictions with: {model_name}")
if model_name in ("lr", "svm"):
pipeline = tm.load_model(model_name)
proba = _proba_sklearn(X_test, pipeline)
preds = proba.argmax(axis=1).tolist()
else:
model, tokenizer = trm.load_model(model_name)
proba = _proba_transformer(X_test, model, tokenizer)
preds = proba.argmax(axis=1).tolist()
acc = accuracy_score(y_test, preds)
logger.info(f"Test accuracy: {acc * 100:.2f}%")
# Build analysis DataFrame
df = pd.DataFrame({
"text": X_test,
"true_label": [CFG.label_names[y] for y in y_test],
"pred_label": [CFG.label_names[p] for p in preds],
"confidence": proba.max(axis=1),
"correct": [int(y) == int(p) for y, p in zip(y_test, preds)],
})
for i, name in enumerate(CFG.label_names):
df[f"prob_{name}"] = proba[:, i]
errors = df[~df["correct"].astype(bool)]
corrects = df[df["correct"].astype(bool)]
# ββ Console report βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n" + "β" * 60)
print(f" ERROR ANALYSIS β {model_name.upper()}")
print("β" * 60)
print(f" Total : {len(df):,}")
print(f" Correct : {len(corrects):,} ({len(corrects)/len(df)*100:.2f}%)")
print(f" Errors : {len(errors):,} ({len(errors)/len(df)*100:.2f}%)")
print("\n Errors by true class:")
for label in CFG.label_names:
n = len(errors[errors["true_label"] == label])
print(f" {label:<12} {n:>4} errors")
print("\n Top confused pairs (True β Predicted):")
confused = (
errors.groupby(["true_label", "pred_label"])
.size()
.sort_values(ascending=False)
.head(6)
)
for (true, pred), count in confused.items():
print(f" {true:<12} β {pred:<12} {count:>4} times")
print("\n 5 Hardest Errors (lowest confidence):")
for _, row in errors.nsmallest(5, "confidence").iterrows():
snippet = row["text"][:75] + "β¦"
print(f" [{row['true_label']} β {row['pred_label']} conf={row['confidence']:.3f}]")
print(f" {snippet}\n")
# ββ Plots ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_plot_analysis(df, model_name, save_dir)
# ββ Save CSV βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if save_dir:
os.makedirs(save_dir, exist_ok=True)
csv_path = os.path.join(save_dir, f"errors_{model_name.replace('-','_')}.csv")
errors.to_csv(csv_path, index=False)
logger.info(f"Error CSV β {csv_path}")
return errors
def _plot_analysis(df: pd.DataFrame, model_name: str, save_dir: str = None) -> None:
"""Two-panel figure: confidence distribution + per-class accuracy bars."""
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
fig.suptitle(f"Error Analysis β {model_name}", fontsize=14, fontweight="bold")
# Panel 1: Confidence histograms
correct_conf = df[df["correct"].astype(bool)]["confidence"]
error_conf = df[~df["correct"].astype(bool)]["confidence"]
axes[0].hist(correct_conf, bins=30, alpha=0.75, color="#27ae60",
label=f"Correct (n={len(correct_conf):,})")
axes[0].hist(error_conf, bins=30, alpha=0.75, color="#e74c3c",
label=f"Incorrect (n={len(error_conf):,})")
axes[0].set_xlabel("Prediction Confidence", fontsize=11)
axes[0].set_ylabel("Count", fontsize=11)
axes[0].set_title("Confidence Distribution", fontsize=12)
axes[0].legend(fontsize=10)
axes[0].axvline(correct_conf.mean(), color="#27ae60", linestyle="--", linewidth=1.2,
label=f"Mean correct: {correct_conf.mean():.3f}")
axes[0].axvline(error_conf.mean(), color="#e74c3c", linestyle="--", linewidth=1.2,
label=f"Mean error: {error_conf.mean():.3f}")
# Panel 2: Per-class accuracy
colours = ["#3498db", "#27ae60", "#e67e22", "#9b59b6"]
class_accs = [
df[df["true_label"] == lbl]["correct"].astype(float).mean() * 100
for lbl in CFG.label_names
]
bars = axes[1].bar(CFG.label_names, class_accs, color=colours,
edgecolor="white", linewidth=1.5)
axes[1].set_ylim(80, 100)
axes[1].set_xlabel("Class", fontsize=11)
axes[1].set_ylabel("Accuracy (%)", fontsize=11)
axes[1].set_title("Per-Class Accuracy", fontsize=12)
for bar, acc in zip(bars, class_accs):
axes[1].text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.3,
f"{acc:.1f}%", ha="center", va="bottom", fontsize=11, fontweight="bold")
plt.tight_layout()
if save_dir:
os.makedirs(save_dir, exist_ok=True)
path = os.path.join(save_dir, f"analysis_{model_name.replace('-','_')}.png")
plt.savefig(path, dpi=150)
logger.info(f"Plot β {path}")
plt.show()
plt.close(fig)
def main() -> None:
parser = argparse.ArgumentParser(description="Document classifier error analysis")
parser.add_argument(
"--model", default="roberta-base",
help="Model name: 'lr', 'svm', or transformer checkpoint (e.g. 'roberta-base')"
)
args = parser.parse_args()
save_dir = os.path.join(CFG.outputs_dir, "error_analysis")
analyse(args.model, save_dir=save_dir)
if __name__ == "__main__":
main()
|