customer-support-agent / src /evaluation /classifier_eval.py
pro580's picture
Fix rate limiter to use X-Forwarded-For header behind HF proxy
e323466
Raw
History Blame Contribute Delete
4.81 kB
"""Classification evaluation metrics — F1, confusion matrix, comparison table."""
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from loguru import logger
from sklearn.metrics import (
classification_report,
confusion_matrix,
f1_score,
accuracy_score,
)
from src.data.dataset import INTENT_CATEGORIES
def evaluate_classifier(
predictions: List[str],
ground_truth: List[str],
label: str,
results_dir: str,
) -> Dict:
"""Compute and save classification metrics.
Args:
predictions: List of predicted intent labels.
ground_truth: List of true intent labels.
label: Short name for the model (e.g., 'baseline', 'distilbert').
results_dir: Directory to save artifacts.
Returns:
Classification report as a dict.
"""
Path(results_dir).mkdir(parents=True, exist_ok=True)
labels_sorted = sorted(INTENT_CATEGORIES)
report = classification_report(
ground_truth, predictions, labels=labels_sorted, output_dict=True
)
report_text = classification_report(ground_truth, predictions, labels=labels_sorted)
logger.info(f"[{label}] Classification report:\n{report_text}")
report_path = Path(results_dir) / f"{label}_classification_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
# Confusion matrix
cm = confusion_matrix(ground_truth, predictions, labels=labels_sorted)
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=labels_sorted,
yticklabels=labels_sorted,
ax=ax,
)
ax.set_title(f"Confusion Matrix — {label}")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
plt.tight_layout()
cm_path = Path(results_dir) / f"{label}_confusion_matrix.png"
fig.savefig(cm_path, dpi=150)
plt.close(fig)
logger.info(f"Saved confusion matrix → {cm_path}")
return report
def generate_comparison_table(
baseline_report: Dict,
distilbert_report: Dict,
baseline_inference_ms: float,
distilbert_inference_ms: float,
baseline_size_mb: float,
distilbert_size_mb: float,
results_dir: str,
) -> str:
"""Generate a markdown comparison table between baseline and DistilBERT.
Args:
baseline_report: Classification report dict for the baseline.
distilbert_report: Classification report dict for DistilBERT.
baseline_inference_ms: Average inference time per sample (ms) for baseline.
distilbert_inference_ms: Average inference time per sample (ms) for DistilBERT.
baseline_size_mb: Baseline model size in MB.
distilbert_size_mb: DistilBERT model size in MB.
results_dir: Directory to save the comparison table.
Returns:
Markdown table string.
"""
rows = []
rows.append(
f"| Weighted F1 | {baseline_report['weighted avg']['f1-score']:.4f} "
f"| {distilbert_report['weighted avg']['f1-score']:.4f} |"
)
rows.append(
f"| Accuracy | {baseline_report['accuracy']:.4f} "
f"| {distilbert_report['accuracy']:.4f} |"
)
for intent in sorted(INTENT_CATEGORIES):
b_f1 = baseline_report.get(intent, {}).get("f1-score", 0.0)
d_f1 = distilbert_report.get(intent, {}).get("f1-score", 0.0)
rows.append(f"| F1 — {intent} | {b_f1:.4f} | {d_f1:.4f} |")
rows.append(
f"| Inference time (ms/sample) | {baseline_inference_ms:.2f} "
f"| {distilbert_inference_ms:.2f} |"
)
rows.append(
f"| Model size (MB) | {baseline_size_mb:.1f} | {distilbert_size_mb:.1f} |"
)
header = (
"| Metric | TF-IDF + LR Baseline | DistilBERT Fine-tuned |\n"
"|--------|----------------------|----------------------|"
)
table = header + "\n" + "\n".join(rows)
path = Path(results_dir) / "comparison_table.md"
path.write_text(table)
logger.info(f"Saved comparison table → {path}")
return table
def measure_inference_time(
predict_fn,
texts: List[str],
n_samples: int = 100,
) -> float:
"""Measure average per-sample inference time in milliseconds.
Args:
predict_fn: Callable that takes a list of texts and returns predictions.
texts: List of input texts to sample from.
n_samples: Number of samples to time.
Returns:
Average inference time per sample in milliseconds.
"""
import random
sample = random.sample(texts, min(n_samples, len(texts)))
start = time.perf_counter()
predict_fn(sample)
elapsed_ms = (time.perf_counter() - start) * 1000
return elapsed_ms / len(sample)