| """ |
| Base classes and evaluation metrics for turn-taking benchmarks. |
| |
| Metrics follow standard turn-taking evaluation methodology: |
| - Ekstedt, E. & Torre, G. (2024). Voice Activity Projection: Self-supervised |
| Learning of Turn-taking Events. arXiv:2401.04868. |
| - Skantze, G. (2021). Turn-taking in Conversational Systems and Human-Robot |
| Interaction: A Review. Computer Speech & Language, 67. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import time |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| import numpy as np |
| from sklearn.metrics import ( |
| balanced_accuracy_score, |
| f1_score, |
| precision_score, |
| recall_score, |
| confusion_matrix, |
| ) |
|
|
| from setup_dataset import Conversation, TurnSegment |
|
|
| log = logging.getLogger(__name__) |
|
|
| RESULTS_DIR = Path(__file__).parent / "results" |
|
|
|
|
| @dataclass |
| class PredictedEvent: |
| """A predicted turn-taking event.""" |
| timestamp: float |
| event_type: str |
| confidence: float = 1.0 |
| latency_ms: float = 0.0 |
|
|
|
|
| @dataclass |
| class BenchmarkResult: |
| """Results from evaluating a single model on the dataset.""" |
| model_name: str |
| dataset_name: str |
| |
| precision_shift: float = 0.0 |
| recall_shift: float = 0.0 |
| f1_shift: float = 0.0 |
| precision_hold: float = 0.0 |
| recall_hold: float = 0.0 |
| f1_hold: float = 0.0 |
| balanced_accuracy: float = 0.0 |
| macro_f1: float = 0.0 |
| |
| mean_latency_ms: float = 0.0 |
| p50_latency_ms: float = 0.0 |
| p95_latency_ms: float = 0.0 |
| p99_latency_ms: float = 0.0 |
| |
| mean_shift_delay_ms: float = 0.0 |
| false_interruption_rate: float = 0.0 |
| missed_shift_rate: float = 0.0 |
| |
| model_size_mb: float = 0.0 |
| peak_memory_mb: float = 0.0 |
| requires_gpu: bool = False |
| requires_asr: bool = False |
| |
| n_conversations: int = 0 |
| n_predictions: int = 0 |
| total_audio_hours: float = 0.0 |
| extra: dict = field(default_factory=dict) |
|
|
| def to_dict(self) -> dict: |
| return {k: v for k, v in self.__dict__.items()} |
|
|
|
|
| class TurnTakingModel(ABC): |
| """Abstract base for turn-taking prediction models.""" |
|
|
| @property |
| @abstractmethod |
| def name(self) -> str: |
| ... |
|
|
| @property |
| @abstractmethod |
| def requires_gpu(self) -> bool: |
| ... |
|
|
| @property |
| @abstractmethod |
| def requires_asr(self) -> bool: |
| ... |
|
|
| @abstractmethod |
| def predict(self, conversation: Conversation) -> list[PredictedEvent]: |
| """Predict turn-taking events for a conversation.""" |
| ... |
|
|
| def get_model_size_mb(self) -> float: |
| """Return model size in MB.""" |
| return 0.0 |
|
|
|
|
| def evaluate_model( |
| model: TurnTakingModel, |
| conversations: list[Conversation], |
| dataset_name: str, |
| tolerance_ms: float = 500.0, |
| ) -> BenchmarkResult: |
| """ |
| Evaluate a turn-taking model against ground truth annotations. |
| |
| Args: |
| model: The model to evaluate |
| conversations: List of conversations with ground truth |
| dataset_name: Name of the dataset |
| tolerance_ms: Matching tolerance in milliseconds for event alignment |
| |
| Returns: |
| BenchmarkResult with all metrics computed |
| """ |
| all_true_labels: list[int] = [] |
| all_pred_labels: list[int] = [] |
| all_latencies: list[float] = [] |
| shift_delays: list[float] = [] |
| false_interruptions = 0 |
| missed_shifts = 0 |
| total_shifts = 0 |
| total_predictions = 0 |
|
|
| tolerance_s = tolerance_ms / 1000.0 |
|
|
| for conv in conversations: |
| t0 = time.perf_counter() |
| predictions = model.predict(conv) |
| elapsed_ms = (time.perf_counter() - t0) * 1000.0 |
|
|
| if predictions: |
| per_pred_latency = elapsed_ms / len(predictions) |
| all_latencies.extend([per_pred_latency] * len(predictions)) |
| total_predictions += len(predictions) |
|
|
| |
| gt_shifts = set(conv.turn_shifts) |
| gt_holds = set(conv.holds) |
| total_shifts += len(gt_shifts) |
|
|
| |
| matched_shifts: set[float] = set() |
| matched_holds: set[float] = set() |
|
|
| for pred in predictions: |
| matched = False |
|
|
| |
| for gt_t in gt_shifts: |
| if abs(pred.timestamp - gt_t) <= tolerance_s: |
| if pred.event_type == "shift": |
| all_true_labels.append(1) |
| all_pred_labels.append(1) |
| matched_shifts.add(gt_t) |
| shift_delays.append((pred.timestamp - gt_t) * 1000.0) |
| else: |
| all_true_labels.append(1) |
| all_pred_labels.append(0) |
| matched = True |
| break |
|
|
| if matched: |
| continue |
|
|
| |
| for gt_t in gt_holds: |
| if abs(pred.timestamp - gt_t) <= tolerance_s: |
| if pred.event_type == "hold": |
| all_true_labels.append(0) |
| all_pred_labels.append(0) |
| matched_holds.add(gt_t) |
| else: |
| all_true_labels.append(0) |
| all_pred_labels.append(1) |
| false_interruptions += 1 |
| matched = True |
| break |
|
|
| if not matched: |
| |
| if pred.event_type == "shift": |
| all_true_labels.append(0) |
| all_pred_labels.append(1) |
| false_interruptions += 1 |
| else: |
| all_true_labels.append(0) |
| all_pred_labels.append(0) |
|
|
| |
| for gt_t in gt_shifts: |
| if gt_t not in matched_shifts: |
| all_true_labels.append(1) |
| all_pred_labels.append(0) |
| missed_shifts += 1 |
|
|
| |
| y_true = np.array(all_true_labels) |
| y_pred = np.array(all_pred_labels) |
|
|
| result = BenchmarkResult( |
| model_name=model.name, |
| dataset_name=dataset_name, |
| n_conversations=len(conversations), |
| n_predictions=total_predictions, |
| total_audio_hours=sum(c.duration for c in conversations) / 3600.0, |
| requires_gpu=model.requires_gpu, |
| requires_asr=model.requires_asr, |
| model_size_mb=model.get_model_size_mb(), |
| ) |
|
|
| if len(y_true) > 0 and len(np.unique(y_true)) > 1: |
| result.precision_shift = float(precision_score(y_true, y_pred, pos_label=1, zero_division=0)) |
| result.recall_shift = float(recall_score(y_true, y_pred, pos_label=1, zero_division=0)) |
| result.f1_shift = float(f1_score(y_true, y_pred, pos_label=1, zero_division=0)) |
| result.precision_hold = float(precision_score(y_true, y_pred, pos_label=0, zero_division=0)) |
| result.recall_hold = float(recall_score(y_true, y_pred, pos_label=0, zero_division=0)) |
| result.f1_hold = float(f1_score(y_true, y_pred, pos_label=0, zero_division=0)) |
| result.balanced_accuracy = float(balanced_accuracy_score(y_true, y_pred)) |
| result.macro_f1 = float(f1_score(y_true, y_pred, average="macro", zero_division=0)) |
|
|
| if all_latencies: |
| arr = np.array(all_latencies) |
| result.mean_latency_ms = float(np.mean(arr)) |
| result.p50_latency_ms = float(np.percentile(arr, 50)) |
| result.p95_latency_ms = float(np.percentile(arr, 95)) |
| result.p99_latency_ms = float(np.percentile(arr, 99)) |
|
|
| if shift_delays: |
| result.mean_shift_delay_ms = float(np.mean(shift_delays)) |
|
|
| if total_shifts > 0: |
| result.missed_shift_rate = missed_shifts / total_shifts |
|
|
| total_non_shifts = len(all_true_labels) - total_shifts |
| if total_non_shifts > 0: |
| result.false_interruption_rate = false_interruptions / total_non_shifts |
|
|
| return result |
|
|
|
|
| def save_result(result: BenchmarkResult) -> Path: |
| """Save benchmark result to JSON.""" |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) |
| path = RESULTS_DIR / f"{result.model_name}_{result.dataset_name}.json" |
| with open(path, "w") as f: |
| json.dump(result.to_dict(), f, indent=2) |
| log.info("Saved result to %s", path) |
| return path |
|
|
|
|
| def load_all_results() -> list[BenchmarkResult]: |
| """Load all saved benchmark results.""" |
| results = [] |
| if not RESULTS_DIR.exists(): |
| return results |
| for path in sorted(RESULTS_DIR.glob("*.json")): |
| with open(path) as f: |
| data = json.load(f) |
| results.append(BenchmarkResult(**data)) |
| return results |
|
|