| |
| """ |
| Utility functions for [SCD] token extraction, alignment, and timestamp calculation. |
| |
| Based on Tokenverse paper approach: |
| - Extract predicted task tokens from hypothesis |
| - Align tokens in time domain using acoustic frame indices |
| - Calculate timestamps using XLSR frame duration (25ms) and stride (20ms) |
| """ |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import List, Tuple, Optional, Dict |
| import torch |
|
|
|
|
| @dataclass |
| class SCDToken: |
| """Represents a detected [SCD] token with timing information.""" |
| token: str |
| frame_index: int |
| timestamp: float |
| token_index: int |
| confidence: Optional[float] = None |
|
|
|
|
| @dataclass |
| class SCDPrediction: |
| """Complete prediction result with [SCD] tokens and timestamps.""" |
| text: str |
| text_clean: str |
| scd_tokens: List[SCDToken] |
| total_frames: int |
| duration: float |
|
|
|
|
| def extract_scd_timestamps( |
| tokens: List[str], |
| frame_indices: List[int], |
| frame_duration_ms: float = 25.0, |
| frame_stride_ms: float = 20.0, |
| ) -> List[SCDToken]: |
| """ |
| Extract [SCD] tokens with their timestamps from decoded token sequence. |
| |
| Args: |
| tokens: List of decoded tokens (including [SCD]) |
| frame_indices: Acoustic frame index for each token emission |
| frame_duration_ms: Duration of each XLSR frame in milliseconds |
| frame_stride_ms: Stride between consecutive frames in milliseconds |
| |
| Returns: |
| List of SCDToken objects with timing information |
| """ |
| scd_tokens = [] |
|
|
| for token_idx, (token, frame_idx) in enumerate(zip(tokens, frame_indices)): |
| if token == "[SCD]" or token == "▁[SCD]": |
| |
| timestamp = (frame_idx * frame_stride_ms) / 1000.0 |
|
|
| scd_token = SCDToken( |
| token="[SCD]", |
| frame_index=frame_idx, |
| timestamp=timestamp, |
| token_index=token_idx, |
| ) |
| scd_tokens.append(scd_token) |
|
|
| return scd_tokens |
|
|
|
|
| def align_predicted_and_reference_scd( |
| pred_scd_tokens: List[SCDToken], |
| ref_scd_timestamps: List[float], |
| tolerance_sec: float = 1.0, |
| ) -> Tuple[int, int, int]: |
| """ |
| Align predicted [SCD] tokens with reference timestamps. |
| |
| Args: |
| pred_scd_tokens: Predicted [SCD] tokens with timestamps |
| ref_scd_timestamps: Reference [SCD] timestamps (ground truth) |
| tolerance_sec: Time tolerance for considering a match (seconds) |
| |
| Returns: |
| Tuple of (true_positives, false_positives, false_negatives) |
| """ |
| pred_times = [t.timestamp for t in pred_scd_tokens] |
|
|
| tp = 0 |
| fp = 0 |
| fn = 0 |
|
|
| matched_refs = set() |
| matched_preds = set() |
|
|
| |
| for pred_idx, pred_time in enumerate(pred_times): |
| matched = False |
| for ref_idx, ref_time in enumerate(ref_scd_timestamps): |
| if ref_idx in matched_refs: |
| continue |
|
|
| time_diff = abs(pred_time - ref_time) |
| if time_diff <= tolerance_sec: |
| tp += 1 |
| matched_refs.add(ref_idx) |
| matched_preds.add(pred_idx) |
| matched = True |
| break |
|
|
| |
| fp = len(pred_times) - len(matched_preds) |
|
|
| |
| fn = len(ref_scd_timestamps) - len(matched_refs) |
|
|
| return tp, fp, fn |
|
|
|
|
| def calculate_scd_metrics( |
| tp: int, fp: int, fn: int |
| ) -> Dict[str, float]: |
| """ |
| Calculate precision, recall, and F1 score for [SCD] detection. |
| |
| Args: |
| tp: True positives |
| fp: False positives |
| fn: False negatives |
| |
| Returns: |
| Dictionary with precision, recall, and f1 scores |
| """ |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 |
|
|
| return { |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "tp": tp, |
| "fp": fp, |
| "fn": fn, |
| } |
|
|
|
|
| def remove_scd_tokens(text: str) -> str: |
| """Remove [SCD] tokens from text for WER calculation.""" |
| return text.replace("[SCD]", "").replace(" ", " ").strip() |
|
|
|
|
| def extract_reference_scd_timestamps( |
| reference_text: str, |
| token_timestamps: Optional[List[Tuple[str, float]]] = None, |
| ) -> List[float]: |
| """ |
| Extract reference [SCD] timestamps from ground truth. |
| |
| Args: |
| reference_text: Reference text with [SCD] tokens |
| token_timestamps: Optional list of (token, timestamp) pairs |
| |
| Returns: |
| List of timestamps where [SCD] tokens occur |
| """ |
| if token_timestamps is not None: |
| |
| return [ts for token, ts in token_timestamps if token == "[SCD]"] |
| else: |
| |
| |
| logging.warning("No token timestamps provided for reference [SCD] extraction") |
| return [] |
|
|
|
|
| def format_scd_output( |
| utt_id: str, |
| prediction: SCDPrediction, |
| include_timestamps: bool = True, |
| ) -> str: |
| """ |
| Format [SCD] prediction output for logging or saving. |
| |
| Args: |
| utt_id: Utterance ID |
| prediction: SCDPrediction object |
| include_timestamps: Whether to include detailed timestamps |
| |
| Returns: |
| Formatted string |
| """ |
| lines = [ |
| f"Utterance: {utt_id}", |
| f"Hypothesis: {prediction.text}", |
| f"Clean text: {prediction.text_clean}", |
| f"Total [SCD] detected: {len(prediction.scd_tokens)}", |
| ] |
|
|
| if include_timestamps and prediction.scd_tokens: |
| lines.append("SCD Timestamps:") |
| for scd in prediction.scd_tokens: |
| lines.append( |
| f" {scd.timestamp:.3f}s (frame {scd.frame_index}, " |
| f"token position {scd.token_index})" |
| ) |
|
|
| return "\n".join(lines) |
|
|
|
|
| def create_rttm_from_scd( |
| utt_id: str, |
| scd_tokens: List[SCDToken], |
| duration: float, |
| default_speaker: str = "SPEAKER1", |
| ) -> List[str]: |
| """ |
| Create RTTM-style output from [SCD] tokens for speaker diarization. |
| |
| RTTM format: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA> |
| |
| Args: |
| utt_id: Utterance ID |
| scd_tokens: List of [SCD] tokens with timestamps |
| duration: Total duration of utterance |
| default_speaker: Default speaker label |
| |
| Returns: |
| List of RTTM-formatted strings |
| """ |
| rttm_lines = [] |
|
|
| if not scd_tokens: |
| |
| rttm_lines.append( |
| f"SPEAKER {utt_id} 1 0.000 {duration:.3f} <NA> <NA> {default_speaker} <NA> <NA>" |
| ) |
| return rttm_lines |
|
|
| |
| scd_tokens_sorted = sorted(scd_tokens, key=lambda x: x.timestamp) |
|
|
| |
| prev_time = 0.0 |
| speaker_idx = 1 |
|
|
| for scd in scd_tokens_sorted: |
| segment_duration = scd.timestamp - prev_time |
| if segment_duration > 0: |
| speaker = f"SPEAKER{speaker_idx}" |
| rttm_lines.append( |
| f"SPEAKER {utt_id} 1 {prev_time:.3f} {segment_duration:.3f} " |
| f"<NA> <NA> {speaker} <NA> <NA>" |
| ) |
| speaker_idx += 1 |
| prev_time = scd.timestamp |
|
|
| |
| final_duration = duration - prev_time |
| if final_duration > 0: |
| speaker = f"SPEAKER{speaker_idx}" |
| rttm_lines.append( |
| f"SPEAKER {utt_id} 1 {prev_time:.3f} {final_duration:.3f} " |
| f"<NA> <NA> {speaker} <NA> <NA>" |
| ) |
|
|
| return rttm_lines |
|
|
|
|
| class SCDEvaluator: |
| """Evaluator for [SCD] token detection performance.""" |
|
|
| def __init__(self, tolerance_sec: float = 1.0): |
| """ |
| Initialize evaluator. |
| |
| Args: |
| tolerance_sec: Time tolerance for [SCD] matching (seconds) |
| """ |
| self.tolerance_sec = tolerance_sec |
| self.total_tp = 0 |
| self.total_fp = 0 |
| self.total_fn = 0 |
| self.num_utterances = 0 |
|
|
| def add_utterance( |
| self, |
| pred_scd_tokens: List[SCDToken], |
| ref_scd_timestamps: List[float], |
| ): |
| """ |
| Add one utterance evaluation result. |
| |
| Args: |
| pred_scd_tokens: Predicted [SCD] tokens |
| ref_scd_timestamps: Reference [SCD] timestamps |
| """ |
| tp, fp, fn = align_predicted_and_reference_scd( |
| pred_scd_tokens, ref_scd_timestamps, self.tolerance_sec |
| ) |
|
|
| self.total_tp += tp |
| self.total_fp += fp |
| self.total_fn += fn |
| self.num_utterances += 1 |
|
|
| def get_metrics(self) -> Dict[str, float]: |
| """ |
| Get overall [SCD] detection metrics. |
| |
| Returns: |
| Dictionary with precision, recall, F1, and counts |
| """ |
| metrics = calculate_scd_metrics(self.total_tp, self.total_fp, self.total_fn) |
| metrics["num_utterances"] = self.num_utterances |
| return metrics |
|
|
| def reset(self): |
| """Reset evaluator statistics.""" |
| self.total_tp = 0 |
| self.total_fp = 0 |
| self.total_fn = 0 |
| self.num_utterances = 0 |
|
|
|
|
| def log_scd_statistics(predictions: List[SCDPrediction]): |
| """ |
| Log statistics about [SCD] predictions. |
| |
| Args: |
| predictions: List of SCDPrediction objects |
| """ |
| total_scd = sum(len(p.scd_tokens) for p in predictions) |
| utts_with_scd = sum(1 for p in predictions if len(p.scd_tokens) > 0) |
| avg_scd_per_utt = total_scd / len(predictions) if predictions else 0 |
|
|
| scd_counts = [len(p.scd_tokens) for p in predictions if len(p.scd_tokens) > 0] |
| max_scd = max(scd_counts) if scd_counts else 0 |
| min_scd = min(scd_counts) if scd_counts else 0 |
|
|
| logging.info("=== [SCD] Token Statistics ===") |
| logging.info(f"Total utterances: {len(predictions)}") |
| logging.info(f"Utterances with [SCD]: {utts_with_scd} ({utts_with_scd/len(predictions)*100:.1f}%)") |
| logging.info(f"Total [SCD] tokens: {total_scd}") |
| logging.info(f"Average [SCD] per utterance: {avg_scd_per_utt:.2f}") |
| logging.info(f"Min/Max [SCD] in utterance: {min_scd}/{max_scd}") |
|
|