|
|
| import collections |
| import logging |
| from difflib import SequenceMatcher |
| from itertools import chain |
| from dataclasses import dataclass, field |
| from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal |
| from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE |
| from enum import Enum |
| import wordninja |
| import config |
| import re |
| logger = logging.getLogger("TranscriptionStrategy") |
|
|
|
|
| class SplitMode(Enum): |
| PUNCTUATION = "punctuation" |
| PAUSE = "pause" |
| END = "end" |
|
|
|
|
|
|
| @dataclass |
| class TranscriptResult: |
| seg_id: int = 0 |
| cut_index: int = 0 |
| is_end_sentence: bool = False |
| context: str = "" |
|
|
| def partial(self): |
| return not self.is_end_sentence |
|
|
| @dataclass |
| class TranscriptToken: |
| """表示一个转录片段,包含文本和时间信息""" |
| text: str |
| t0: int |
| t1: int |
|
|
| def is_punctuation(self): |
| """检查文本是否包含标点符号""" |
| return REGEX_MARKERS.search(self.text.strip()) is not None |
|
|
| def is_end(self): |
| """检查文本是否为句子结束标记""" |
| return SENTENCE_END_PATTERN.search(self.text.strip()) is not None |
|
|
| def is_pause(self): |
| """检查文本是否为暂停标记""" |
| return PAUSEE_END_PATTERN.search(self.text.strip()) is not None |
|
|
| def buffer_index(self) -> int: |
| return int(self.t1 / 100 * SAMPLE_RATE) |
|
|
| @dataclass |
| class TranscriptChunk: |
| """表示一组转录片段,支持分割和比较操作""" |
| separator: str = "" |
| items: list[TranscriptToken] = field(default_factory=list) |
|
|
| @staticmethod |
| def _calculate_similarity(text1: str, text2: str) -> float: |
| """计算两段文本的相似度""" |
| return SequenceMatcher(None, text1, text2).ratio() |
|
|
| def split_by(self, mode: SplitMode) -> list['TranscriptChunk']: |
| """根据文本中的标点符号分割片段列表""" |
| if mode == SplitMode.PUNCTUATION: |
| indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()] |
| elif mode == SplitMode.PAUSE: |
| indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()] |
| elif mode == SplitMode.END: |
| indexes = [i for i, seg in enumerate(self.items) if seg.is_end()] |
| else: |
| raise ValueError(f"Unsupported mode: {mode}") |
|
|
| |
| cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)] |
| chunks = [ |
| TranscriptChunk(items=self.items[start:end], separator=self.separator) |
| for start, end in zip(cut_points, cut_points[1:]) |
| ] |
| return [ |
| ck |
| for ck in chunks |
| if not ck.only_punctuation() |
| ] |
|
|
|
|
| def get_split_first_rest(self, mode: SplitMode): |
| chunks = self.split_by(mode) |
| fisrt_chunk = chunks[0] if chunks else self |
| rest_chunks = chunks[1:] if chunks else None |
| return fisrt_chunk, rest_chunks |
|
|
| def puncation_numbers(self) -> int: |
| """计算片段中标点符号的数量""" |
| return sum(1 for seg in self.items if seg.is_punctuation()) |
|
|
| def length(self) -> int: |
| """返回片段列表的长度""" |
| return len(self.items) |
|
|
| def join(self) -> str: |
| """将片段连接为一个字符串""" |
| return self.separator.join(seg.text for seg in self.items) |
|
|
| def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float: |
| """比较当前片段与另一个片段的相似度""" |
| if not chunk: |
| return 0 |
|
|
| score = self._calculate_similarity(self.join(), chunk.join()) |
| |
| return score |
|
|
| def only_punctuation(self)->bool: |
| return all(seg.is_punctuation() for seg in self.items) |
|
|
| def has_punctuation(self) -> bool: |
| return any(seg.is_punctuation() for seg in self.items) |
|
|
| def get_buffer_index(self) -> int: |
| return self.items[-1].buffer_index() |
|
|
| def is_end_sentence(self) ->bool: |
| return self.items[-1].is_end() |
|
|
|
|
| class TranscriptHistory: |
| """管理转录片段的历史记录""" |
|
|
| def __init__(self) -> None: |
| self.history = collections.deque(maxlen=2) |
|
|
| def add(self, chunk: TranscriptChunk): |
| """添加新的片段到历史记录""" |
| self.history.appendleft(chunk) |
|
|
| def previous_chunk(self) -> Optional[TranscriptChunk]: |
| """获取上一个片段(如果存在)""" |
| return self.history[1] if len(self.history) == 2 else None |
|
|
| def lastest_chunk(self): |
| """获取最后一个片段""" |
| return self.history[-1] |
|
|
| def clear(self): |
| self.history.clear() |
|
|
| class TranscriptBuffer: |
| """ |
| 管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落 |
| |
| |-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --| |
| |
| 管理 pending -> line -> paragraph 的缓冲逻辑 |
| |
| """ |
|
|
| def __init__(self, source_lang:str, separator:str): |
| self._segments: List[str] = collections.deque(maxlen=2) |
| self._sentences: List[str] = collections.deque() |
| self._buffer: str = "" |
| self._current_seg_id: int = 0 |
| self.source_language = source_lang |
| self._separator = separator |
|
|
| def get_seg_id(self) -> int: |
| return self._current_seg_id |
|
|
| @property |
| def current_sentences_length(self) -> int: |
| count = 0 |
| for item in self._sentences: |
| if self._separator: |
| count += len(item.split(self._separator)) |
| else: |
| count += len(item) |
| return count |
|
|
| def update_pending_text(self, text: str) -> None: |
| """更新临时缓冲字符串""" |
| self._buffer = text |
|
|
| def commit_line(self,) -> None: |
| """将缓冲字符串提交为短句""" |
| if self._buffer: |
| self._sentences.append(self._buffer) |
| self._buffer = "" |
|
|
| def commit_paragraph(self) -> None: |
| """ |
| 提交当前短句为完整段落(如句子结束) |
| |
| Args: |
| end_of_sentence: 是否为句子结尾(如检测到句号) |
| """ |
|
|
| count = 0 |
| current_sentences = [] |
| while len(self._sentences): |
| item = self._sentences.popleft() |
| current_sentences.append(item) |
| if self._separator: |
| count += len(item.split(self._separator)) |
| else: |
| count += len(item) |
| if current_sentences: |
| self._segments.append("".join(current_sentences)) |
| logger.debug(f"=== count to paragraph ===") |
| logger.debug(f"push: {current_sentences}") |
| logger.debug(f"rest: {self._sentences}") |
| |
| |
| |
|
|
| def rebuild(self, text): |
| output = self.split_and_join( |
| text.replace( |
| self._separator, "")) |
|
|
| logger.debug("==== rebuild string ====") |
| logger.debug(text) |
| logger.debug(output) |
|
|
| return output |
|
|
| @staticmethod |
| def split_and_join(text): |
| tokens = [] |
| word_buf = '' |
|
|
| for char in text: |
| if char in ALL_MARKERS: |
| if word_buf: |
| tokens.extend(wordninja.split(word_buf)) |
| word_buf = '' |
| tokens.append(char) |
| else: |
| word_buf += char |
| if word_buf: |
| tokens.extend(wordninja.split(word_buf)) |
|
|
| output = '' |
| for i, token in enumerate(tokens): |
| if i == 0: |
| output += token |
| elif token in ALL_MARKERS: |
| output += (token + " ") |
| else: |
| output += ' ' + token |
| return output |
|
|
|
|
| def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False): |
| if self.source_language == "en": |
| stable_strings = [self.rebuild(i) for i in stable_strings] |
| remaining_strings =[self.rebuild(i) for i in remaining_strings] |
| remaining_string = "".join(remaining_strings) |
|
|
| logger.debug(f"{self.__dict__}") |
| if is_end_sentence: |
| for stable_str in stable_strings: |
| self.update_pending_text(stable_str) |
| self.commit_line() |
|
|
| current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text) |
| |
| self.update_pending_text(remaining_string) |
| if current_text_len >= config.TEXT_THREHOLD: |
| self.commit_paragraph() |
| self._current_seg_id += 1 |
| return True |
| else: |
| for stable_str in stable_strings: |
| self.update_pending_text(stable_str) |
| self.commit_line() |
| self.update_pending_text(remaining_string) |
| return False |
|
|
|
|
| @property |
| def un_commit_paragraph(self) -> str: |
| """当前短句组合""" |
| return "".join([i for i in self._sentences]) |
|
|
| @property |
| def pending_text(self) -> str: |
| """当前缓冲内容""" |
| return self._buffer |
|
|
| @property |
| def latest_paragraph(self) -> str: |
| """最新确认的段落""" |
| return self._segments[-1] if self._segments else "" |
|
|
| @property |
| def current_not_commit_text(self) -> str: |
| return self.un_commit_paragraph + self.pending_text |
|
|
|
|
|
|
| class TranscriptStabilityAnalyzer: |
| def __init__(self, source_lang, separator) -> None: |
| self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator) |
| self._transcript_history = TranscriptHistory() |
| self._separator = separator |
| logger.debug(f"Current separator: {self._separator}") |
|
|
| def merge_chunks(self, chunks: List[TranscriptChunk])->str: |
| if not chunks: |
| return [""] |
| output = list(r.join() for r in chunks if r) |
| return output |
|
|
|
|
| def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]: |
| current = TranscriptChunk(items=current, separator=self._separator) |
| self._transcript_history.add(current) |
|
|
| prev = self._transcript_history.previous_chunk() |
| self._transcript_buffer.update_pending_text(current.join()) |
| if not prev: |
| yield TranscriptResult( |
| context=self._transcript_buffer.current_not_commit_text, |
| seg_id=self._transcript_buffer.get_seg_id() |
| ) |
| return |
|
|
| |
| if buffer_duration <= 4: |
| yield from self._handle_short_buffer(current, prev) |
| else: |
| yield from self._handle_long_buffer(current) |
|
|
|
|
| def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]: |
| curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION) |
| prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION) |
|
|
| |
| |
| |
| |
|
|
| if curr_first and prev_first: |
|
|
| core = curr_first.compare(prev_first) |
| has_punctuation = curr_first.has_punctuation() |
| if core >= 0.8 and has_punctuation: |
| yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence()) |
| return |
|
|
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| context=self._transcript_buffer.current_not_commit_text |
| ) |
|
|
|
|
| def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]: |
| chunks = curr.split_by(SplitMode.PUNCTUATION) |
| if len(chunks) > 1: |
| stable, remaining = chunks[:-1], chunks[-1:] |
| |
| |
| yield from self._yield_commit_results( |
| stable, remaining, is_end_sentence=True |
| ) |
| else: |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| context=self._transcript_buffer.current_not_commit_text |
| ) |
|
|
|
|
| def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]: |
| stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk) |
| remaining_str_list = self.merge_chunks(remaining_chunks) |
| frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index() |
|
|
| prev_seg_id = self._transcript_buffer.get_seg_id() |
| commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence) |
| logger.debug(f"current buffer: {self._transcript_buffer.__dict__}") |
|
|
| if commit_paragraph: |
| |
| yield TranscriptResult( |
| seg_id=prev_seg_id, |
| cut_index=frame_cut_index, |
| context=self._transcript_buffer.latest_paragraph, |
| is_end_sentence=True |
| ) |
| if (context := self._transcript_buffer.current_not_commit_text.strip()): |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| context=context, |
| ) |
| else: |
| yield TranscriptResult( |
| seg_id=self._transcript_buffer.get_seg_id(), |
| cut_index=frame_cut_index, |
| context=self._transcript_buffer.current_not_commit_text, |
| ) |
|
|
|
|