File size: 12,373 Bytes
f724af4
da63a34
 
 
 
 
 
 
 
f724af4
da63a34
f724af4
da63a34
 
 
f724af4
da63a34
 
f724af4
da63a34
 
 
 
f724af4
da63a34
 
 
f724af4
 
da63a34
 
f724af4
56cdd5d
f724af4
 
da63a34
f724af4
 
 
 
 
da63a34
 
f724af4
 
 
 
 
 
 
 
 
da63a34
 
f724af4
 
 
 
 
 
 
 
da63a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f724af4
 
 
 
 
 
 
da63a34
f724af4
 
 
 
 
da63a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f724af4
 
 
 
 
 
 
 
 
 
 
 
da63a34
 
f724af4
 
 
 
 
da63a34
 
f724af4
 
 
da63a34
 
f724af4
 
da63a34
 
 
 
f724af4
 
 
da63a34
f724af4
 
da63a34
f724af4
 
 
da63a34
 
f724af4
da63a34
 
f724af4
 
 
da63a34
f724af4
da63a34
f724af4
 
 
 
 
da63a34
f724af4
da63a34
f724af4
 
da63a34
f724af4
 
 
 
da63a34
f724af4
 
 
 
 
da63a34
 
f724af4
da63a34
 
f724af4
 
da63a34
f724af4
da63a34
 
 
f724af4
 
da63a34
 
 
 
f724af4
56cdd5d
f724af4
 
da63a34
f724af4
 
 
 
 
da63a34
 
f724af4
 
 
 
da63a34
f724af4
da63a34
f724af4
 
 
da63a34
f724af4
 
da63a34
f724af4
 
 
 
 
 
da63a34
 
 
f724af4
 
 
 
da63a34
 
f724af4
 
 
 
 
 
 
da63a34
 
f724af4
 
 
 
 
 
da63a34
 
f724af4
da63a34
 
f724af4
 
 
da63a34
f724af4
 
 
 
da63a34
f724af4
 
 
 
da63a34
0c397a9
f724af4
 
da63a34
f724af4
 
 
 
 
 
 
 
 
da63a34
f724af4
 
 
 
 
 
da63a34
 
f724af4
 
 
da63a34
f724af4
da63a34
f724af4
 
 
 
 
 
da63a34
f724af4
 
 
 
da63a34
 
f724af4
 
 
 
 
da63a34
 
f724af4
 
da63a34
 
 
f724af4
 
 
 
da63a34
f724af4
0c397a9
 
da63a34
f724af4
da63a34
f724af4
 
 
da63a34
0c397a9
 
 
f724af4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Streaming ASR Output Segmenter for Translation Pipeline.

Description
-----------
Segments streaming ASR output into optimal chunks for NMT translation.

Designed for NVIDIA NeMo cache-aware streaming ASR feeding TranslateGemma
via llama.cpp on CPU.

Implements:
- Word-count based segmentation (max_words with hold_back).
- Punctuation boundary detection (sentence and clause punctuation).
- Simple text buffer with split-based word counting.

Not yet wired into the active code path (defined in config but unused):
- Pause detection (use_pause_detection, pause_threshold_ms).
- ASR FINAL hypothesis integration (honor_asr_final).
"""

from __future__ import annotations

import logging
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Tuple

log = logging.getLogger(__name__)


class BoundaryType(Enum):
    """Types of segment boundaries that can trigger an emit."""

    FINAL = "final"
    PUNCTUATION = "punctuation"
    PAUSE = "pause"
    MAX_TOKENS = "max_tokens"
    FORCED = "forced"


@dataclass
class ASRToken:
    """Represents a single ASR token or word with optional metadata."""

    text: str
    timestamp: float = 0.0
    confidence: float = 1.0
    is_final: bool = False


@dataclass
class SegmentResult:
    """
    Metadata for a single emitted segment.

    Attributes
    ----------
    text :
        Concatenated segment text.
    tokens :
        Tokens contributing to this segment.
    boundary_type :
        The boundary condition that triggered this segment.
    start_time :
        Start timestamp of the segment (seconds).
    end_time :
        End timestamp of the segment (seconds).
    token_count :
        Number of tokens in the segment (derived from `tokens`).
    """

    text: str
    tokens: List[ASRToken]
    boundary_type: BoundaryType
    start_time: float
    end_time: float
    token_count: int = 0

    def __post_init__(self) -> None:
        self.token_count = len(self.tokens)


@dataclass
class SegmenterConfig:
    """
    Configuration for the streaming segmenter.

    Parameters
    ----------
    max_words :
        Maximum words in a segment before triggering MAX_TOKENS.
    min_words :
        Minimum words required before punctuation can trigger a segment.
    hold_back :
        Number of words held back when cutting on MAX_TOKENS to absorb
        ASR tail instability.
    sentence_punct :
        Characters considered strong sentence punctuation.
    clause_punct :
        Characters considered clause-level punctuation.
    use_punctuation :
        Whether punctuation-based boundaries are enabled.
    pause_threshold_ms :
        Pause duration in milliseconds for pause-based segmentation
        (currently not wired into the main path).
    use_pause_detection :
        Whether pause-based segmentation is enabled (currently unused).
    honor_asr_final :
        Whether to honor ASR FINAL hypotheses (currently unused).
    """

    max_words: int = 5
    min_words: int = 3
    hold_back: int = 2
    sentence_punct: str = ".!?"
    clause_punct: str = ",;:"
    use_punctuation: bool = True
    pause_threshold_ms: float = 700
    use_pause_detection: bool = True
    honor_asr_final: bool = True


class StreamingSegmenter:
    """
    Segment streaming ASR tokens for optimal NMT chunking.

    Uses word-count based segmentation with punctuation boundary detection.
    Accumulates tokens in a plain text buffer and triggers emission when
    a punctuation boundary or max_words threshold is reached.

    Boundary priority: sentence punctuation > clause punctuation > max_words.
    A hold-back buffer absorbs ASR tail instability on max_words triggers.

    Parameters
    ----------
    config :
        Segmenter configuration. Defaults to :class:`SegmenterConfig`.
    """

    def __init__(self, config: Optional[SegmenterConfig] = None) -> None:
        self.config: SegmenterConfig = config or SegmenterConfig()

        # Text buffer state (legacy token buffer is kept for compatibility).
        self.text_buffer: str = ""
        self.text_split: List[str] = []
        self.buffer: List[ASRToken] = []

        self.last_token_time: Optional[float] = None
        self.segments_emitted: List[str] = []

        self._all_punct = set(self.config.sentence_punct + self.config.clause_punct)
        self._strong_punct = set(self.config.sentence_punct)

    # ─── Public API ─────────────────────────────────────────────────────────

    def add_token(self, new_text: str, new_timestamp: float) -> Optional[str]:
        """
        Add an ASR text delta and return a segment if a boundary is triggered.

        Parameters
        ----------
        new_text :
            Incremental ASR text token/delta.
        new_timestamp :
            Current timestamp (seconds) for pause detection.

        Returns
        -------
        str or None
            Emitted text segment if a boundary was triggered; otherwise ``None``.
        """
        # Detect punctuation in the new text.
        new_text = new_text.rstrip()
        new_text_punct = ""

        for punct in self._strong_punct:
            if punct in new_text:
                new_text_punct = punct
                break
        if not new_text_punct:
            for punct in self.config.clause_punct:
                if punct in new_text:
                    new_text_punct = punct
                    break

        log.debug(f"add-token text_buffer_before: {self.text_buffer!r}")
        log.debug(f"add-token text_split_before: {self.text_split!r}")

        # Append to text buffer, stripping leading space on first token.
        if self.text_buffer:
            self.text_buffer += new_text
        else:
            if new_text and new_text[0] == " ":
                self.text_buffer = new_text[1:]
            else:
                self.text_buffer = new_text

        self.text_split = self.text_buffer.split(" ")

        log.debug(f"add-token new_text: {new_text!r}")
        log.debug(f"add-token new_text_punct: {new_text_punct!r}")
        log.debug(f"add-token text_buffer_after: {self.text_buffer!r}")
        log.debug(f"add-token text_split_after: {self.text_split!r}")

        self.last_token_time = time.time()

        boundary, index = self._check_boundary(new_text_punct)
        log.debug(f"boundary: {boundary}, index: {index}")
        if boundary:
            return self._emit(boundary, index)
        return None

    def flush(self) -> Optional[str]:
        """
        Force-flush remaining buffer contents as a segment.

        Returns
        -------
        str or None
            Remaining buffered text, or ``None`` if the buffer is empty.
        """
        if self.text_split:
            text = " ".join(t for t in self.text_split)
            self.reset()
            return text
        return None

    def get_buffer_text(self) -> str:
        """Return current legacy token buffer content as a string."""
        return " ".join(t.text for t in self.buffer)

    def get_buffer_size(self) -> int:
        """Return number of token objects in the legacy buffer."""
        return len(self.buffer)

    def get_stats(self) -> Dict[str, int]:
        """
        Return basic segmentation statistics.

        Returns
        -------
        dict
            ``segments_emitted`` – count of segments emitted so far.
            ``buffered_tokens`` – number of words currently buffered.
        """
        n = len(self.segments_emitted)
        return {
            "segments_emitted": n,
            "buffered_tokens": len(self.text_split),
        }

    def reset(self) -> None:
        """Reset all buffer and timing state."""
        self.buffer.clear()
        self.text_split.clear()
        self.text_buffer = ""
        self.last_token_time = None
        self.segments_emitted.clear()

    # ─── Internals ─────────────────────────────────────────────────────────

    def _check_boundary(self, punct_check: str = "") -> Tuple[Optional[BoundaryType], int]:
        """
        Check whether a segment boundary should be triggered.

        Parameters
        ----------
        punct_check :
            Punctuation character found in the latest token, or empty string.

        Returns
        -------
        (BoundaryType or None, int)
            Boundary type and split index, or (None, -1) if no boundary.
        """
        buf_len = len(self.text_split)

        # Punctuation-based boundary.
        if self.config.use_punctuation and punct_check and buf_len >= self.config.min_words:
            if punct_check not in self.text_buffer:
                return None, -1

            index = -1
            for i in range(buf_len - 1, -1, -1):
                if punct_check in self.text_split[i]:
                    index = i
                    break
            if index < 0:
                return None, -1
            return BoundaryType.PUNCTUATION, index

        # Max-token boundary.
        if buf_len >= self.config.max_words + self.config.hold_back:
            return BoundaryType.MAX_TOKENS, -1

        return None, -1

    def _emit(self, boundary_type: BoundaryType, index_buffer_end: int = -1) -> str:
        """
        Emit a segment from the buffer.

        Parameters
        ----------
        boundary_type :
            The boundary condition that triggered this emit.
        index_buffer_end :
            Word index (inclusive) at which to cut the buffer for punctuation
            boundaries. Ignored for MAX_TOKENS (uses max_words cut).

        Returns
        -------
        str
            The emitted text segment (may be empty if misconfigured).
        """
        if boundary_type == BoundaryType.MAX_TOKENS and self.config.hold_back > 0:
            if len(self.text_split) < self.config.max_words + self.config.hold_back:
                log.warning(
                    f"_emit MAX_TOKENS: buffer too short "
                    f"({len(self.text_split)} < "
                    f"{self.config.max_words + self.config.hold_back}), returning empty"
                )
                return ""
            cut = self.config.max_words
            emit_tokens = self.text_split[:cut]
            log.debug(f"_emit text_split_before: {self.text_split!r}")
            log.debug(f"_emit text_buffer_before: {self.text_buffer!r}")
            self.text_split = self.text_split[cut:]
            self.text_buffer = " ".join(t for t in self.text_split)
            log.debug(f"_emit emit_tokens: {emit_tokens!r}")
            log.debug(f"_emit text_split_after: {self.text_split!r}")
            log.debug(f"_emit text_buffer_after: {self.text_buffer!r}")
        else:
            if index_buffer_end < 0:
                log.warning("_emit punctuation: index_buffer_end < 0, returning empty")
                return ""

            emit_tokens = self.text_split[: index_buffer_end + 1]
            log.debug(f"_emit punct {index_buffer_end} text_split_before: {self.text_split!r}")
            log.debug(f"_emit punct {index_buffer_end} text_buffer_before: {self.text_buffer!r}")

            if index_buffer_end < (len(self.text_split) - 1):
                self.text_split = self.text_split[index_buffer_end + 1 :]
                self.text_buffer = " ".join(t for t in self.text_split)
            else:
                self.reset()

            log.debug(f"_emit punct {index_buffer_end} emit_tokens: {emit_tokens!r}")
            log.debug(f"_emit punct {index_buffer_end} text_split_after: {self.text_split!r}")
            log.debug(f"_emit punct {index_buffer_end} text_buffer_after: {self.text_buffer!r}")

        text = " ".join(t for t in emit_tokens)
        self.segments_emitted.append(text)
        return text