Translator / transcribe /strategy.py
daihui.zhang
add buffer clip via sequence strategy
e046f39
raw
history blame
4.96 kB
from logging import getLogger
from difflib import SequenceMatcher
import collections
import config
import numpy as np
from itertools import chain
logger = getLogger("Stragegy")
class TripleTextBuffer:
def __init__(self, size=2):
self.history = collections.deque(maxlen=size)
def add_entry(self, text, index):
"""
text: 文本
index: 当前buffer的相对下标 数组索引
"""
self.history.append((text, index))
def get_final_index(self, similarity_threshold=0.7):
"""根据文本变化,返回可靠的标点的buffer的位置下标"""
if len(self.history) < 2:
return None
# 获取三次的文本
text1, _ = self.history[0]
text2, idx2 = self.history[1]
# text3, idx3 = self.history[2]
# 计算变化程度
sim_12 = self.text_similarity(text1, text2)
# print("比较: ", text1, text2," => ", sim_12)
# sim_23 = self.text_similarity(text2, text3)
if sim_12 >= similarity_threshold:
self.history.clear()
return idx2
return None
@staticmethod
def text_similarity(text1, text2):
return SequenceMatcher(None, text1, text2).ratio()
class SegmentManager:
def __init__(self) -> None:
self._commited_segments = [] # 确定后的段落
self._commited_short_sentences = [] # 确定后的序列
self._temp_string = "" # 存储当前临时的文本字符串,直到以句号结尾
def handle(self, string):
self._temp_string = string
return self
@property
def short_sentence(self) -> str:
return "".join(self._commited_short_sentences)
@property
def segment(self):
return self._commited_segments[-1] if len(self._commited_segments) > 0 else ""
def get_seg_id(self):
return len(self._commited_segments)
@property
def string(self):
return self._temp_string
def commit_short_sentence(self):
"""将临时字符串 提交到临时短句"""
self._commited_short_sentences.append(self._temp_string)
self._temp_string = ""
def commit_segment(self):
"""将短句 合并 到长句中"""
self._commited_segments.append(self.short_sentence)
self._commited_short_sentences = []
def commit(self, is_end_sentence=False):
"""
当需要切掉的音频部分的时候,将句子提交到短句队列中,并移除临时字符串
当完成一个整句的时候提交到段落中
"""
self.commit_short_sentence()
if is_end_sentence:
self.commit_segment()
def segement_merge(segments):
"""根据标点符号分整句"""
sequences = []
temp_seq = []
for seg in segments:
temp_seq.append(seg)
if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS]):
sequences.append(temp_seq.copy())
temp_seq = []
if temp_seq:
sequences.append(temp_seq)
return sequences
def segments_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
"""根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
left_watch_sequences = []
left_watch_idx = 0
right_watch_sequences = []
is_end = False
if (len(audio_buffer) / sample_rate) < 12:
# 低于12s 使用短句符号比如逗号作为判断依据
markers = config.PAUSE_END_MARKERS
is_end = False
for idx, seg in enumerate(segments):
left_watch_sequences.append(seg)
if seg.text in markers:
seg_index = int(seg.t1 / 100 * sample_rate)
rest_buffer_duration = (len(audio_buffer) - seg_index) / sample_rate
# is_end = any(i in seg.text for i in config.SENTENCE_END_MARKERS)
right_watch_sequences = segments[min(idx+1, len(segments)):]
if rest_buffer_duration >= 1.5:
left_watch_idx = seg_index
break
return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
def sequences_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
# 长句 保留最后两句即可
left_watch_sequences = []
right_watch_sequences = []
left_watch_idx = 0
is_end = False
sequences = segement_merge(segments)
if len(sequences) > 2:
logger.info(f"buffer clip via sequence, current length: {len(sequences)}")
is_end = True
left_watch_sequences = chain(*sequences[:-2])
right_watch_sequences = chain(*sequences[-2:])
last_sequence_segment = sequences[-3]
last_segment = last_sequence_segment[-1]
left_watch_idx = int(last_segment.t1 / 100 * sample_rate)
return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end