marcosremar2's picture
Upload folder using huggingface_hub
3c1eb61 verified
"""
Base classes and evaluation metrics for turn-taking benchmarks.
Metrics follow standard turn-taking evaluation methodology:
- Ekstedt, E. & Torre, G. (2024). Voice Activity Projection: Self-supervised
Learning of Turn-taking Events. arXiv:2401.04868.
- Skantze, G. (2021). Turn-taking in Conversational Systems and Human-Robot
Interaction: A Review. Computer Speech & Language, 67.
"""
from __future__ import annotations
import json
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
from sklearn.metrics import (
balanced_accuracy_score,
f1_score,
precision_score,
recall_score,
confusion_matrix,
)
from setup_dataset import Conversation, TurnSegment
log = logging.getLogger(__name__)
RESULTS_DIR = Path(__file__).parent / "results"
@dataclass
class PredictedEvent:
"""A predicted turn-taking event."""
timestamp: float # seconds
event_type: str # "shift" or "hold"
confidence: float = 1.0
latency_ms: float = 0.0 # inference latency
@dataclass
class BenchmarkResult:
"""Results from evaluating a single model on the dataset."""
model_name: str
dataset_name: str
# Classification metrics
precision_shift: float = 0.0
recall_shift: float = 0.0
f1_shift: float = 0.0
precision_hold: float = 0.0
recall_hold: float = 0.0
f1_hold: float = 0.0
balanced_accuracy: float = 0.0
macro_f1: float = 0.0
# Timing metrics
mean_latency_ms: float = 0.0
p50_latency_ms: float = 0.0
p95_latency_ms: float = 0.0
p99_latency_ms: float = 0.0
# Turn-specific metrics
mean_shift_delay_ms: float = 0.0 # How early/late shifts are detected
false_interruption_rate: float = 0.0 # False positive shifts
missed_shift_rate: float = 0.0 # False negative shifts
# Resource usage
model_size_mb: float = 0.0
peak_memory_mb: float = 0.0
requires_gpu: bool = False
requires_asr: bool = False
# Metadata
n_conversations: int = 0
n_predictions: int = 0
total_audio_hours: float = 0.0
extra: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {k: v for k, v in self.__dict__.items()}
class TurnTakingModel(ABC):
"""Abstract base for turn-taking prediction models."""
@property
@abstractmethod
def name(self) -> str:
...
@property
@abstractmethod
def requires_gpu(self) -> bool:
...
@property
@abstractmethod
def requires_asr(self) -> bool:
...
@abstractmethod
def predict(self, conversation: Conversation) -> list[PredictedEvent]:
"""Predict turn-taking events for a conversation."""
...
def get_model_size_mb(self) -> float:
"""Return model size in MB."""
return 0.0
def evaluate_model(
model: TurnTakingModel,
conversations: list[Conversation],
dataset_name: str,
tolerance_ms: float = 500.0,
) -> BenchmarkResult:
"""
Evaluate a turn-taking model against ground truth annotations.
Args:
model: The model to evaluate
conversations: List of conversations with ground truth
dataset_name: Name of the dataset
tolerance_ms: Matching tolerance in milliseconds for event alignment
Returns:
BenchmarkResult with all metrics computed
"""
all_true_labels: list[int] = []
all_pred_labels: list[int] = []
all_latencies: list[float] = []
shift_delays: list[float] = []
false_interruptions = 0
missed_shifts = 0
total_shifts = 0
total_predictions = 0
tolerance_s = tolerance_ms / 1000.0
for conv in conversations:
t0 = time.perf_counter()
predictions = model.predict(conv)
elapsed_ms = (time.perf_counter() - t0) * 1000.0
if predictions:
per_pred_latency = elapsed_ms / len(predictions)
all_latencies.extend([per_pred_latency] * len(predictions))
total_predictions += len(predictions)
# Build ground truth event timeline
gt_shifts = set(conv.turn_shifts)
gt_holds = set(conv.holds)
total_shifts += len(gt_shifts)
# Match predictions to ground truth events
matched_shifts: set[float] = set()
matched_holds: set[float] = set()
for pred in predictions:
matched = False
# Check if prediction matches a ground truth shift
for gt_t in gt_shifts:
if abs(pred.timestamp - gt_t) <= tolerance_s:
if pred.event_type == "shift":
all_true_labels.append(1)
all_pred_labels.append(1)
matched_shifts.add(gt_t)
shift_delays.append((pred.timestamp - gt_t) * 1000.0)
else:
all_true_labels.append(1)
all_pred_labels.append(0)
matched = True
break
if matched:
continue
# Check if prediction matches a ground truth hold
for gt_t in gt_holds:
if abs(pred.timestamp - gt_t) <= tolerance_s:
if pred.event_type == "hold":
all_true_labels.append(0)
all_pred_labels.append(0)
matched_holds.add(gt_t)
else:
all_true_labels.append(0)
all_pred_labels.append(1)
false_interruptions += 1
matched = True
break
if not matched:
# Unmatched prediction = false positive
if pred.event_type == "shift":
all_true_labels.append(0)
all_pred_labels.append(1)
false_interruptions += 1
else:
all_true_labels.append(0)
all_pred_labels.append(0)
# Unmatched ground truth shifts = missed
for gt_t in gt_shifts:
if gt_t not in matched_shifts:
all_true_labels.append(1)
all_pred_labels.append(0)
missed_shifts += 1
# Compute metrics
y_true = np.array(all_true_labels)
y_pred = np.array(all_pred_labels)
result = BenchmarkResult(
model_name=model.name,
dataset_name=dataset_name,
n_conversations=len(conversations),
n_predictions=total_predictions,
total_audio_hours=sum(c.duration for c in conversations) / 3600.0,
requires_gpu=model.requires_gpu,
requires_asr=model.requires_asr,
model_size_mb=model.get_model_size_mb(),
)
if len(y_true) > 0 and len(np.unique(y_true)) > 1:
result.precision_shift = float(precision_score(y_true, y_pred, pos_label=1, zero_division=0))
result.recall_shift = float(recall_score(y_true, y_pred, pos_label=1, zero_division=0))
result.f1_shift = float(f1_score(y_true, y_pred, pos_label=1, zero_division=0))
result.precision_hold = float(precision_score(y_true, y_pred, pos_label=0, zero_division=0))
result.recall_hold = float(recall_score(y_true, y_pred, pos_label=0, zero_division=0))
result.f1_hold = float(f1_score(y_true, y_pred, pos_label=0, zero_division=0))
result.balanced_accuracy = float(balanced_accuracy_score(y_true, y_pred))
result.macro_f1 = float(f1_score(y_true, y_pred, average="macro", zero_division=0))
if all_latencies:
arr = np.array(all_latencies)
result.mean_latency_ms = float(np.mean(arr))
result.p50_latency_ms = float(np.percentile(arr, 50))
result.p95_latency_ms = float(np.percentile(arr, 95))
result.p99_latency_ms = float(np.percentile(arr, 99))
if shift_delays:
result.mean_shift_delay_ms = float(np.mean(shift_delays))
if total_shifts > 0:
result.missed_shift_rate = missed_shifts / total_shifts
total_non_shifts = len(all_true_labels) - total_shifts
if total_non_shifts > 0:
result.false_interruption_rate = false_interruptions / total_non_shifts
return result
def save_result(result: BenchmarkResult) -> Path:
"""Save benchmark result to JSON."""
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
path = RESULTS_DIR / f"{result.model_name}_{result.dataset_name}.json"
with open(path, "w") as f:
json.dump(result.to_dict(), f, indent=2)
log.info("Saved result to %s", path)
return path
def load_all_results() -> list[BenchmarkResult]:
"""Load all saved benchmark results."""
results = []
if not RESULTS_DIR.exists():
return results
for path in sorted(RESULTS_DIR.glob("*.json")):
with open(path) as f:
data = json.load(f)
results.append(BenchmarkResult(**data))
return results