Spaces:
Running
Running
File size: 4,807 Bytes
e323466 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Classification evaluation metrics — F1, confusion matrix, comparison table."""
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from loguru import logger
from sklearn.metrics import (
classification_report,
confusion_matrix,
f1_score,
accuracy_score,
)
from src.data.dataset import INTENT_CATEGORIES
def evaluate_classifier(
predictions: List[str],
ground_truth: List[str],
label: str,
results_dir: str,
) -> Dict:
"""Compute and save classification metrics.
Args:
predictions: List of predicted intent labels.
ground_truth: List of true intent labels.
label: Short name for the model (e.g., 'baseline', 'distilbert').
results_dir: Directory to save artifacts.
Returns:
Classification report as a dict.
"""
Path(results_dir).mkdir(parents=True, exist_ok=True)
labels_sorted = sorted(INTENT_CATEGORIES)
report = classification_report(
ground_truth, predictions, labels=labels_sorted, output_dict=True
)
report_text = classification_report(ground_truth, predictions, labels=labels_sorted)
logger.info(f"[{label}] Classification report:\n{report_text}")
report_path = Path(results_dir) / f"{label}_classification_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
# Confusion matrix
cm = confusion_matrix(ground_truth, predictions, labels=labels_sorted)
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=labels_sorted,
yticklabels=labels_sorted,
ax=ax,
)
ax.set_title(f"Confusion Matrix — {label}")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
plt.tight_layout()
cm_path = Path(results_dir) / f"{label}_confusion_matrix.png"
fig.savefig(cm_path, dpi=150)
plt.close(fig)
logger.info(f"Saved confusion matrix → {cm_path}")
return report
def generate_comparison_table(
baseline_report: Dict,
distilbert_report: Dict,
baseline_inference_ms: float,
distilbert_inference_ms: float,
baseline_size_mb: float,
distilbert_size_mb: float,
results_dir: str,
) -> str:
"""Generate a markdown comparison table between baseline and DistilBERT.
Args:
baseline_report: Classification report dict for the baseline.
distilbert_report: Classification report dict for DistilBERT.
baseline_inference_ms: Average inference time per sample (ms) for baseline.
distilbert_inference_ms: Average inference time per sample (ms) for DistilBERT.
baseline_size_mb: Baseline model size in MB.
distilbert_size_mb: DistilBERT model size in MB.
results_dir: Directory to save the comparison table.
Returns:
Markdown table string.
"""
rows = []
rows.append(
f"| Weighted F1 | {baseline_report['weighted avg']['f1-score']:.4f} "
f"| {distilbert_report['weighted avg']['f1-score']:.4f} |"
)
rows.append(
f"| Accuracy | {baseline_report['accuracy']:.4f} "
f"| {distilbert_report['accuracy']:.4f} |"
)
for intent in sorted(INTENT_CATEGORIES):
b_f1 = baseline_report.get(intent, {}).get("f1-score", 0.0)
d_f1 = distilbert_report.get(intent, {}).get("f1-score", 0.0)
rows.append(f"| F1 — {intent} | {b_f1:.4f} | {d_f1:.4f} |")
rows.append(
f"| Inference time (ms/sample) | {baseline_inference_ms:.2f} "
f"| {distilbert_inference_ms:.2f} |"
)
rows.append(
f"| Model size (MB) | {baseline_size_mb:.1f} | {distilbert_size_mb:.1f} |"
)
header = (
"| Metric | TF-IDF + LR Baseline | DistilBERT Fine-tuned |\n"
"|--------|----------------------|----------------------|"
)
table = header + "\n" + "\n".join(rows)
path = Path(results_dir) / "comparison_table.md"
path.write_text(table)
logger.info(f"Saved comparison table → {path}")
return table
def measure_inference_time(
predict_fn,
texts: List[str],
n_samples: int = 100,
) -> float:
"""Measure average per-sample inference time in milliseconds.
Args:
predict_fn: Callable that takes a list of texts and returns predictions.
texts: List of input texts to sample from.
n_samples: Number of samples to time.
Returns:
Average inference time per sample in milliseconds.
"""
import random
sample = random.sample(texts, min(n_samples, len(texts)))
start = time.perf_counter()
predict_fn(sample)
elapsed_ms = (time.perf_counter() - start) * 1000
return elapsed_ms / len(sample)
|