odg123's picture
Upload icefall experiment results and logs
d596074 verified
#!/usr/bin/env python3
"""
Utility functions for [SCD] token extraction, alignment, and timestamp calculation.
Based on Tokenverse paper approach:
- Extract predicted task tokens from hypothesis
- Align tokens in time domain using acoustic frame indices
- Calculate timestamps using XLSR frame duration (25ms) and stride (20ms)
"""
import logging
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import torch
@dataclass
class SCDToken:
"""Represents a detected [SCD] token with timing information."""
token: str
frame_index: int
timestamp: float # in seconds
token_index: int # position in token sequence
confidence: Optional[float] = None
@dataclass
class SCDPrediction:
"""Complete prediction result with [SCD] tokens and timestamps."""
text: str # full hypothesis with [SCD] tokens
text_clean: str # hypothesis without [SCD] tokens
scd_tokens: List[SCDToken]
total_frames: int
duration: float
def extract_scd_timestamps(
tokens: List[str],
frame_indices: List[int],
frame_duration_ms: float = 25.0,
frame_stride_ms: float = 20.0,
) -> List[SCDToken]:
"""
Extract [SCD] tokens with their timestamps from decoded token sequence.
Args:
tokens: List of decoded tokens (including [SCD])
frame_indices: Acoustic frame index for each token emission
frame_duration_ms: Duration of each XLSR frame in milliseconds
frame_stride_ms: Stride between consecutive frames in milliseconds
Returns:
List of SCDToken objects with timing information
"""
scd_tokens = []
for token_idx, (token, frame_idx) in enumerate(zip(tokens, frame_indices)):
if token == "[SCD]" or token == "▁[SCD]": # Handle with/without space prefix
# Calculate timestamp: frame_index * stride
timestamp = (frame_idx * frame_stride_ms) / 1000.0 # convert to seconds
scd_token = SCDToken(
token="[SCD]",
frame_index=frame_idx,
timestamp=timestamp,
token_index=token_idx,
)
scd_tokens.append(scd_token)
return scd_tokens
def align_predicted_and_reference_scd(
pred_scd_tokens: List[SCDToken],
ref_scd_timestamps: List[float],
tolerance_sec: float = 1.0,
) -> Tuple[int, int, int]:
"""
Align predicted [SCD] tokens with reference timestamps.
Args:
pred_scd_tokens: Predicted [SCD] tokens with timestamps
ref_scd_timestamps: Reference [SCD] timestamps (ground truth)
tolerance_sec: Time tolerance for considering a match (seconds)
Returns:
Tuple of (true_positives, false_positives, false_negatives)
"""
pred_times = [t.timestamp for t in pred_scd_tokens]
tp = 0 # True positives
fp = 0 # False positives
fn = 0 # False negatives
matched_refs = set()
matched_preds = set()
# Match predictions to references
for pred_idx, pred_time in enumerate(pred_times):
matched = False
for ref_idx, ref_time in enumerate(ref_scd_timestamps):
if ref_idx in matched_refs:
continue
time_diff = abs(pred_time - ref_time)
if time_diff <= tolerance_sec:
tp += 1
matched_refs.add(ref_idx)
matched_preds.add(pred_idx)
matched = True
break
# Count false positives (unmatched predictions)
fp = len(pred_times) - len(matched_preds)
# Count false negatives (unmatched references)
fn = len(ref_scd_timestamps) - len(matched_refs)
return tp, fp, fn
def calculate_scd_metrics(
tp: int, fp: int, fn: int
) -> Dict[str, float]:
"""
Calculate precision, recall, and F1 score for [SCD] detection.
Args:
tp: True positives
fp: False positives
fn: False negatives
Returns:
Dictionary with precision, recall, and f1 scores
"""
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return {
"precision": precision,
"recall": recall,
"f1": f1,
"tp": tp,
"fp": fp,
"fn": fn,
}
def remove_scd_tokens(text: str) -> str:
"""Remove [SCD] tokens from text for WER calculation."""
return text.replace("[SCD]", "").replace(" ", " ").strip()
def extract_reference_scd_timestamps(
reference_text: str,
token_timestamps: Optional[List[Tuple[str, float]]] = None,
) -> List[float]:
"""
Extract reference [SCD] timestamps from ground truth.
Args:
reference_text: Reference text with [SCD] tokens
token_timestamps: Optional list of (token, timestamp) pairs
Returns:
List of timestamps where [SCD] tokens occur
"""
if token_timestamps is not None:
# Use provided timestamps
return [ts for token, ts in token_timestamps if token == "[SCD]"]
else:
# If no timestamps available, return empty list
# In practice, you would extract this from supervision metadata
logging.warning("No token timestamps provided for reference [SCD] extraction")
return []
def format_scd_output(
utt_id: str,
prediction: SCDPrediction,
include_timestamps: bool = True,
) -> str:
"""
Format [SCD] prediction output for logging or saving.
Args:
utt_id: Utterance ID
prediction: SCDPrediction object
include_timestamps: Whether to include detailed timestamps
Returns:
Formatted string
"""
lines = [
f"Utterance: {utt_id}",
f"Hypothesis: {prediction.text}",
f"Clean text: {prediction.text_clean}",
f"Total [SCD] detected: {len(prediction.scd_tokens)}",
]
if include_timestamps and prediction.scd_tokens:
lines.append("SCD Timestamps:")
for scd in prediction.scd_tokens:
lines.append(
f" {scd.timestamp:.3f}s (frame {scd.frame_index}, "
f"token position {scd.token_index})"
)
return "\n".join(lines)
def create_rttm_from_scd(
utt_id: str,
scd_tokens: List[SCDToken],
duration: float,
default_speaker: str = "SPEAKER1",
) -> List[str]:
"""
Create RTTM-style output from [SCD] tokens for speaker diarization.
RTTM format: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA>
Args:
utt_id: Utterance ID
scd_tokens: List of [SCD] tokens with timestamps
duration: Total duration of utterance
default_speaker: Default speaker label
Returns:
List of RTTM-formatted strings
"""
rttm_lines = []
if not scd_tokens:
# Single speaker for entire utterance
rttm_lines.append(
f"SPEAKER {utt_id} 1 0.000 {duration:.3f} <NA> <NA> {default_speaker} <NA> <NA>"
)
return rttm_lines
# Sort by timestamp
scd_tokens_sorted = sorted(scd_tokens, key=lambda x: x.timestamp)
# Create segments between [SCD] tokens
prev_time = 0.0
speaker_idx = 1
for scd in scd_tokens_sorted:
segment_duration = scd.timestamp - prev_time
if segment_duration > 0:
speaker = f"SPEAKER{speaker_idx}"
rttm_lines.append(
f"SPEAKER {utt_id} 1 {prev_time:.3f} {segment_duration:.3f} "
f"<NA> <NA> {speaker} <NA> <NA>"
)
speaker_idx += 1
prev_time = scd.timestamp
# Last segment after final [SCD]
final_duration = duration - prev_time
if final_duration > 0:
speaker = f"SPEAKER{speaker_idx}"
rttm_lines.append(
f"SPEAKER {utt_id} 1 {prev_time:.3f} {final_duration:.3f} "
f"<NA> <NA> {speaker} <NA> <NA>"
)
return rttm_lines
class SCDEvaluator:
"""Evaluator for [SCD] token detection performance."""
def __init__(self, tolerance_sec: float = 1.0):
"""
Initialize evaluator.
Args:
tolerance_sec: Time tolerance for [SCD] matching (seconds)
"""
self.tolerance_sec = tolerance_sec
self.total_tp = 0
self.total_fp = 0
self.total_fn = 0
self.num_utterances = 0
def add_utterance(
self,
pred_scd_tokens: List[SCDToken],
ref_scd_timestamps: List[float],
):
"""
Add one utterance evaluation result.
Args:
pred_scd_tokens: Predicted [SCD] tokens
ref_scd_timestamps: Reference [SCD] timestamps
"""
tp, fp, fn = align_predicted_and_reference_scd(
pred_scd_tokens, ref_scd_timestamps, self.tolerance_sec
)
self.total_tp += tp
self.total_fp += fp
self.total_fn += fn
self.num_utterances += 1
def get_metrics(self) -> Dict[str, float]:
"""
Get overall [SCD] detection metrics.
Returns:
Dictionary with precision, recall, F1, and counts
"""
metrics = calculate_scd_metrics(self.total_tp, self.total_fp, self.total_fn)
metrics["num_utterances"] = self.num_utterances
return metrics
def reset(self):
"""Reset evaluator statistics."""
self.total_tp = 0
self.total_fp = 0
self.total_fn = 0
self.num_utterances = 0
def log_scd_statistics(predictions: List[SCDPrediction]):
"""
Log statistics about [SCD] predictions.
Args:
predictions: List of SCDPrediction objects
"""
total_scd = sum(len(p.scd_tokens) for p in predictions)
utts_with_scd = sum(1 for p in predictions if len(p.scd_tokens) > 0)
avg_scd_per_utt = total_scd / len(predictions) if predictions else 0
scd_counts = [len(p.scd_tokens) for p in predictions if len(p.scd_tokens) > 0]
max_scd = max(scd_counts) if scd_counts else 0
min_scd = min(scd_counts) if scd_counts else 0
logging.info("=== [SCD] Token Statistics ===")
logging.info(f"Total utterances: {len(predictions)}")
logging.info(f"Utterances with [SCD]: {utts_with_scd} ({utts_with_scd/len(predictions)*100:.1f}%)")
logging.info(f"Total [SCD] tokens: {total_scd}")
logging.info(f"Average [SCD] per utterance: {avg_scd_per_utt:.2f}")
logging.info(f"Min/Max [SCD] in utterance: {min_scd}/{max_scd}")