| import re |
| import collections |
| import logging |
| from difflib import SequenceMatcher |
| from itertools import chain |
| from dataclasses import dataclass, field |
| from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal |
| from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE |
| import numpy as np |
| from enum import Enum |
| from itertools import chain |
| logger = logging.getLogger("TranscriptionStrategy") |
|
|
|
|
| class SplitMode(Enum): |
| PUNCTUATION = "punctuation" |
| PAUSE = "pause" |
| END = "end" |
|
|
|
|
|
|
| @dataclass |
| class TranscriptResult: |
| seg_id: int = 0 |
| cut_index: int = 0 |
| is_end_sentence: bool = False |
| context: str = "" |
|
|
| def partial(self): |
| return not self.is_end_sentence |
|
|
| @dataclass |
| class TranscriptToken: |
| """表示一个转录片段,包含文本和时间信息""" |
| text: str |
| t0: float |
| t1: float |
|
|
| def is_punctuation(self): |
| """检查文本是否包含标点符号""" |
| return REGEX_MARKERS.search(self.text) is not None |
| |
| def is_end(self): |
| """检查文本是否为句子结束标记""" |
| return SENTENCE_END_PATTERN.search(self.text) is not None |
| |
| def is_pause(self): |
| """检查文本是否为暂停标记""" |
| return PAUSEE_END_PATTERN.search(self.text) is not None |
|
|
| def buffer_index(self) -> int: |
| return max(int(self.t1 / 100 * SAMPLE_RATE) - 300, 0) |
|
|
| @dataclass |
| class TranscriptChunk: |
| """表示一组转录片段,支持分割和比较操作""" |
| separator: str = "" |
| items: list[TranscriptToken] = field(default_factory=list) |
|
|
| @staticmethod |
| def _calculate_similarity(text1: str, text2: str) -> float: |
| """计算两段文本的相似度""" |
| return SequenceMatcher(None, text1, text2).ratio() |
|
|
| def split_by(self, mode: SplitMode) -> list['TranscriptChunk']: |
| """根据文本中的标点符号分割片段列表""" |
| if mode == SplitMode.PUNCTUATION: |
| indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()] |
| elif mode == SplitMode.PAUSE: |
| indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()] |
| elif mode == SplitMode.END: |
| indexes = [i for i, seg in enumerate(self.items) if seg.is_end()] |
| else: |
| raise ValueError(f"Unsupported mode: {mode}") |
|
|
| |
| cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)] |
| return [ |
| TranscriptChunk(items=self.items[start:end], separator=self.separator) |
| for start, end in zip(cut_points, cut_points[1:]) |
| ] |
|
|
| |
| def get_split_first_rest(self, mode: SplitMode): |
| chunks = self.split_by(mode) |
| fisrt_chunk = chunks[0] if chunks else self |
| rest_chunks = chunks[1:] if chunks else None |
| return fisrt_chunk, rest_chunks |
| |
| def puncation_numbers(self) -> int: |
| """计算片段中标点符号的数量""" |
| return sum(1 for seg in self.items if seg.is_punctuation()) |
|
|
| def length(self) -> int: |
| """返回片段列表的长度""" |
| return len(self.items) |
|
|
| def join(self) -> str: |
| """将片段连接为一个字符串""" |
| return self.separator.join(seg.text for seg in self.items) |
| |
| def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float: |
| """比较当前片段与另一个片段的相似度""" |
| if not chunk: |
| return 0 |
| |
| score = self._calculate_similarity(self.join(), chunk.join()) |
| logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}") |
| return score |
| |
| def has_punctuation(self) -> bool: |
| return any(seg.is_punctuation() for seg in self.items) |
| |
| def get_buffer_index(self) -> int: |
| return self.items[-1].buffer_index() |
| |
| def is_end_sentence(self) ->bool: |
| return self.items[-1].is_end() |
| |
|
|
| class TranscriptHistory: |
| """管理转录片段的历史记录""" |
|
|
| def __init__(self) -> None: |
| self.history = collections.deque(maxlen=2) |
| |
| def add(self, chunk: TranscriptChunk): |
| """添加新的片段到历史记录""" |
| self.history.appendleft(chunk) |
|
|
| def previous_chunk(self) -> Optional[TranscriptChunk]: |
| """获取上一个片段(如果存在)""" |
| return self.history[1] if len(self.history) == 2 else None |
|
|
| def lastest_chunk(self): |
| """获取最后一个片段""" |
| return self.history[-1] |
| |
| def clear(self): |
| self.history.clear() |
|
|
| class TranscriptBuffer: |
| """ |
| 管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落 |
| |
| |-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --| |
| |
| 管理 pending -> line -> paragraph 的缓冲逻辑 |
| |
| """ |
|
|
| def __init__(self, separator): |
| self._segments: List[str] = collections.deque(maxlen=2) |
| self._sentences: List[str] = [] |
| self._buffer: str = "" |
| self._current_seg_id: int = 0 |
| self._separator = separator |
|
|
| def get_seg_id(self) -> int: |
| return self._current_seg_id |
| |
| def next_seg_id(self) -> int: |
| return self._current_seg_id + 1 |
| |
| def update_pending_text(self, text: str) -> None: |
| """更新临时缓冲字符串""" |
| self._buffer = text |
|
|
| def commit_line(self) -> None: |
| """将缓冲字符串提交为短句""" |
| if self._buffer: |
| self._sentences.append(self._buffer) |
| self._buffer = "" |
|
|
| def commit_paragraph(self) -> None: |
| """ |
| 提交当前短句为完整段落(如句子结束) |
| |
| Args: |
| end_of_sentence: 是否为句子结尾(如检测到句号) |
| """ |
| if self._sentences: |
| self._segments.append("".join(self._sentences)) |
| self._sentences.clear() |
|
|
|
|
| def update_and_commit(self, stable_string: str, remaining_string:str, is_end_sentence=False): |
| |
| logger.debug(f"{self.__dict__}") |
| if is_end_sentence: |
| self.update_pending_text(stable_string) |
| self.commit_line() |
| current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text) |
| |
| self.update_pending_text(remaining_string) |
| if current_text_len >=20: |
| self.commit_paragraph() |
| self._current_seg_id += 1 |
| return True |
| else: |
| self.update_pending_text(stable_string) |
| self.commit_line() |
| self.update_pending_text(remaining_string) |
| return False |
|
|
| |
| @property |
| def un_commit_paragraph(self) -> str: |
| """当前短句组合""" |
| return "".join(self._sentences) |
|
|
| @property |
| def pending_text(self) -> str: |
| """当前缓冲内容""" |
| return self._buffer |
|
|
| @property |
| def latest_paragraph(self) -> str: |
| """最新确认的段落""" |
| return self._segments[-1] if self._segments else "" |
|
|
| @property |
| def current_not_commit_text(self) -> str: |
| return self.un_commit_paragraph + self.pending_text |
|
|
|
|
|
|
| class TranscriptStabilityAnalyzer: |
| def __init__(self, separator) -> None: |
| self._transcript_buffer = TranscriptBuffer(separator=separator) |
| self._transcript_history = TranscriptHistory() |
| self._separator = separator |
| logger.debug(f"Current separator: {self._separator}") |
|
|
| def merge_chunks(self, chunks: List[TranscriptChunk])->str: |
| return "".join(r.join() for r in chunks) |
| |
|
|
| def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]: |
| current = TranscriptChunk(items=current, separator=self._separator) |
| self._transcript_history.add(current) |
|
|
| prev = self._transcript_history.previous_chunk() |
| self._transcript_buffer.update_pending_text(current.join()) |
| if not prev: |
| yield TranscriptResult( |
| context=self._transcript_buffer.current_not_commit_text, |
| seg_id=self._transcript_buffer.get_seg_id() |
| ) |
| return |
|
|
| |
| if buffer_duration <= 12: |
| yield from self._handle_short_buffer(current, prev) |
| else: |
| yield from self._handle_long_buffer(current) |
|
|
|
|
| def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]: |
| curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION) |
| prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION) |
|
|
| |
| |
| |
| |
|
|
| if curr_first and prev_first: |
| |
| core = curr_first.compare(prev_first) |
| |
| if core >= 0.8: |
| yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence()) |
| return |
| |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| context=self._transcript_buffer.current_not_commit_text |
| ) |
|
|
|
|
| def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]: |
| chunks = curr.split_by(SplitMode.PUNCTUATION) |
| if len(chunks) > 2: |
| stable, remaining = chunks[:-2], chunks[-2:] |
| |
| |
| yield from self._yield_commit_results( |
| stable, remaining, is_end_sentence=True |
| ) |
| else: |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| context=self._transcript_buffer.current_not_commit_text |
| ) |
|
|
|
|
| def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]: |
| stable_str = stable_chunk.join() if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk) |
| remaining_str = self.merge_chunks(remaining_chunks) |
| frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index() |
| |
| prev_seg_id = self._transcript_buffer.get_seg_id() |
| commit_paragraph = self._transcript_buffer.update_and_commit(stable_str, remaining_str, is_end_sentence) |
| logger.debug(f"current buffer: {self._transcript_buffer.__dict__}") |
|
|
| if commit_paragraph: |
| |
| yield TranscriptResult( |
| seg_id=prev_seg_id, |
| cut_index=frame_cut_index, |
| context=self._transcript_buffer.latest_paragraph, |
| is_end_sentence=True |
| ) |
| if (context := self._transcript_buffer.current_not_commit_text.strip()): |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| context=context, |
| ) |
| else: |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| cut_index=frame_cut_index, |
| context=self._transcript_buffer.current_not_commit_text, |
| ) |
| |
|
|