File size: 12,037 Bytes
0c38083 e046f39 0c38083 e046f39 813ffab 0c38083 813ffab 99b58ae 0c38083 e046f39 813ffab 0c38083 813ffab 0c38083 813ffab 0c38083 813ffab e046f39 813ffab b6e4de3 e046f39 813ffab e046f39 0c38083 e046f39 9f6a51c 813ffab e046f39 813ffab e046f39 0c38083 813ffab 99b58ae 813ffab 27321a0 813ffab 99b58ae 813ffab 0c38083 813ffab 99b58ae 0c38083 813ffab 0c38083 813ffab 0c38083 813ffab 0c38083 813ffab 99b58ae 813ffab 7dc6a6f 99b58ae 813ffab 99b58ae 7dc6a6f 813ffab e1e0093 813ffab e1e0093 e046f39 813ffab 0c38083 813ffab 0c38083 e1e0093 c556c3a 813ffab 27321a0 813ffab 99b58ae 813ffab 99b58ae 7dc6a6f b6e4de3 99b58ae 813ffab 99b58ae 813ffab b6e4de3 99b58ae 813ffab 27321a0 99b58ae 813ffab 27321a0 813ffab 27321a0 813ffab 27321a0 813ffab 99b58ae 813ffab 7dc6a6f 813ffab 7dc6a6f 813ffab 27321a0 7dc6a6f 813ffab fa46942 99b58ae fa46942 813ffab 99b58ae 813ffab 0c38083 813ffab 99b58ae 813ffab cd7fb92 813ffab fa46942 813ffab cd7fb92 813ffab cd7fb92 99b58ae 813ffab b6e4de3 813ffab b6e4de3 cd7fb92 813ffab e1e0093 813ffab b2b3b92 99b58ae 813ffab e1e0093 813ffab 99b58ae 813ffab 99b58ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | import re
import collections
import logging
from difflib import SequenceMatcher
from itertools import chain
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
import numpy as np
from enum import Enum
from itertools import chain
logger = logging.getLogger("TranscriptionStrategy")
class SplitMode(Enum):
PUNCTUATION = "punctuation"
PAUSE = "pause"
END = "end"
@dataclass
class TranscriptResult:
seg_id: int = 0
cut_index: int = 0
is_end_sentence: bool = False
context: str = ""
def partial(self):
return not self.is_end_sentence
@dataclass
class TranscriptToken:
"""表示一个转录片段,包含文本和时间信息"""
text: str # 转录的文本内容
t0: float # 开始时间(百分之一秒)
t1: float # 结束时间(百分之一秒)
def is_punctuation(self):
"""检查文本是否包含标点符号"""
return REGEX_MARKERS.search(self.text) is not None
def is_end(self):
"""检查文本是否为句子结束标记"""
return SENTENCE_END_PATTERN.search(self.text) is not None
def is_pause(self):
"""检查文本是否为暂停标记"""
return PAUSEE_END_PATTERN.search(self.text) is not None
def buffer_index(self) -> int:
return max(int(self.t1 / 100 * SAMPLE_RATE) - 300, 0)
@dataclass
class TranscriptChunk:
"""表示一组转录片段,支持分割和比较操作"""
separator: str = "" # 用于连接片段的分隔符
items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
@staticmethod
def _calculate_similarity(text1: str, text2: str) -> float:
"""计算两段文本的相似度"""
return SequenceMatcher(None, text1, text2).ratio()
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
"""根据文本中的标点符号分割片段列表"""
if mode == SplitMode.PUNCTUATION:
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
elif mode == SplitMode.PAUSE:
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
elif mode == SplitMode.END:
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
else:
raise ValueError(f"Unsupported mode: {mode}")
# 每个切分点向后移一个索引,表示“分隔符归前段”
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
return [
TranscriptChunk(items=self.items[start:end], separator=self.separator)
for start, end in zip(cut_points, cut_points[1:])
]
def get_split_first_rest(self, mode: SplitMode):
chunks = self.split_by(mode)
fisrt_chunk = chunks[0] if chunks else self
rest_chunks = chunks[1:] if chunks else None
return fisrt_chunk, rest_chunks
def puncation_numbers(self) -> int:
"""计算片段中标点符号的数量"""
return sum(1 for seg in self.items if seg.is_punctuation())
def length(self) -> int:
"""返回片段列表的长度"""
return len(self.items)
def join(self) -> str:
"""将片段连接为一个字符串"""
return self.separator.join(seg.text for seg in self.items)
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
"""比较当前片段与另一个片段的相似度"""
if not chunk:
return 0
score = self._calculate_similarity(self.join(), chunk.join())
logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
return score
def has_punctuation(self) -> bool:
return any(seg.is_punctuation() for seg in self.items)
def get_buffer_index(self) -> int:
return self.items[-1].buffer_index()
def is_end_sentence(self) ->bool:
return self.items[-1].is_end()
class TranscriptHistory:
"""管理转录片段的历史记录"""
def __init__(self) -> None:
self.history = collections.deque(maxlen=2) # 存储最近的两个片段
def add(self, chunk: TranscriptChunk):
"""添加新的片段到历史记录"""
self.history.appendleft(chunk)
def previous_chunk(self) -> Optional[TranscriptChunk]:
"""获取上一个片段(如果存在)"""
return self.history[1] if len(self.history) == 2 else None
def lastest_chunk(self):
"""获取最后一个片段"""
return self.history[-1]
def clear(self):
self.history.clear()
class TranscriptBuffer:
"""
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
管理 pending -> line -> paragraph 的缓冲逻辑
"""
def __init__(self, separator):
self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
self._sentences: List[str] = [] # 当前段落中的短句
self._buffer: str = "" # 当前缓冲中的文本
self._current_seg_id: int = 0
self._separator = separator
def get_seg_id(self) -> int:
return self._current_seg_id
def next_seg_id(self) -> int:
return self._current_seg_id + 1
def update_pending_text(self, text: str) -> None:
"""更新临时缓冲字符串"""
self._buffer = text
def commit_line(self) -> None:
"""将缓冲字符串提交为短句"""
if self._buffer:
self._sentences.append(self._buffer)
self._buffer = ""
def commit_paragraph(self) -> None:
"""
提交当前短句为完整段落(如句子结束)
Args:
end_of_sentence: 是否为句子结尾(如检测到句号)
"""
if self._sentences:
self._segments.append("".join(self._sentences))
self._sentences.clear()
def update_and_commit(self, stable_string: str, remaining_string:str, is_end_sentence=False):
logger.debug(f"{self.__dict__}")
if is_end_sentence:
self.update_pending_text(stable_string)
self.commit_line()
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
# current_text_len = len(self.current_not_commit_text.split(self._separator))
self.update_pending_text(remaining_string)
if current_text_len >=20:
self.commit_paragraph()
self._current_seg_id += 1
return True
else:
self.update_pending_text(stable_string)
self.commit_line()
self.update_pending_text(remaining_string)
return False
@property
def un_commit_paragraph(self) -> str:
"""当前短句组合"""
return "".join(self._sentences)
@property
def pending_text(self) -> str:
"""当前缓冲内容"""
return self._buffer
@property
def latest_paragraph(self) -> str:
"""最新确认的段落"""
return self._segments[-1] if self._segments else ""
@property
def current_not_commit_text(self) -> str:
return self.un_commit_paragraph + self.pending_text
class TranscriptStabilityAnalyzer:
def __init__(self, separator) -> None:
self._transcript_buffer = TranscriptBuffer(separator=separator)
self._transcript_history = TranscriptHistory()
self._separator = separator
logger.debug(f"Current separator: {self._separator}")
def merge_chunks(self, chunks: List[TranscriptChunk])->str:
return "".join(r.join() for r in chunks)
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
current = TranscriptChunk(items=current, separator=self._separator)
self._transcript_history.add(current)
prev = self._transcript_history.previous_chunk()
self._transcript_buffer.update_pending_text(current.join())
if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
yield TranscriptResult(
context=self._transcript_buffer.current_not_commit_text,
seg_id=self._transcript_buffer.get_seg_id()
)
return
# yield from self._handle_short_buffer(current, prev)
if buffer_duration <= 12:
yield from self._handle_short_buffer(current, prev)
else:
yield from self._handle_long_buffer(current)
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
# logger.debug("==== Current cut item ====")
# logger.debug(f"{curr.join()} ")
# logger.debug(f"{prev.join()}")
# logger.debug("==========================")
if curr_first and prev_first:
core = curr_first.compare(prev_first)
# has_punctuation = curr_first.has_punctuation()
if core >= 0.8:
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
return
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=self._transcript_buffer.current_not_commit_text
)
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
chunks = curr.split_by(SplitMode.PUNCTUATION)
if len(chunks) > 2:
stable, remaining = chunks[:-2], chunks[-2:]
# stable_str = self.merge_chunks(stable)
# remaining_str = self.merge_chunks(remaining)
yield from self._yield_commit_results(
stable, remaining, is_end_sentence=True # 暂时硬编码为True
)
else:
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=self._transcript_buffer.current_not_commit_text
)
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
stable_str = stable_chunk.join() if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
remaining_str = self.merge_chunks(remaining_chunks)
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
prev_seg_id = self._transcript_buffer.get_seg_id()
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str, remaining_str, is_end_sentence)
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
if commit_paragraph:
# 表示生成了一个新段落 换行
yield TranscriptResult(
seg_id=prev_seg_id,
cut_index=frame_cut_index,
context=self._transcript_buffer.latest_paragraph,
is_end_sentence=True
)
if (context := self._transcript_buffer.current_not_commit_text.strip()):
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=context,
)
else:
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
cut_index=frame_cut_index,
context=self._transcript_buffer.current_not_commit_text,
)
|