offrails / scripts /evaluate.py
Jog-sama's picture
complete ML pipeline: data processing, feature engineering, 3 models, evaluation, experiments
07660e7
"""
Comprehensive evaluation of all three models on the held-out test set.
Produces:
- Per-model metrics (accuracy, precision, recall, F1, ROC AUC)
- Confusion matrices (saved as PNGs)
- ROC curves overlay
- Feature importance plot (XGBoost)
- Error analysis: 5 specific mispredictions per model with root-cause explanation
- Results summary CSV
"""
import argparse
import json
import os
import sys
import warnings
import matplotlib
matplotlib.use("Agg") # non-interactive backend for servers
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
f1_score,
precision_recall_curve,
precision_score,
recall_score,
roc_auc_score,
roc_curve,
)
from build_features import extract_features_from_row, get_feature_columns
from model import ClassicalMLModel, NaiveBaseline, TraceTransformer
warnings.filterwarnings("ignore")
sns.set_style("whitegrid")
def compute_metrics(y_true, y_pred, y_proba=None) -> dict:
"""Compute full metric suite."""
metrics = {
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred, zero_division=0),
"recall": recall_score(y_true, y_pred, zero_division=0),
"f1": f1_score(y_true, y_pred, zero_division=0),
"f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0),
}
if y_proba is not None and len(np.unique(y_true)) > 1:
try:
metrics["roc_auc"] = roc_auc_score(y_true, y_proba[:, 1])
except Exception:
metrics["roc_auc"] = None
return metrics
def plot_confusion_matrix(y_true, y_pred, model_name, output_dir):
"""Save a confusion matrix heatmap."""
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(
cm, annot=True, fmt="d", cmap="Blues",
xticklabels=["Normal", "Anomalous"],
yticklabels=["Normal", "Anomalous"],
ax=ax,
)
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual")
ax.set_title(f"Confusion Matrix — {model_name}")
plt.tight_layout()
path = os.path.join(output_dir, f"cm_{model_name.lower().replace(' ', '_')}.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f" [SAVED] {path}")
def plot_roc_curves(results: dict, output_dir: str):
"""Overlay ROC curves for all models that have probability outputs."""
fig, ax = plt.subplots(figsize=(8, 6))
for name, res in results.items():
if res.get("y_proba") is not None and res.get("roc_auc") is not None:
fpr, tpr, _ = roc_curve(res["y_true"], res["y_proba"][:, 1])
ax.plot(fpr, tpr, label=f'{name} (AUC={res["roc_auc"]:.3f})')
ax.plot([0, 1], [0, 1], "k--", alpha=0.3, label="Random")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title("ROC Curves — Model Comparison")
ax.legend(loc="lower right")
plt.tight_layout()
path = os.path.join(output_dir, "roc_curves.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[SAVED] {path}")
def plot_precision_recall_curves(results: dict, output_dir: str):
"""Overlay Precision-Recall curves."""
fig, ax = plt.subplots(figsize=(8, 6))
for name, res in results.items():
if res.get("y_proba") is not None:
prec_vals, rec_vals, _ = precision_recall_curve(
res["y_true"], res["y_proba"][:, 1]
)
ax.plot(rec_vals, prec_vals, label=name)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall Curves — Model Comparison")
ax.legend(loc="upper right")
plt.tight_layout()
path = os.path.join(output_dir, "pr_curves.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[SAVED] {path}")
def plot_feature_importance(model: ClassicalMLModel, output_dir: str, top_n: int = 20):
"""Bar chart of top feature importances."""
importance = model.get_feature_importance().head(top_n)
fig, ax = plt.subplots(figsize=(10, 7))
importance.plot(kind="barh", ax=ax, color="steelblue")
ax.set_xlabel("Importance (Gain)")
ax.set_title(f"XGBoost — Top {top_n} Feature Importances")
ax.invert_yaxis()
plt.tight_layout()
path = os.path.join(output_dir, "feature_importance.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[SAVED] {path}")
def error_analysis(y_true, y_pred, df, model_name, n_errors=5) -> list[dict]:
"""
Identify mispredictions, explain likely root causes.
Returns a list of dicts with error details.
"""
errors = []
wrong_mask = y_true != y_pred
wrong_indices = np.where(wrong_mask)[0]
if len(wrong_indices) == 0:
print(f" [{model_name}] No errors found!")
return errors
# sample up to n_errors, balanced between FP and FN
fp_idx = wrong_indices[(y_true[wrong_indices] == 0) & (y_pred[wrong_indices] == 1)]
fn_idx = wrong_indices[(y_true[wrong_indices] == 1) & (y_pred[wrong_indices] == 0)]
sample_fp = fp_idx[:min(3, len(fp_idx))]
sample_fn = fn_idx[:min(3, len(fn_idx))]
sampled = np.concatenate([sample_fp, sample_fn])[:n_errors]
for idx in sampled:
row = df.iloc[idx]
error_type = "False Positive" if y_true[idx] == 0 else "False Negative"
# infer root cause from features
query_preview = str(row.get("user_query", ""))[:150]
n_tools = row.get("num_tool_calls", 0)
n_turns = row.get("num_turns", 0)
if error_type == "False Positive":
cause = (
f"Model predicted anomalous but trace was normal. "
f"Trace had {n_tools} tool calls across {n_turns} turns. "
f"Possible cause: hedging or cautious language in the final response "
f"triggered false failure signal."
)
mitigation = (
"Add features that distinguish cautious-but-successful responses "
"from actual failures. Consider sentiment analysis on the final turn."
)
else:
cause = (
f"Model predicted normal but trace was actually anomalous. "
f"Trace had {n_tools} tool calls across {n_turns} turns. "
f"Possible cause: the agent completed tool calls but the final answer "
f"was incorrect/incomplete without obvious failure language."
)
mitigation = (
"Incorporate semantic similarity between query intent and final response. "
"Add features measuring whether tool outputs were actually used in the answer."
)
errors.append({
"index": int(idx),
"model": model_name,
"error_type": error_type,
"true_label": int(y_true[idx]),
"pred_label": int(y_pred[idx]),
"query_preview": query_preview,
"root_cause": cause,
"mitigation": mitigation,
})
return errors
def evaluate_all(data_dir: str, model_dir: str, output_dir: str):
"""Run full evaluation pipeline."""
os.makedirs(output_dir, exist_ok=True)
# load test data
test_raw = pd.read_parquet(os.path.join(data_dir, "test.parquet"))
test_feat_path = os.path.join(data_dir, "test_features.parquet")
test_feat = pd.read_parquet(test_feat_path) if os.path.exists(test_feat_path) else None
y_true = test_raw["label"].values
results = {}
all_errors = []
# Naive Baseline
naive_path = os.path.join(model_dir, "naive_baseline.joblib")
if os.path.exists(naive_path):
print("\nEvaluating: Naive Baseline")
naive = NaiveBaseline.load(naive_path)
y_pred = naive.predict(y_true)
y_proba = naive.predict_proba(y_true)
metrics = compute_metrics(y_true, y_pred, y_proba)
results["Naive Baseline"] = {**metrics, "y_true": y_true, "y_pred": y_pred, "y_proba": y_proba}
plot_confusion_matrix(y_true, y_pred, "Naive Baseline", output_dir)
all_errors += error_analysis(y_true, y_pred, test_raw, "Naive Baseline")
print(classification_report(y_true, y_pred, target_names=["Normal", "Anomalous"]))
# XGBoost
xgb_path = os.path.join(model_dir, "xgboost_model.joblib")
if os.path.exists(xgb_path) and test_feat is not None:
print("\nEvaluating: XGBoost")
xgb_model = ClassicalMLModel.load(xgb_path)
feat_cols = get_feature_columns(test_feat)
X_test = test_feat[feat_cols]
y_pred = xgb_model.predict(X_test)
y_proba = xgb_model.predict_proba(X_test)
metrics = compute_metrics(y_true, y_pred, y_proba)
results["XGBoost"] = {**metrics, "y_true": y_true, "y_pred": y_pred, "y_proba": y_proba}
plot_confusion_matrix(y_true, y_pred, "XGBoost", output_dir)
plot_feature_importance(xgb_model, output_dir)
all_errors += error_analysis(y_true, y_pred, test_raw, "XGBoost")
print(classification_report(y_true, y_pred, target_names=["Normal", "Anomalous"]))
# DistilBERT
dl_path = os.path.join(model_dir, "distilbert_trace")
if os.path.exists(dl_path):
print("\nEvaluating: DistilBERT")
dl_model = TraceTransformer.load(dl_path)
X_test = test_raw["raw_trace"].tolist()
y_pred = dl_model.predict(X_test)
y_proba = dl_model.predict_proba(X_test)
metrics = compute_metrics(y_true, y_pred, y_proba)
results["DistilBERT"] = {**metrics, "y_true": y_true, "y_pred": y_pred, "y_proba": y_proba}
plot_confusion_matrix(y_true, y_pred, "DistilBERT", output_dir)
all_errors += error_analysis(y_true, y_pred, test_raw, "DistilBERT")
print(classification_report(y_true, y_pred, target_names=["Normal", "Anomalous"]))
# Comparison plots
if results:
plot_roc_curves(results, output_dir)
plot_precision_recall_curves(results, output_dir)
# Summary Table
summary_rows = []
for name, res in results.items():
row = {k: v for k, v in res.items() if k not in ("y_true", "y_pred", "y_proba")}
row["model"] = name
summary_rows.append(row)
if summary_rows:
summary_df = pd.DataFrame(summary_rows).set_index("model")
print("\n" + "═" * 60)
print(" FINAL TEST SET COMPARISON")
print("═" * 60)
print(summary_df.to_string())
summary_path = os.path.join(output_dir, "results_summary.csv")
summary_df.to_csv(summary_path)
print(f"\n[SAVED] {summary_path}")
# Error analysis details
if all_errors:
errors_df = pd.DataFrame(all_errors)
errors_path = os.path.join(output_dir, "error_analysis.csv")
errors_df.to_csv(errors_path, index=False)
print(f"[SAVED] {errors_path}")
print("\n" + "═" * 60)
print(" ERROR ANALYSIS (5 mispredictions per model)")
print("═" * 60)
for err in all_errors:
print(f"\n [{err['model']}] {err['error_type']}")
print(f" True: {err['true_label']}, Predicted: {err['pred_label']}")
print(f" Query: {err['query_preview']}")
print(f" Root cause: {err['root_cause']}")
print(f" Mitigation: {err['mitigation']}")
# Bar chart comparison
if summary_rows:
plot_metric_comparison(summary_rows, output_dir)
print("\n[DONE] Evaluation complete.")
def plot_metric_comparison(summary_rows: list[dict], output_dir: str):
"""Bar chart comparing key metrics across models."""
df = pd.DataFrame(summary_rows).set_index("model")
plot_cols = ["accuracy", "precision", "recall", "f1", "f1_macro"]
plot_cols = [c for c in plot_cols if c in df.columns]
fig, ax = plt.subplots(figsize=(10, 6))
df[plot_cols].plot(kind="bar", ax=ax, colormap="Set2")
ax.set_ylabel("Score")
ax.set_title("Model Comparison — Test Set Metrics")
ax.set_ylim(0, 1.05)
ax.legend(loc="lower right")
plt.xticks(rotation=0)
plt.tight_layout()
path = os.path.join(output_dir, "model_comparison.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[SAVED] {path}")
def main():
parser = argparse.ArgumentParser(description="Evaluate all models on test set")
parser.add_argument("--data_dir", type=str, default="data/processed")
parser.add_argument("--model_dir", type=str, default="models")
parser.add_argument("--output_dir", type=str, default="data/outputs")
args = parser.parse_args()
evaluate_all(args.data_dir, args.model_dir, args.output_dir)
if __name__ == "__main__":
main()