Spaces:
Running
Running
| """Classification evaluation metrics — F1, confusion matrix, comparison table.""" | |
| import json | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import seaborn as sns | |
| from loguru import logger | |
| from sklearn.metrics import ( | |
| classification_report, | |
| confusion_matrix, | |
| f1_score, | |
| accuracy_score, | |
| ) | |
| from src.data.dataset import INTENT_CATEGORIES | |
| def evaluate_classifier( | |
| predictions: List[str], | |
| ground_truth: List[str], | |
| label: str, | |
| results_dir: str, | |
| ) -> Dict: | |
| """Compute and save classification metrics. | |
| Args: | |
| predictions: List of predicted intent labels. | |
| ground_truth: List of true intent labels. | |
| label: Short name for the model (e.g., 'baseline', 'distilbert'). | |
| results_dir: Directory to save artifacts. | |
| Returns: | |
| Classification report as a dict. | |
| """ | |
| Path(results_dir).mkdir(parents=True, exist_ok=True) | |
| labels_sorted = sorted(INTENT_CATEGORIES) | |
| report = classification_report( | |
| ground_truth, predictions, labels=labels_sorted, output_dict=True | |
| ) | |
| report_text = classification_report(ground_truth, predictions, labels=labels_sorted) | |
| logger.info(f"[{label}] Classification report:\n{report_text}") | |
| report_path = Path(results_dir) / f"{label}_classification_report.json" | |
| with open(report_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| # Confusion matrix | |
| cm = confusion_matrix(ground_truth, predictions, labels=labels_sorted) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| xticklabels=labels_sorted, | |
| yticklabels=labels_sorted, | |
| ax=ax, | |
| ) | |
| ax.set_title(f"Confusion Matrix — {label}") | |
| ax.set_xlabel("Predicted") | |
| ax.set_ylabel("True") | |
| plt.tight_layout() | |
| cm_path = Path(results_dir) / f"{label}_confusion_matrix.png" | |
| fig.savefig(cm_path, dpi=150) | |
| plt.close(fig) | |
| logger.info(f"Saved confusion matrix → {cm_path}") | |
| return report | |
| def generate_comparison_table( | |
| baseline_report: Dict, | |
| distilbert_report: Dict, | |
| baseline_inference_ms: float, | |
| distilbert_inference_ms: float, | |
| baseline_size_mb: float, | |
| distilbert_size_mb: float, | |
| results_dir: str, | |
| ) -> str: | |
| """Generate a markdown comparison table between baseline and DistilBERT. | |
| Args: | |
| baseline_report: Classification report dict for the baseline. | |
| distilbert_report: Classification report dict for DistilBERT. | |
| baseline_inference_ms: Average inference time per sample (ms) for baseline. | |
| distilbert_inference_ms: Average inference time per sample (ms) for DistilBERT. | |
| baseline_size_mb: Baseline model size in MB. | |
| distilbert_size_mb: DistilBERT model size in MB. | |
| results_dir: Directory to save the comparison table. | |
| Returns: | |
| Markdown table string. | |
| """ | |
| rows = [] | |
| rows.append( | |
| f"| Weighted F1 | {baseline_report['weighted avg']['f1-score']:.4f} " | |
| f"| {distilbert_report['weighted avg']['f1-score']:.4f} |" | |
| ) | |
| rows.append( | |
| f"| Accuracy | {baseline_report['accuracy']:.4f} " | |
| f"| {distilbert_report['accuracy']:.4f} |" | |
| ) | |
| for intent in sorted(INTENT_CATEGORIES): | |
| b_f1 = baseline_report.get(intent, {}).get("f1-score", 0.0) | |
| d_f1 = distilbert_report.get(intent, {}).get("f1-score", 0.0) | |
| rows.append(f"| F1 — {intent} | {b_f1:.4f} | {d_f1:.4f} |") | |
| rows.append( | |
| f"| Inference time (ms/sample) | {baseline_inference_ms:.2f} " | |
| f"| {distilbert_inference_ms:.2f} |" | |
| ) | |
| rows.append( | |
| f"| Model size (MB) | {baseline_size_mb:.1f} | {distilbert_size_mb:.1f} |" | |
| ) | |
| header = ( | |
| "| Metric | TF-IDF + LR Baseline | DistilBERT Fine-tuned |\n" | |
| "|--------|----------------------|----------------------|" | |
| ) | |
| table = header + "\n" + "\n".join(rows) | |
| path = Path(results_dir) / "comparison_table.md" | |
| path.write_text(table) | |
| logger.info(f"Saved comparison table → {path}") | |
| return table | |
| def measure_inference_time( | |
| predict_fn, | |
| texts: List[str], | |
| n_samples: int = 100, | |
| ) -> float: | |
| """Measure average per-sample inference time in milliseconds. | |
| Args: | |
| predict_fn: Callable that takes a list of texts and returns predictions. | |
| texts: List of input texts to sample from. | |
| n_samples: Number of samples to time. | |
| Returns: | |
| Average inference time per sample in milliseconds. | |
| """ | |
| import random | |
| sample = random.sample(texts, min(n_samples, len(texts))) | |
| start = time.perf_counter() | |
| predict_fn(sample) | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| return elapsed_ms / len(sample) | |