File size: 4,807 Bytes
e323466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Classification evaluation metrics — F1, confusion matrix, comparison table."""

import json
import time
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from loguru import logger
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    accuracy_score,
)

from src.data.dataset import INTENT_CATEGORIES


def evaluate_classifier(
    predictions: List[str],
    ground_truth: List[str],
    label: str,
    results_dir: str,
) -> Dict:
    """Compute and save classification metrics.

    Args:
        predictions: List of predicted intent labels.
        ground_truth: List of true intent labels.
        label: Short name for the model (e.g., 'baseline', 'distilbert').
        results_dir: Directory to save artifacts.

    Returns:
        Classification report as a dict.
    """
    Path(results_dir).mkdir(parents=True, exist_ok=True)
    labels_sorted = sorted(INTENT_CATEGORIES)

    report = classification_report(
        ground_truth, predictions, labels=labels_sorted, output_dict=True
    )
    report_text = classification_report(ground_truth, predictions, labels=labels_sorted)
    logger.info(f"[{label}] Classification report:\n{report_text}")

    report_path = Path(results_dir) / f"{label}_classification_report.json"
    with open(report_path, "w") as f:
        json.dump(report, f, indent=2)

    # Confusion matrix
    cm = confusion_matrix(ground_truth, predictions, labels=labels_sorted)
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=labels_sorted,
        yticklabels=labels_sorted,
        ax=ax,
    )
    ax.set_title(f"Confusion Matrix — {label}")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    plt.tight_layout()
    cm_path = Path(results_dir) / f"{label}_confusion_matrix.png"
    fig.savefig(cm_path, dpi=150)
    plt.close(fig)
    logger.info(f"Saved confusion matrix → {cm_path}")

    return report


def generate_comparison_table(
    baseline_report: Dict,
    distilbert_report: Dict,
    baseline_inference_ms: float,
    distilbert_inference_ms: float,
    baseline_size_mb: float,
    distilbert_size_mb: float,
    results_dir: str,
) -> str:
    """Generate a markdown comparison table between baseline and DistilBERT.

    Args:
        baseline_report: Classification report dict for the baseline.
        distilbert_report: Classification report dict for DistilBERT.
        baseline_inference_ms: Average inference time per sample (ms) for baseline.
        distilbert_inference_ms: Average inference time per sample (ms) for DistilBERT.
        baseline_size_mb: Baseline model size in MB.
        distilbert_size_mb: DistilBERT model size in MB.
        results_dir: Directory to save the comparison table.

    Returns:
        Markdown table string.
    """
    rows = []
    rows.append(
        f"| Weighted F1 | {baseline_report['weighted avg']['f1-score']:.4f} "
        f"| {distilbert_report['weighted avg']['f1-score']:.4f} |"
    )
    rows.append(
        f"| Accuracy | {baseline_report['accuracy']:.4f} "
        f"| {distilbert_report['accuracy']:.4f} |"
    )
    for intent in sorted(INTENT_CATEGORIES):
        b_f1 = baseline_report.get(intent, {}).get("f1-score", 0.0)
        d_f1 = distilbert_report.get(intent, {}).get("f1-score", 0.0)
        rows.append(f"| F1 — {intent} | {b_f1:.4f} | {d_f1:.4f} |")
    rows.append(
        f"| Inference time (ms/sample) | {baseline_inference_ms:.2f} "
        f"| {distilbert_inference_ms:.2f} |"
    )
    rows.append(
        f"| Model size (MB) | {baseline_size_mb:.1f} | {distilbert_size_mb:.1f} |"
    )

    header = (
        "| Metric | TF-IDF + LR Baseline | DistilBERT Fine-tuned |\n"
        "|--------|----------------------|----------------------|"
    )
    table = header + "\n" + "\n".join(rows)

    path = Path(results_dir) / "comparison_table.md"
    path.write_text(table)
    logger.info(f"Saved comparison table → {path}")
    return table


def measure_inference_time(
    predict_fn,
    texts: List[str],
    n_samples: int = 100,
) -> float:
    """Measure average per-sample inference time in milliseconds.

    Args:
        predict_fn: Callable that takes a list of texts and returns predictions.
        texts: List of input texts to sample from.
        n_samples: Number of samples to time.

    Returns:
        Average inference time per sample in milliseconds.
    """
    import random

    sample = random.sample(texts, min(n_samples, len(texts)))
    start = time.perf_counter()
    predict_fn(sample)
    elapsed_ms = (time.perf_counter() - start) * 1000
    return elapsed_ms / len(sample)