streaming-speech-translation / src /nmt /streaming_segmenter.py
pltobing's picture
Fix ASR-NMT and TTS on the server side. Still need to fix on the client side for playback.
56cdd5d
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Streaming ASR Output Segmenter for Translation Pipeline.
Description
-----------
Segments streaming ASR output into optimal chunks for NMT translation.
Designed for NVIDIA NeMo cache-aware streaming ASR feeding TranslateGemma
via llama.cpp on CPU.
Implements:
- Word-count based segmentation (max_words with hold_back).
- Punctuation boundary detection (sentence and clause punctuation).
- Simple text buffer with split-based word counting.
Not yet wired into the active code path (defined in config but unused):
- Pause detection (use_pause_detection, pause_threshold_ms).
- ASR FINAL hypothesis integration (honor_asr_final).
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Tuple
log = logging.getLogger(__name__)
class BoundaryType(Enum):
"""Types of segment boundaries that can trigger an emit."""
FINAL = "final"
PUNCTUATION = "punctuation"
PAUSE = "pause"
MAX_TOKENS = "max_tokens"
FORCED = "forced"
@dataclass
class ASRToken:
"""Represents a single ASR token or word with optional metadata."""
text: str
timestamp: float = 0.0
confidence: float = 1.0
is_final: bool = False
@dataclass
class SegmentResult:
"""
Metadata for a single emitted segment.
Attributes
----------
text :
Concatenated segment text.
tokens :
Tokens contributing to this segment.
boundary_type :
The boundary condition that triggered this segment.
start_time :
Start timestamp of the segment (seconds).
end_time :
End timestamp of the segment (seconds).
token_count :
Number of tokens in the segment (derived from `tokens`).
"""
text: str
tokens: List[ASRToken]
boundary_type: BoundaryType
start_time: float
end_time: float
token_count: int = 0
def __post_init__(self) -> None:
self.token_count = len(self.tokens)
@dataclass
class SegmenterConfig:
"""
Configuration for the streaming segmenter.
Parameters
----------
max_words :
Maximum words in a segment before triggering MAX_TOKENS.
min_words :
Minimum words required before punctuation can trigger a segment.
hold_back :
Number of words held back when cutting on MAX_TOKENS to absorb
ASR tail instability.
sentence_punct :
Characters considered strong sentence punctuation.
clause_punct :
Characters considered clause-level punctuation.
use_punctuation :
Whether punctuation-based boundaries are enabled.
pause_threshold_ms :
Pause duration in milliseconds for pause-based segmentation
(currently not wired into the main path).
use_pause_detection :
Whether pause-based segmentation is enabled (currently unused).
honor_asr_final :
Whether to honor ASR FINAL hypotheses (currently unused).
"""
max_words: int = 5
min_words: int = 3
hold_back: int = 2
sentence_punct: str = ".!?"
clause_punct: str = ",;:"
use_punctuation: bool = True
pause_threshold_ms: float = 700
use_pause_detection: bool = True
honor_asr_final: bool = True
class StreamingSegmenter:
"""
Segment streaming ASR tokens for optimal NMT chunking.
Uses word-count based segmentation with punctuation boundary detection.
Accumulates tokens in a plain text buffer and triggers emission when
a punctuation boundary or max_words threshold is reached.
Boundary priority: sentence punctuation > clause punctuation > max_words.
A hold-back buffer absorbs ASR tail instability on max_words triggers.
Parameters
----------
config :
Segmenter configuration. Defaults to :class:`SegmenterConfig`.
"""
def __init__(self, config: Optional[SegmenterConfig] = None) -> None:
self.config: SegmenterConfig = config or SegmenterConfig()
# Text buffer state (legacy token buffer is kept for compatibility).
self.text_buffer: str = ""
self.text_split: List[str] = []
self.buffer: List[ASRToken] = []
self.last_token_time: Optional[float] = None
self.segments_emitted: List[str] = []
self._all_punct = set(self.config.sentence_punct + self.config.clause_punct)
self._strong_punct = set(self.config.sentence_punct)
# ─── Public API ─────────────────────────────────────────────────────────
def add_token(self, new_text: str, new_timestamp: float) -> Optional[str]:
"""
Add an ASR text delta and return a segment if a boundary is triggered.
Parameters
----------
new_text :
Incremental ASR text token/delta.
new_timestamp :
Current timestamp (seconds) for pause detection.
Returns
-------
str or None
Emitted text segment if a boundary was triggered; otherwise ``None``.
"""
# Detect punctuation in the new text.
new_text = new_text.rstrip()
new_text_punct = ""
for punct in self._strong_punct:
if punct in new_text:
new_text_punct = punct
break
if not new_text_punct:
for punct in self.config.clause_punct:
if punct in new_text:
new_text_punct = punct
break
log.debug(f"add-token text_buffer_before: {self.text_buffer!r}")
log.debug(f"add-token text_split_before: {self.text_split!r}")
# Append to text buffer, stripping leading space on first token.
if self.text_buffer:
self.text_buffer += new_text
else:
if new_text and new_text[0] == " ":
self.text_buffer = new_text[1:]
else:
self.text_buffer = new_text
self.text_split = self.text_buffer.split(" ")
log.debug(f"add-token new_text: {new_text!r}")
log.debug(f"add-token new_text_punct: {new_text_punct!r}")
log.debug(f"add-token text_buffer_after: {self.text_buffer!r}")
log.debug(f"add-token text_split_after: {self.text_split!r}")
self.last_token_time = time.time()
boundary, index = self._check_boundary(new_text_punct)
log.debug(f"boundary: {boundary}, index: {index}")
if boundary:
return self._emit(boundary, index)
return None
def flush(self) -> Optional[str]:
"""
Force-flush remaining buffer contents as a segment.
Returns
-------
str or None
Remaining buffered text, or ``None`` if the buffer is empty.
"""
if self.text_split:
text = " ".join(t for t in self.text_split)
self.reset()
return text
return None
def get_buffer_text(self) -> str:
"""Return current legacy token buffer content as a string."""
return " ".join(t.text for t in self.buffer)
def get_buffer_size(self) -> int:
"""Return number of token objects in the legacy buffer."""
return len(self.buffer)
def get_stats(self) -> Dict[str, int]:
"""
Return basic segmentation statistics.
Returns
-------
dict
``segments_emitted`` – count of segments emitted so far.
``buffered_tokens`` – number of words currently buffered.
"""
n = len(self.segments_emitted)
return {
"segments_emitted": n,
"buffered_tokens": len(self.text_split),
}
def reset(self) -> None:
"""Reset all buffer and timing state."""
self.buffer.clear()
self.text_split.clear()
self.text_buffer = ""
self.last_token_time = None
self.segments_emitted.clear()
# ─── Internals ─────────────────────────────────────────────────────────
def _check_boundary(self, punct_check: str = "") -> Tuple[Optional[BoundaryType], int]:
"""
Check whether a segment boundary should be triggered.
Parameters
----------
punct_check :
Punctuation character found in the latest token, or empty string.
Returns
-------
(BoundaryType or None, int)
Boundary type and split index, or (None, -1) if no boundary.
"""
buf_len = len(self.text_split)
# Punctuation-based boundary.
if self.config.use_punctuation and punct_check and buf_len >= self.config.min_words:
if punct_check not in self.text_buffer:
return None, -1
index = -1
for i in range(buf_len - 1, -1, -1):
if punct_check in self.text_split[i]:
index = i
break
if index < 0:
return None, -1
return BoundaryType.PUNCTUATION, index
# Max-token boundary.
if buf_len >= self.config.max_words + self.config.hold_back:
return BoundaryType.MAX_TOKENS, -1
return None, -1
def _emit(self, boundary_type: BoundaryType, index_buffer_end: int = -1) -> str:
"""
Emit a segment from the buffer.
Parameters
----------
boundary_type :
The boundary condition that triggered this emit.
index_buffer_end :
Word index (inclusive) at which to cut the buffer for punctuation
boundaries. Ignored for MAX_TOKENS (uses max_words cut).
Returns
-------
str
The emitted text segment (may be empty if misconfigured).
"""
if boundary_type == BoundaryType.MAX_TOKENS and self.config.hold_back > 0:
if len(self.text_split) < self.config.max_words + self.config.hold_back:
log.warning(
f"_emit MAX_TOKENS: buffer too short "
f"({len(self.text_split)} < "
f"{self.config.max_words + self.config.hold_back}), returning empty"
)
return ""
cut = self.config.max_words
emit_tokens = self.text_split[:cut]
log.debug(f"_emit text_split_before: {self.text_split!r}")
log.debug(f"_emit text_buffer_before: {self.text_buffer!r}")
self.text_split = self.text_split[cut:]
self.text_buffer = " ".join(t for t in self.text_split)
log.debug(f"_emit emit_tokens: {emit_tokens!r}")
log.debug(f"_emit text_split_after: {self.text_split!r}")
log.debug(f"_emit text_buffer_after: {self.text_buffer!r}")
else:
if index_buffer_end < 0:
log.warning("_emit punctuation: index_buffer_end < 0, returning empty")
return ""
emit_tokens = self.text_split[: index_buffer_end + 1]
log.debug(f"_emit punct {index_buffer_end} text_split_before: {self.text_split!r}")
log.debug(f"_emit punct {index_buffer_end} text_buffer_before: {self.text_buffer!r}")
if index_buffer_end < (len(self.text_split) - 1):
self.text_split = self.text_split[index_buffer_end + 1 :]
self.text_buffer = " ".join(t for t in self.text_split)
else:
self.reset()
log.debug(f"_emit punct {index_buffer_end} emit_tokens: {emit_tokens!r}")
log.debug(f"_emit punct {index_buffer_end} text_split_after: {self.text_split!r}")
log.debug(f"_emit punct {index_buffer_end} text_buffer_after: {self.text_buffer!r}")
text = " ".join(t for t in emit_tokens)
self.segments_emitted.append(text)
return text