Merge branch 'vad'
Browse files* vad:
[fix]: logging level.
add string replace
[fix]: parameter.
[fix]: hot words.
[fix]: update.
fix 'transcrible' named error
[fix]: hot words.
update some keywords
add speech start padding 100ms
[fix]: words.
fix max speech duration bug
remove time delaly in loop
add DESIGN_TIME_THREHOLD
添加热词文件路径配置,并在生成模型时使用热词参数。
Disable FunASR pbar in Warmup.
update log level
remove unused codes
remove unused codes
add log to debug silence ms
# Conflicts:
# transcribe/pipelines/pipe_vad.py
- api_model.py +2 -2
- config.py +17 -24
- main.py +1 -5
- moyoyo_asr_models/hotwords.json +7 -0
- moyoyo_asr_models/hotwords.txt +34 -0
- tests/audio_utils.py +54 -0
- tests/test_vad.ipynb +129 -0
- transcribe/client.py +0 -677
- transcribe/helpers/funasr.py +5 -8
- transcribe/helpers/vadprocessor.py +8 -8
- transcribe/pipelines/pipe_vad.py +5 -32
- transcribe/server.py +0 -382
- transcribe/strategy.py +0 -405
- transcribe/transcription.py +0 -334
- transcribe/translatepipes.py +3 -14
- transcribe/utils.py +37 -12
- transcribe/whisper_llm_serve.py +75 -162
api_model.py
CHANGED
|
@@ -18,9 +18,9 @@ class TransResult(BaseModel):
|
|
| 18 |
class DebugResult(BaseModel):
|
| 19 |
# trans_pattern: str
|
| 20 |
seg_id: int
|
| 21 |
-
|
| 22 |
translate_time:float
|
| 23 |
-
context: str = Field(alias="
|
| 24 |
from_: str = Field(alias="from")
|
| 25 |
to: str
|
| 26 |
tran_content: str = Field(alias="translateContent")
|
|
|
|
| 18 |
class DebugResult(BaseModel):
|
| 19 |
# trans_pattern: str
|
| 20 |
seg_id: int
|
| 21 |
+
transcribe_time: float
|
| 22 |
translate_time:float
|
| 23 |
+
context: str = Field(alias="transcribeContent")
|
| 24 |
from_: str = Field(alias="from")
|
| 25 |
to: str
|
| 26 |
tran_content: str = Field(alias="translateContent")
|
config.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
import pathlib
|
| 2 |
import re
|
| 3 |
import logging
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
|
| 7 |
logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
|
| 8 |
logging.basicConfig(
|
| 9 |
-
level=
|
| 10 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 11 |
filename='translator.log',
|
| 12 |
datefmt="%H:%M:%S"
|
|
@@ -15,13 +18,15 @@ logging.basicConfig(
|
|
| 15 |
SAVE_DATA_SAVE = False
|
| 16 |
# Add terminal log
|
| 17 |
console_handler = logging.StreamHandler()
|
| 18 |
-
console_handler.setLevel(
|
| 19 |
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 20 |
console_handler.setFormatter(console_formatter)
|
| 21 |
logging.getLogger().addHandler(console_handler)
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
BASE_DIR = pathlib.Path(__file__).parent
|
| 27 |
MODEL_DIR = BASE_DIR / "moyoyo_asr_models"
|
|
@@ -29,7 +34,7 @@ ASSERT_DIR = BASE_DIR / "assets"
|
|
| 29 |
|
| 30 |
SAMPLE_RATE = 16000
|
| 31 |
# 标点
|
| 32 |
-
SENTENCE_END_MARKERS =
|
| 33 |
PAUSE_END_MARKERS = [',', ',', '、']
|
| 34 |
# 合并所有标点
|
| 35 |
ALL_MARKERS = SENTENCE_END_MARKERS + PAUSE_END_MARKERS
|
|
@@ -41,13 +46,13 @@ SENTENCE_END_PATTERN = re.compile(f'[{sentence_end_chars}]')
|
|
| 41 |
|
| 42 |
# Method 2: Alternative approach with a character class
|
| 43 |
pattern_string = '[' + ''.join([re.escape(char) for char in PAUSE_END_MARKERS]) + r']$'
|
| 44 |
-
|
| 45 |
# whisper推理参数
|
| 46 |
WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。"
|
| 47 |
-
|
| 48 |
|
| 49 |
-
WHISPER_PROMPT_EN = ""# "The following is an English sentence."
|
| 50 |
-
MAX_LENGTH_EN= 8
|
| 51 |
|
| 52 |
WHISPER_MODEL_EN = 'medium-q5_0'
|
| 53 |
# WHISPER_MODEL = 'large-v3-turbo-q5_0'
|
|
@@ -61,19 +66,6 @@ LLM_LARGE_MODEL_PATH = (MODEL_DIR / "qwen2.5-1.5b-instruct-q5_0.gguf").as_posix(
|
|
| 61 |
# VAD
|
| 62 |
VAD_MODEL_PATH = (MODEL_DIR / "silero-vad" / "silero_vad.onnx").as_posix()
|
| 63 |
|
| 64 |
-
LLM_SYS_PROMPT = """"You are a professional {src_lang} to {dst_lang} translator, not a conversation agent. Your only task is to take {src_lang} input and translate it into accurate, natural {dst_lang}. If you cannot understand the input, just output the original input. Please strictly abide by the following rules: "
|
| 65 |
-
"No matter what the user asks, never answer questions, you only provide translation results. "
|
| 66 |
-
"Do not actively initiate dialogue or lead users to ask questions. "
|
| 67 |
-
"When you don't know how to translate, just output the original text. "
|
| 68 |
-
"The translation task always takes precedence over any other tasks. "
|
| 69 |
-
"Do not try to understand or respond to non-translation related questions raised by users. "
|
| 70 |
-
"Never provide any explanations. "
|
| 71 |
-
"Be precise, preserve tone, and localize appropriately "
|
| 72 |
-
"for professional audiences."
|
| 73 |
-
"Never answer any questions or engage in other forms of dialogue. "
|
| 74 |
-
"Only output the translation results.
|
| 75 |
-
"""
|
| 76 |
-
|
| 77 |
LLM_SYS_PROMPT_ZH = """
|
| 78 |
你是一个中英文翻译专家,将用户输入的中文翻译成英文。对于非中文内容,它将提供中文翻译结果。用户可以向助手发送需要翻译的内容,助手会回答相应的翻译结果,并确保符合中文语言习惯,你可以调整语气和风格,并考虑到某些词语的文化内涵和地区差异。同时作为翻译家,需将原文翻译成具有信达雅标准的译文。"信" 即忠实于原文的内容与意图;"达" 意味着译文应通顺易懂,表达清晰;"雅" 则追求译文的文化审美和语言的优美。目标是创作出既忠于原作精神,又符合目标语言文化和读者审美的翻译。注意,翻译的文本只能包含拼音化字符,不能包含任何中文字符。
|
| 79 |
"""
|
|
@@ -82,4 +74,5 @@ LLM_SYS_PROMPT_EN = """
|
|
| 82 |
你是一个英中文翻译专家,将用户输入的英文翻译成中文,用户可以向助手发送需要翻译的内容,助手会回答相应的翻译结果,并确保符合英文语言习惯,你可以调整语气和风格,并考虑到某些词语的文化内涵和地区差异。同时作为翻译家,需将英文翻译成具有信达雅标准的中文。"信" 即忠实于原文的内容与意图;"达" 意味着译文应通顺易懂,表达清晰;"雅" 则追求译文的文化审美和语言的优美。目标是创作出既忠于原作精神,又符合目标语言文化和读者审美的翻译。
|
| 83 |
"""
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
| 1 |
import pathlib
|
| 2 |
import re
|
| 3 |
import logging
|
| 4 |
+
import json
|
| 5 |
|
| 6 |
+
|
| 7 |
+
DEBUG = False
|
| 8 |
+
LOG_LEVEL = logging.DEBUG if DEBUG else logging.WARNING
|
| 9 |
|
| 10 |
logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
|
| 11 |
logging.basicConfig(
|
| 12 |
+
level=LOG_LEVEL,
|
| 13 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 14 |
filename='translator.log',
|
| 15 |
datefmt="%H:%M:%S"
|
|
|
|
| 18 |
SAVE_DATA_SAVE = False
|
| 19 |
# Add terminal log
|
| 20 |
console_handler = logging.StreamHandler()
|
| 21 |
+
console_handler.setLevel(LOG_LEVEL)
|
| 22 |
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 23 |
console_handler.setFormatter(console_formatter)
|
| 24 |
logging.getLogger().addHandler(console_handler)
|
| 25 |
|
| 26 |
+
# 音频段的决策时间
|
| 27 |
+
FRAME_SCOPE_TIME_THRESHOLD = 4
|
| 28 |
+
# 最长语音时长
|
| 29 |
+
MAX_SPEECH_DURATION_S = 15
|
| 30 |
|
| 31 |
BASE_DIR = pathlib.Path(__file__).parent
|
| 32 |
MODEL_DIR = BASE_DIR / "moyoyo_asr_models"
|
|
|
|
| 34 |
|
| 35 |
SAMPLE_RATE = 16000
|
| 36 |
# 标点
|
| 37 |
+
SENTENCE_END_MARKERS = ['.', '!', '?', '。', '!', '?', ';', ';', ':', ':']
|
| 38 |
PAUSE_END_MARKERS = [',', ',', '、']
|
| 39 |
# 合并所有标点
|
| 40 |
ALL_MARKERS = SENTENCE_END_MARKERS + PAUSE_END_MARKERS
|
|
|
|
| 46 |
|
| 47 |
# Method 2: Alternative approach with a character class
|
| 48 |
pattern_string = '[' + ''.join([re.escape(char) for char in PAUSE_END_MARKERS]) + r']$'
|
| 49 |
+
PAUSE_END_PATTERN = re.compile(pattern_string)
|
| 50 |
# whisper推理参数
|
| 51 |
WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。"
|
| 52 |
+
MAX_LENGTH_ZH = 4
|
| 53 |
|
| 54 |
+
WHISPER_PROMPT_EN = "" # "The following is an English sentence."
|
| 55 |
+
MAX_LENGTH_EN = 8
|
| 56 |
|
| 57 |
WHISPER_MODEL_EN = 'medium-q5_0'
|
| 58 |
# WHISPER_MODEL = 'large-v3-turbo-q5_0'
|
|
|
|
| 66 |
# VAD
|
| 67 |
VAD_MODEL_PATH = (MODEL_DIR / "silero-vad" / "silero_vad.onnx").as_posix()
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
LLM_SYS_PROMPT_ZH = """
|
| 70 |
你是一个中英文翻译专家,将用户输入的中文翻译成英文。对于非中文内容,它将提供中文翻译结果。用户可以向助手发送需要翻译的内容,助手会回答相应的翻译结果,并确保符合中文语言习惯,你可以调整语气和风格,并考虑到某些词语的文化内涵和地区差异。同时作为翻译家,需将原文翻译成具有信达雅标准的译文。"信" 即忠实于原文的内容与意图;"达" 意味着译文应通顺易懂,表达清晰;"雅" 则追求译文的文化审美和语言的优美。目标是创作出既忠于原作精神,又符合目标语言文化和读者审美的翻译。注意,翻译的文本只能包含拼音化字符,不能包含任何中文字符。
|
| 71 |
"""
|
|
|
|
| 74 |
你是一个英中文翻译专家,将用户输入的英文翻译成中文,用户可以向助手发送需要翻译的内容,助手会回答相应的翻译结果,并确保符合英文语言习惯,你可以调整语气和风格,并考虑到某些词语的文化内涵和地区差异。同时作为翻译家,需将英文翻译成具有信达雅标准的中文。"信" 即忠实于原文的内容与意图;"达" 意味着译文应通顺易懂,表达清晰;"雅" 则追求译文的文化审美和语言的优美。目标是创作出既忠于原作精神,又符合目标语言文化和读者审美的翻译。
|
| 75 |
"""
|
| 76 |
|
| 77 |
+
hotwords_file = MODEL_DIR / 'hotwords.txt'
|
| 78 |
+
hotwords_json = json.loads((MODEL_DIR / 'hotwords.json').read_text())
|
main.py
CHANGED
|
@@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles
|
|
| 11 |
from fastapi.responses import RedirectResponse
|
| 12 |
import os
|
| 13 |
from transcribe.utils import pcm_bytes_to_np_array
|
|
|
|
| 14 |
logger = getLogger(__name__)
|
| 15 |
|
| 16 |
|
|
@@ -39,9 +40,6 @@ async def lifespan(app:FastAPI):
|
|
| 39 |
yield
|
| 40 |
|
| 41 |
|
| 42 |
-
# 获取当前文件所在目录的绝对路径
|
| 43 |
-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 44 |
-
# 构建frontend目录的绝对路径
|
| 45 |
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
| 46 |
|
| 47 |
|
|
@@ -66,9 +64,7 @@ async def translate(websocket: WebSocket):
|
|
| 66 |
client_uid=f"{uuid1()}",
|
| 67 |
)
|
| 68 |
|
| 69 |
-
|
| 70 |
if from_lang and to_lang and client:
|
| 71 |
-
client.set_language(from_lang, to_lang)
|
| 72 |
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 73 |
await websocket.accept()
|
| 74 |
try:
|
|
|
|
| 11 |
from fastapi.responses import RedirectResponse
|
| 12 |
import os
|
| 13 |
from transcribe.utils import pcm_bytes_to_np_array
|
| 14 |
+
from config import BASE_DIR
|
| 15 |
logger = getLogger(__name__)
|
| 16 |
|
| 17 |
|
|
|
|
| 40 |
yield
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
| 44 |
|
| 45 |
|
|
|
|
| 64 |
client_uid=f"{uuid1()}",
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
if from_lang and to_lang and client:
|
|
|
|
| 68 |
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 69 |
await websocket.accept()
|
| 70 |
try:
|
moyoyo_asr_models/hotwords.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"高斯姆": "GOSIM",
|
| 3 |
+
"GO SIM": "GOSIM",
|
| 4 |
+
"go sim": "GOSIM",
|
| 5 |
+
"GO SAME": "GOSIM",
|
| 6 |
+
"go same": "GOSIM"
|
| 7 |
+
}
|
moyoyo_asr_models/hotwords.txt
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GOSIM
|
| 2 |
+
CSDN
|
| 3 |
+
Rust
|
| 4 |
+
git
|
| 5 |
+
lib
|
| 6 |
+
HUAWEI
|
| 7 |
+
Futurewei
|
| 8 |
+
Cloud
|
| 9 |
+
OpenAI
|
| 10 |
+
PYTHON
|
| 11 |
+
千问
|
| 12 |
+
鸿蒙
|
| 13 |
+
vLLM
|
| 14 |
+
MiniCPM
|
| 15 |
+
ChatGPT
|
| 16 |
+
GPT
|
| 17 |
+
GPT2
|
| 18 |
+
GPT3
|
| 19 |
+
GPT4
|
| 20 |
+
Llama
|
| 21 |
+
Llama2
|
| 22 |
+
Llama3
|
| 23 |
+
MISTRAL
|
| 24 |
+
Large
|
| 25 |
+
Mistral
|
| 26 |
+
Small
|
| 27 |
+
LoRA
|
| 28 |
+
finetune
|
| 29 |
+
quantization
|
| 30 |
+
pruning
|
| 31 |
+
MoXIN
|
| 32 |
+
Function
|
| 33 |
+
Func
|
| 34 |
+
Lava
|
tests/audio_utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import soundfile as sf
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
def audio_stream_generator(audio_file_path, chunk_size=4096, simulate_realtime=True):
|
| 6 |
+
"""
|
| 7 |
+
音频流生成器,从音频文件中读取数据并以流的方式输出
|
| 8 |
+
|
| 9 |
+
参数:
|
| 10 |
+
audio_file_path: 音频文件路径
|
| 11 |
+
chunk_size: 每个数据块的大小(采样点数)
|
| 12 |
+
simulate_realtime: 是否模拟实时流处理的速度
|
| 13 |
+
|
| 14 |
+
生成:
|
| 15 |
+
numpy.ndarray: 每次生成一个chunk_size大小的np.float32数据块
|
| 16 |
+
"""
|
| 17 |
+
# 加载音频文件
|
| 18 |
+
audio_data, sample_rate = sf.read(audio_file_path)
|
| 19 |
+
|
| 20 |
+
# 确保音频数据是float32类型
|
| 21 |
+
if audio_data.dtype != np.float32:
|
| 22 |
+
audio_data = audio_data.astype(np.float32)
|
| 23 |
+
|
| 24 |
+
# 如果是立体声,转换为单声道
|
| 25 |
+
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
|
| 26 |
+
audio_data = audio_data.mean(axis=1)
|
| 27 |
+
|
| 28 |
+
print(f"已加载音频文件: {audio_file_path}")
|
| 29 |
+
print(f"采样率: {sample_rate} Hz")
|
| 30 |
+
print(f"音频长度: {len(audio_data)/sample_rate:.2f} 秒")
|
| 31 |
+
|
| 32 |
+
# 计算每个块的时长(秒)
|
| 33 |
+
chunk_duration = chunk_size / sample_rate if simulate_realtime else 0
|
| 34 |
+
|
| 35 |
+
# 按块生成数据
|
| 36 |
+
audio_len = len(audio_data)
|
| 37 |
+
for pos in range(0, audio_len, chunk_size):
|
| 38 |
+
# 获取当前块
|
| 39 |
+
end_pos = min(pos + chunk_size, audio_len)
|
| 40 |
+
chunk = audio_data[pos:end_pos]
|
| 41 |
+
|
| 42 |
+
# 如果块大小不足,用0填充
|
| 43 |
+
if len(chunk) < chunk_size:
|
| 44 |
+
padded_chunk = np.zeros(chunk_size, dtype=np.float32)
|
| 45 |
+
padded_chunk[:len(chunk)] = chunk
|
| 46 |
+
chunk = padded_chunk
|
| 47 |
+
|
| 48 |
+
# 模拟实时处理的延迟
|
| 49 |
+
if simulate_realtime:
|
| 50 |
+
time.sleep(chunk_duration)
|
| 51 |
+
|
| 52 |
+
yield chunk
|
| 53 |
+
|
| 54 |
+
print("音频流处理完成")
|
tests/test_vad.ipynb
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from audio_utils import audio_stream_generator\n",
|
| 10 |
+
"import IPython.display as ipd\n",
|
| 11 |
+
"import sys\n",
|
| 12 |
+
"sys.path.append(\"..\")\n",
|
| 13 |
+
"from transcribe.helpers.vadprocessor import FixedVADIterator\n"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "code",
|
| 18 |
+
"execution_count": 3,
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"vac = FixedVADIterator(\n",
|
| 23 |
+
" threshold=0.5,\n",
|
| 24 |
+
" sampling_rate=16000,\n",
|
| 25 |
+
" # speech_pad_ms=10\n",
|
| 26 |
+
" min_silence_duration_ms = 100,\n",
|
| 27 |
+
" # speech_pad_ms = 30,\n",
|
| 28 |
+
" max_speech_duration_s=5.0,\n",
|
| 29 |
+
" )\n"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 10,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"SAMPLE_FILE_PATH = \"/Users/david/Samples/Audio/zh/liyongle.wav\"\n",
|
| 39 |
+
"SAMPLING_RATE = 16000\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"chunks_generator = audio_stream_generator(SAMPLE_FILE_PATH, chunk_size=4096)\n",
|
| 42 |
+
"vac.reset_states()"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 11,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [
|
| 50 |
+
{
|
| 51 |
+
"name": "stdout",
|
| 52 |
+
"output_type": "stream",
|
| 53 |
+
"text": [
|
| 54 |
+
"已加载音频文件: /Users/david/Samples/Audio/zh/liyongle.wav\n",
|
| 55 |
+
"采样率: 16000 Hz\n",
|
| 56 |
+
"音频长度: 64.00 秒\n",
|
| 57 |
+
"{'start': 3616}\n",
|
| 58 |
+
"{'end': 83968}\n",
|
| 59 |
+
"{'end': 164352}\n",
|
| 60 |
+
"{'end': 244736}\n",
|
| 61 |
+
"{'end': 325120}\n",
|
| 62 |
+
"{'end': 405504}\n",
|
| 63 |
+
"{'end': 485888}\n",
|
| 64 |
+
"{'end': 566272}\n",
|
| 65 |
+
"{'end': 624608}\n",
|
| 66 |
+
"{'start': 631328}\n",
|
| 67 |
+
"{'end': 691168}\n",
|
| 68 |
+
"{'start': 698912}\n",
|
| 69 |
+
"{'end': 779264}\n",
|
| 70 |
+
"{'end': 800736}\n",
|
| 71 |
+
"{'start': 805920}\n",
|
| 72 |
+
"{'end': 846816}\n",
|
| 73 |
+
"{'start': 855072}\n",
|
| 74 |
+
"{'end': 862176}\n",
|
| 75 |
+
"{'start': 864288}\n",
|
| 76 |
+
"{'end': 890336}\n",
|
| 77 |
+
"{'start': 893984}\n",
|
| 78 |
+
"{'end': 912352}\n",
|
| 79 |
+
"{'start': 917536}\n",
|
| 80 |
+
"{'end': 932320}\n",
|
| 81 |
+
"{'start': 939040}\n",
|
| 82 |
+
"{'end': 966112}\n",
|
| 83 |
+
"{'start': 970784}\n",
|
| 84 |
+
"{'end': 1015264}\n",
|
| 85 |
+
"{'start': 1019424}\n",
|
| 86 |
+
"音频流处理完成\n"
|
| 87 |
+
]
|
| 88 |
+
}
|
| 89 |
+
],
|
| 90 |
+
"source": [
|
| 91 |
+
"for chunk in chunks_generator:\n",
|
| 92 |
+
" # vad_iterator.reset_states()\n",
|
| 93 |
+
" # audio_buffer = np.append(audio_buffer, chunk)\n",
|
| 94 |
+
" \n",
|
| 95 |
+
" speech_dict = vac(chunk, return_seconds=False)\n",
|
| 96 |
+
" if speech_dict:\n",
|
| 97 |
+
" print(speech_dict)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": []
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"metadata": {
|
| 109 |
+
"kernelspec": {
|
| 110 |
+
"display_name": ".venv",
|
| 111 |
+
"language": "python",
|
| 112 |
+
"name": "python3"
|
| 113 |
+
},
|
| 114 |
+
"language_info": {
|
| 115 |
+
"codemirror_mode": {
|
| 116 |
+
"name": "ipython",
|
| 117 |
+
"version": 3
|
| 118 |
+
},
|
| 119 |
+
"file_extension": ".py",
|
| 120 |
+
"mimetype": "text/x-python",
|
| 121 |
+
"name": "python",
|
| 122 |
+
"nbconvert_exporter": "python",
|
| 123 |
+
"pygments_lexer": "ipython3",
|
| 124 |
+
"version": "3.11.11"
|
| 125 |
+
}
|
| 126 |
+
},
|
| 127 |
+
"nbformat": 4,
|
| 128 |
+
"nbformat_minor": 2
|
| 129 |
+
}
|
transcribe/client.py
DELETED
|
@@ -1,677 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import threading
|
| 5 |
-
import time
|
| 6 |
-
import uuid
|
| 7 |
-
import wave
|
| 8 |
-
|
| 9 |
-
import av
|
| 10 |
-
import numpy as np
|
| 11 |
-
import pyaudio
|
| 12 |
-
import websocket
|
| 13 |
-
|
| 14 |
-
import transcribe.utils as utils
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Client:
|
| 18 |
-
"""
|
| 19 |
-
Handles communication with a server using WebSocket.
|
| 20 |
-
"""
|
| 21 |
-
INSTANCES = {}
|
| 22 |
-
END_OF_AUDIO = "END_OF_AUDIO"
|
| 23 |
-
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
host=None,
|
| 27 |
-
port=None,
|
| 28 |
-
lang=None,
|
| 29 |
-
log_transcription=True,
|
| 30 |
-
max_clients=4,
|
| 31 |
-
max_connection_time=600,
|
| 32 |
-
dst_lang='zh',
|
| 33 |
-
):
|
| 34 |
-
"""
|
| 35 |
-
Initializes a Client instance for audio recording and streaming to a server.
|
| 36 |
-
|
| 37 |
-
If host and port are not provided, the WebSocket connection will not be established.
|
| 38 |
-
the audio recording starts immediately upon initialization.
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
host (str): The hostname or IP address of the server.
|
| 42 |
-
port (int): The port number for the WebSocket server.
|
| 43 |
-
lang (str, optional): The selected language for transcription. Default is None.
|
| 44 |
-
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
| 45 |
-
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
| 46 |
-
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
| 47 |
-
"""
|
| 48 |
-
self.recording = False
|
| 49 |
-
self.uid = str(uuid.uuid4())
|
| 50 |
-
self.waiting = False
|
| 51 |
-
self.last_response_received = None
|
| 52 |
-
self.disconnect_if_no_response_for = 15
|
| 53 |
-
self.language = lang
|
| 54 |
-
self.server_error = False
|
| 55 |
-
self.last_segment = None
|
| 56 |
-
self.last_received_segment = None
|
| 57 |
-
self.log_transcription = log_transcription
|
| 58 |
-
self.max_clients = max_clients
|
| 59 |
-
self.max_connection_time = max_connection_time
|
| 60 |
-
self.dst_lang = dst_lang
|
| 61 |
-
|
| 62 |
-
self.audio_bytes = None
|
| 63 |
-
|
| 64 |
-
if host is not None and port is not None:
|
| 65 |
-
socket_url = f"ws://{host}:{port}?from={self.language}&to={self.dst_lang}"
|
| 66 |
-
self.client_socket = websocket.WebSocketApp(
|
| 67 |
-
socket_url,
|
| 68 |
-
on_open=lambda ws: self.on_open(ws),
|
| 69 |
-
on_message=lambda ws, message: self.on_message(ws, message),
|
| 70 |
-
on_error=lambda ws, error: self.on_error(ws, error),
|
| 71 |
-
on_close=lambda ws, close_status_code, close_msg: self.on_close(
|
| 72 |
-
ws, close_status_code, close_msg
|
| 73 |
-
),
|
| 74 |
-
)
|
| 75 |
-
else:
|
| 76 |
-
print("[ERROR]: No host or port specified.")
|
| 77 |
-
return
|
| 78 |
-
|
| 79 |
-
Client.INSTANCES[self.uid] = self
|
| 80 |
-
|
| 81 |
-
# start websocket client in a thread
|
| 82 |
-
self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
|
| 83 |
-
self.ws_thread.daemon = True
|
| 84 |
-
self.ws_thread.start()
|
| 85 |
-
|
| 86 |
-
self.transcript = []
|
| 87 |
-
print("[INFO]: * recording")
|
| 88 |
-
|
| 89 |
-
def handle_status_messages(self, message_data):
|
| 90 |
-
"""Handles server status messages."""
|
| 91 |
-
status = message_data["status"]
|
| 92 |
-
if status == "WAIT":
|
| 93 |
-
self.waiting = True
|
| 94 |
-
print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
|
| 95 |
-
elif status == "ERROR":
|
| 96 |
-
print(f"Message from Server: {message_data['message']}")
|
| 97 |
-
self.server_error = True
|
| 98 |
-
elif status == "WARNING":
|
| 99 |
-
print(f"Message from Server: {message_data['message']}")
|
| 100 |
-
|
| 101 |
-
def process_segments(self, segments):
|
| 102 |
-
"""Processes transcript segments."""
|
| 103 |
-
text = []
|
| 104 |
-
for i, seg in enumerate(segments):
|
| 105 |
-
if not text or text[-1] != seg["text"]:
|
| 106 |
-
text.append(seg["text"])
|
| 107 |
-
if i == len(segments) - 1 and not seg.get("completed", False):
|
| 108 |
-
self.last_segment = seg
|
| 109 |
-
|
| 110 |
-
# update last received segment and last valid response time
|
| 111 |
-
if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
|
| 112 |
-
self.last_response_received = time.time()
|
| 113 |
-
self.last_received_segment = segments[-1]["text"]
|
| 114 |
-
|
| 115 |
-
if self.log_transcription:
|
| 116 |
-
# Truncate to last 3 entries for brevity.
|
| 117 |
-
text = text[-3:]
|
| 118 |
-
utils.clear_screen()
|
| 119 |
-
utils.print_transcript(text)
|
| 120 |
-
|
| 121 |
-
def on_message(self, ws, message):
|
| 122 |
-
"""
|
| 123 |
-
Callback function called when a message is received from the server.
|
| 124 |
-
|
| 125 |
-
It updates various attributes of the client based on the received message, including
|
| 126 |
-
recording status, language detection, and server messages. If a disconnect message
|
| 127 |
-
is received, it sets the recording status to False.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
ws (websocket.WebSocketApp): The WebSocket client instance.
|
| 131 |
-
message (str): The received message from the server.
|
| 132 |
-
|
| 133 |
-
"""
|
| 134 |
-
message = json.loads(message)
|
| 135 |
-
|
| 136 |
-
# if self.uid != message.get("uid"):
|
| 137 |
-
# print("[ERROR]: invalid client uid")
|
| 138 |
-
# return
|
| 139 |
-
|
| 140 |
-
if "status" in message.keys():
|
| 141 |
-
self.handle_status_messages(message)
|
| 142 |
-
return
|
| 143 |
-
|
| 144 |
-
if "message" in message.keys() and message["message"] == "DISCONNECT":
|
| 145 |
-
print("[INFO]: Server disconnected due to overtime.")
|
| 146 |
-
self.recording = False
|
| 147 |
-
|
| 148 |
-
if "message" in message.keys() and message["message"] == "SERVER_READY":
|
| 149 |
-
self.last_response_received = time.time()
|
| 150 |
-
self.recording = True
|
| 151 |
-
self.server_backend = message["backend"]
|
| 152 |
-
print(f"[INFO]: Server Running with backend {self.server_backend}")
|
| 153 |
-
return
|
| 154 |
-
|
| 155 |
-
if "language" in message.keys():
|
| 156 |
-
self.language = message.get("language")
|
| 157 |
-
lang_prob = message.get("language_prob")
|
| 158 |
-
print(
|
| 159 |
-
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
|
| 160 |
-
)
|
| 161 |
-
return
|
| 162 |
-
|
| 163 |
-
if "segments" in message.keys():
|
| 164 |
-
self.process_segments(message["segments"])
|
| 165 |
-
|
| 166 |
-
def on_error(self, ws, error):
|
| 167 |
-
print(f"[ERROR] WebSocket Error: {error}")
|
| 168 |
-
self.server_error = True
|
| 169 |
-
self.error_message = error
|
| 170 |
-
|
| 171 |
-
def on_close(self, ws, close_status_code, close_msg):
|
| 172 |
-
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
|
| 173 |
-
self.recording = False
|
| 174 |
-
self.waiting = False
|
| 175 |
-
|
| 176 |
-
def on_open(self, ws):
|
| 177 |
-
"""
|
| 178 |
-
Callback function called when the WebSocket connection is successfully opened.
|
| 179 |
-
|
| 180 |
-
Sends an initial configuration message to the server, including client UID,
|
| 181 |
-
language selection, and task type.
|
| 182 |
-
|
| 183 |
-
Args:
|
| 184 |
-
ws (websocket.WebSocketApp): The WebSocket client instance.
|
| 185 |
-
|
| 186 |
-
"""
|
| 187 |
-
print("[INFO]: Opened connection")
|
| 188 |
-
ws.send(
|
| 189 |
-
json.dumps(
|
| 190 |
-
{
|
| 191 |
-
"uid": self.uid,
|
| 192 |
-
"language": self.language,
|
| 193 |
-
"max_clients": self.max_clients,
|
| 194 |
-
"max_connection_time": self.max_connection_time,
|
| 195 |
-
}
|
| 196 |
-
)
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def send_packet_to_server(self, message):
|
| 200 |
-
"""
|
| 201 |
-
Send an audio packet to the server using WebSocket.
|
| 202 |
-
|
| 203 |
-
Args:
|
| 204 |
-
message (bytes): The audio data packet in bytes to be sent to the server.
|
| 205 |
-
|
| 206 |
-
"""
|
| 207 |
-
try:
|
| 208 |
-
self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
|
| 209 |
-
except Exception as e:
|
| 210 |
-
print(e)
|
| 211 |
-
|
| 212 |
-
def close_websocket(self):
|
| 213 |
-
"""
|
| 214 |
-
Close the WebSocket connection and join the WebSocket thread.
|
| 215 |
-
|
| 216 |
-
First attempts to close the WebSocket connection using `self.client_socket.close()`. After
|
| 217 |
-
closing the connection, it joins the WebSocket thread to ensure proper termination.
|
| 218 |
-
|
| 219 |
-
"""
|
| 220 |
-
try:
|
| 221 |
-
self.client_socket.close()
|
| 222 |
-
except Exception as e:
|
| 223 |
-
print("[ERROR]: Error closing WebSocket:", e)
|
| 224 |
-
|
| 225 |
-
try:
|
| 226 |
-
self.ws_thread.join()
|
| 227 |
-
except Exception as e:
|
| 228 |
-
print("[ERROR:] Error joining WebSocket thread:", e)
|
| 229 |
-
|
| 230 |
-
def get_client_socket(self):
|
| 231 |
-
"""
|
| 232 |
-
Get the WebSocket client socket instance.
|
| 233 |
-
|
| 234 |
-
Returns:
|
| 235 |
-
WebSocketApp: The WebSocket client socket instance currently in use by the client.
|
| 236 |
-
"""
|
| 237 |
-
return self.client_socket
|
| 238 |
-
|
| 239 |
-
def wait_before_disconnect(self):
|
| 240 |
-
"""Waits a bit before disconnecting in order to process pending responses."""
|
| 241 |
-
assert self.last_response_received
|
| 242 |
-
while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
|
| 243 |
-
continue
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
class TranscriptionTeeClient:
|
| 247 |
-
"""
|
| 248 |
-
Client for handling audio recording, streaming, and transcription tasks via one or more
|
| 249 |
-
WebSocket connections.
|
| 250 |
-
|
| 251 |
-
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
| 252 |
-
to send audio data for transcription to one or more servers, and receive transcribed text segments.
|
| 253 |
-
Args:
|
| 254 |
-
clients (list): one or more previously initialized Client instances
|
| 255 |
-
|
| 256 |
-
Attributes:
|
| 257 |
-
clients (list): the underlying Client instances responsible for handling WebSocket connections.
|
| 258 |
-
"""
|
| 259 |
-
|
| 260 |
-
def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav",
|
| 261 |
-
mute_audio_playback=False):
|
| 262 |
-
self.clients = clients
|
| 263 |
-
if not self.clients:
|
| 264 |
-
raise Exception("At least one client is required.")
|
| 265 |
-
self.chunk = 4096
|
| 266 |
-
self.format = pyaudio.paInt16
|
| 267 |
-
self.channels = 1
|
| 268 |
-
self.rate = 16000
|
| 269 |
-
self.record_seconds = 60000
|
| 270 |
-
self.save_output_recording = save_output_recording
|
| 271 |
-
self.output_recording_filename = output_recording_filename
|
| 272 |
-
self.mute_audio_playback = mute_audio_playback
|
| 273 |
-
self.frames = b""
|
| 274 |
-
self.p = pyaudio.PyAudio()
|
| 275 |
-
try:
|
| 276 |
-
self.stream = self.p.open(
|
| 277 |
-
format=self.format,
|
| 278 |
-
channels=self.channels,
|
| 279 |
-
rate=self.rate,
|
| 280 |
-
input=True,
|
| 281 |
-
frames_per_buffer=self.chunk,
|
| 282 |
-
)
|
| 283 |
-
except OSError as error:
|
| 284 |
-
print(f"[WARN]: Unable to access microphone. {error}")
|
| 285 |
-
self.stream = None
|
| 286 |
-
|
| 287 |
-
def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
|
| 288 |
-
"""
|
| 289 |
-
Start the transcription process.
|
| 290 |
-
|
| 291 |
-
Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
|
| 292 |
-
to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
|
| 293 |
-
will be played and streamed to the server; otherwise, it will perform live recording.
|
| 294 |
-
|
| 295 |
-
Args:
|
| 296 |
-
audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
|
| 297 |
-
|
| 298 |
-
"""
|
| 299 |
-
assert sum(
|
| 300 |
-
source is not None for source in [audio, rtsp_url, hls_url]
|
| 301 |
-
) <= 1, 'You must provide only one selected source'
|
| 302 |
-
|
| 303 |
-
print("[INFO]: Waiting for server ready ...")
|
| 304 |
-
for client in self.clients:
|
| 305 |
-
while not client.recording:
|
| 306 |
-
if client.waiting or client.server_error:
|
| 307 |
-
self.close_all_clients()
|
| 308 |
-
return
|
| 309 |
-
|
| 310 |
-
print("[INFO]: Server Ready!")
|
| 311 |
-
if hls_url is not None:
|
| 312 |
-
self.process_hls_stream(hls_url, save_file)
|
| 313 |
-
elif audio is not None:
|
| 314 |
-
resampled_file = utils.resample(audio)
|
| 315 |
-
self.play_file(resampled_file)
|
| 316 |
-
elif rtsp_url is not None:
|
| 317 |
-
self.process_rtsp_stream(rtsp_url)
|
| 318 |
-
else:
|
| 319 |
-
self.record()
|
| 320 |
-
|
| 321 |
-
def close_all_clients(self):
|
| 322 |
-
"""Closes all client websockets."""
|
| 323 |
-
for client in self.clients:
|
| 324 |
-
client.close_websocket()
|
| 325 |
-
|
| 326 |
-
def multicast_packet(self, packet, unconditional=False):
|
| 327 |
-
"""
|
| 328 |
-
Sends an identical packet via all clients.
|
| 329 |
-
|
| 330 |
-
Args:
|
| 331 |
-
packet (bytes): The audio data packet in bytes to be sent.
|
| 332 |
-
unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
|
| 333 |
-
"""
|
| 334 |
-
for client in self.clients:
|
| 335 |
-
if (unconditional or client.recording):
|
| 336 |
-
client.send_packet_to_server(packet)
|
| 337 |
-
|
| 338 |
-
def play_file(self, filename):
|
| 339 |
-
"""
|
| 340 |
-
Play an audio file and send it to the server for processing.
|
| 341 |
-
|
| 342 |
-
Reads an audio file, plays it through the audio output, and simultaneously sends
|
| 343 |
-
the audio data to the server for processing. It uses PyAudio to create an audio
|
| 344 |
-
stream for playback. The audio data is read from the file in chunks, converted to
|
| 345 |
-
floating-point format, and sent to the server using WebSocket communication.
|
| 346 |
-
This method is typically used when you want to process pre-recorded audio and send it
|
| 347 |
-
to the server in real-time.
|
| 348 |
-
|
| 349 |
-
Args:
|
| 350 |
-
filename (str): The path to the audio file to be played and sent to the server.
|
| 351 |
-
"""
|
| 352 |
-
|
| 353 |
-
# read audio and create pyaudio stream
|
| 354 |
-
with wave.open(filename, "rb") as wavfile:
|
| 355 |
-
self.stream = self.p.open(
|
| 356 |
-
format=self.p.get_format_from_width(wavfile.getsampwidth()),
|
| 357 |
-
channels=wavfile.getnchannels(),
|
| 358 |
-
rate=wavfile.getframerate(),
|
| 359 |
-
input=True,
|
| 360 |
-
output=True,
|
| 361 |
-
frames_per_buffer=self.chunk,
|
| 362 |
-
)
|
| 363 |
-
chunk_duration = self.chunk / float(wavfile.getframerate())
|
| 364 |
-
try:
|
| 365 |
-
while any(client.recording for client in self.clients):
|
| 366 |
-
data = wavfile.readframes(self.chunk)
|
| 367 |
-
if data == b"":
|
| 368 |
-
break
|
| 369 |
-
|
| 370 |
-
audio_array = self.bytes_to_float_array(data)
|
| 371 |
-
self.multicast_packet(audio_array.tobytes())
|
| 372 |
-
if self.mute_audio_playback:
|
| 373 |
-
time.sleep(chunk_duration)
|
| 374 |
-
else:
|
| 375 |
-
self.stream.write(data)
|
| 376 |
-
|
| 377 |
-
wavfile.close()
|
| 378 |
-
|
| 379 |
-
for client in self.clients:
|
| 380 |
-
client.wait_before_disconnect()
|
| 381 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 382 |
-
self.stream.close()
|
| 383 |
-
self.close_all_clients()
|
| 384 |
-
|
| 385 |
-
except KeyboardInterrupt:
|
| 386 |
-
wavfile.close()
|
| 387 |
-
self.stream.stop_stream()
|
| 388 |
-
self.stream.close()
|
| 389 |
-
self.p.terminate()
|
| 390 |
-
self.close_all_clients()
|
| 391 |
-
print("[INFO]: Keyboard interrupt.")
|
| 392 |
-
|
| 393 |
-
def process_rtsp_stream(self, rtsp_url):
|
| 394 |
-
"""
|
| 395 |
-
Connect to an RTSP source, process the audio stream, and send it for transcription.
|
| 396 |
-
|
| 397 |
-
Args:
|
| 398 |
-
rtsp_url (str): The URL of the RTSP stream source.
|
| 399 |
-
"""
|
| 400 |
-
print("[INFO]: Connecting to RTSP stream...")
|
| 401 |
-
try:
|
| 402 |
-
container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
|
| 403 |
-
self.process_av_stream(container, stream_type="RTSP")
|
| 404 |
-
except Exception as e:
|
| 405 |
-
print(f"[ERROR]: Failed to process RTSP stream: {e}")
|
| 406 |
-
finally:
|
| 407 |
-
for client in self.clients:
|
| 408 |
-
client.wait_before_disconnect()
|
| 409 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 410 |
-
self.close_all_clients()
|
| 411 |
-
print("[INFO]: RTSP stream processing finished.")
|
| 412 |
-
|
| 413 |
-
def process_hls_stream(self, hls_url, save_file=None):
|
| 414 |
-
"""
|
| 415 |
-
Connect to an HLS source, process the audio stream, and send it for transcription.
|
| 416 |
-
|
| 417 |
-
Args:
|
| 418 |
-
hls_url (str): The URL of the HLS stream source.
|
| 419 |
-
save_file (str, optional): Local path to save the network stream.
|
| 420 |
-
"""
|
| 421 |
-
print("[INFO]: Connecting to HLS stream...")
|
| 422 |
-
try:
|
| 423 |
-
container = av.open(hls_url, format="hls")
|
| 424 |
-
self.process_av_stream(container, stream_type="HLS", save_file=save_file)
|
| 425 |
-
except Exception as e:
|
| 426 |
-
print(f"[ERROR]: Failed to process HLS stream: {e}")
|
| 427 |
-
finally:
|
| 428 |
-
for client in self.clients:
|
| 429 |
-
client.wait_before_disconnect()
|
| 430 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 431 |
-
self.close_all_clients()
|
| 432 |
-
print("[INFO]: HLS stream processing finished.")
|
| 433 |
-
|
| 434 |
-
def process_av_stream(self, container, stream_type, save_file=None):
|
| 435 |
-
"""
|
| 436 |
-
Process an AV container stream and send audio packets to the server.
|
| 437 |
-
|
| 438 |
-
Args:
|
| 439 |
-
container (av.container.InputContainer): The input container to process.
|
| 440 |
-
stream_type (str): The type of stream being processed ("RTSP" or "HLS").
|
| 441 |
-
save_file (str, optional): Local path to save the stream. Default is None.
|
| 442 |
-
"""
|
| 443 |
-
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
| 444 |
-
if not audio_stream:
|
| 445 |
-
print(f"[ERROR]: No audio stream found in {stream_type} source.")
|
| 446 |
-
return
|
| 447 |
-
|
| 448 |
-
output_container = None
|
| 449 |
-
if save_file:
|
| 450 |
-
output_container = av.open(save_file, mode="w")
|
| 451 |
-
output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
|
| 452 |
-
|
| 453 |
-
try:
|
| 454 |
-
for packet in container.demux(audio_stream):
|
| 455 |
-
for frame in packet.decode():
|
| 456 |
-
audio_data = frame.to_ndarray().tobytes()
|
| 457 |
-
self.multicast_packet(audio_data)
|
| 458 |
-
|
| 459 |
-
if save_file:
|
| 460 |
-
output_container.mux(frame)
|
| 461 |
-
except Exception as e:
|
| 462 |
-
print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
|
| 463 |
-
finally:
|
| 464 |
-
# Wait for server to send any leftover transcription.
|
| 465 |
-
time.sleep(5)
|
| 466 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
| 467 |
-
if output_container:
|
| 468 |
-
output_container.close()
|
| 469 |
-
container.close()
|
| 470 |
-
|
| 471 |
-
def save_chunk(self, n_audio_file):
|
| 472 |
-
"""
|
| 473 |
-
Saves the current audio frames to a WAV file in a separate thread.
|
| 474 |
-
|
| 475 |
-
Args:
|
| 476 |
-
n_audio_file (int): The index of the audio file which determines the filename.
|
| 477 |
-
This helps in maintaining the order and uniqueness of each chunk.
|
| 478 |
-
"""
|
| 479 |
-
t = threading.Thread(
|
| 480 |
-
target=self.write_audio_frames_to_file,
|
| 481 |
-
args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
|
| 482 |
-
)
|
| 483 |
-
t.start()
|
| 484 |
-
|
| 485 |
-
def finalize_recording(self, n_audio_file):
|
| 486 |
-
"""
|
| 487 |
-
Finalizes the recording process by saving any remaining audio frames,
|
| 488 |
-
closing the audio stream, and terminating the process.
|
| 489 |
-
|
| 490 |
-
Args:
|
| 491 |
-
n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
|
| 492 |
-
This index is incremented before use if the last chunk is saved.
|
| 493 |
-
"""
|
| 494 |
-
if self.save_output_recording and len(self.frames):
|
| 495 |
-
self.write_audio_frames_to_file(
|
| 496 |
-
self.frames[:], f"chunks/{n_audio_file}.wav"
|
| 497 |
-
)
|
| 498 |
-
n_audio_file += 1
|
| 499 |
-
self.stream.stop_stream()
|
| 500 |
-
self.stream.close()
|
| 501 |
-
self.p.terminate()
|
| 502 |
-
self.close_all_clients()
|
| 503 |
-
if self.save_output_recording:
|
| 504 |
-
self.write_output_recording(n_audio_file)
|
| 505 |
-
|
| 506 |
-
def record(self):
|
| 507 |
-
"""
|
| 508 |
-
Record audio data from the input stream and save it to a WAV file.
|
| 509 |
-
|
| 510 |
-
Continuously records audio data from the input stream, sends it to the server via a WebSocket
|
| 511 |
-
connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
|
| 512 |
-
the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
|
| 513 |
-
|
| 514 |
-
Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
|
| 515 |
-
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
|
| 516 |
-
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
|
| 517 |
-
the method combines all the saved audio chunks into the specified `out_file`.
|
| 518 |
-
"""
|
| 519 |
-
n_audio_file = 0
|
| 520 |
-
if self.save_output_recording:
|
| 521 |
-
if os.path.exists("chunks"):
|
| 522 |
-
shutil.rmtree("chunks")
|
| 523 |
-
os.makedirs("chunks")
|
| 524 |
-
try:
|
| 525 |
-
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
|
| 526 |
-
if not any(client.recording for client in self.clients):
|
| 527 |
-
break
|
| 528 |
-
data = self.stream.read(self.chunk, exception_on_overflow=False)
|
| 529 |
-
self.frames += data
|
| 530 |
-
|
| 531 |
-
audio_array = self.bytes_to_float_array(data)
|
| 532 |
-
|
| 533 |
-
self.multicast_packet(audio_array.tobytes())
|
| 534 |
-
|
| 535 |
-
# save frames if more than a minute
|
| 536 |
-
if len(self.frames) > 60 * self.rate:
|
| 537 |
-
if self.save_output_recording:
|
| 538 |
-
self.save_chunk(n_audio_file)
|
| 539 |
-
n_audio_file += 1
|
| 540 |
-
self.frames = b""
|
| 541 |
-
|
| 542 |
-
except KeyboardInterrupt:
|
| 543 |
-
self.finalize_recording(n_audio_file)
|
| 544 |
-
|
| 545 |
-
def write_audio_frames_to_file(self, frames, file_name):
|
| 546 |
-
"""
|
| 547 |
-
Write audio frames to a WAV file.
|
| 548 |
-
|
| 549 |
-
The WAV file is created or overwritten with the specified name. The audio frames should be
|
| 550 |
-
in the correct format and match the specified channel, sample width, and sample rate.
|
| 551 |
-
|
| 552 |
-
Args:
|
| 553 |
-
frames (bytes): The audio frames to be written to the file.
|
| 554 |
-
file_name (str): The name of the WAV file to which the frames will be written.
|
| 555 |
-
|
| 556 |
-
"""
|
| 557 |
-
with wave.open(file_name, "wb") as wavfile:
|
| 558 |
-
wavfile: wave.Wave_write
|
| 559 |
-
wavfile.setnchannels(self.channels)
|
| 560 |
-
wavfile.setsampwidth(2)
|
| 561 |
-
wavfile.setframerate(self.rate)
|
| 562 |
-
wavfile.writeframes(frames)
|
| 563 |
-
|
| 564 |
-
def write_output_recording(self, n_audio_file):
|
| 565 |
-
"""
|
| 566 |
-
Combine and save recorded audio chunks into a single WAV file.
|
| 567 |
-
|
| 568 |
-
The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
|
| 569 |
-
file, appends its audio data to the final recording, and then deletes the chunk file. After combining
|
| 570 |
-
and saving, the final recording is stored in the specified `out_file`.
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
Args:
|
| 574 |
-
n_audio_file (int): The number of audio chunk files to combine.
|
| 575 |
-
out_file (str): The name of the output WAV file to save the final recording.
|
| 576 |
-
|
| 577 |
-
"""
|
| 578 |
-
input_files = [
|
| 579 |
-
f"chunks/{i}.wav"
|
| 580 |
-
for i in range(n_audio_file)
|
| 581 |
-
if os.path.exists(f"chunks/{i}.wav")
|
| 582 |
-
]
|
| 583 |
-
with wave.open(self.output_recording_filename, "wb") as wavfile:
|
| 584 |
-
wavfile: wave.Wave_write
|
| 585 |
-
wavfile.setnchannels(self.channels)
|
| 586 |
-
wavfile.setsampwidth(2)
|
| 587 |
-
wavfile.setframerate(self.rate)
|
| 588 |
-
for in_file in input_files:
|
| 589 |
-
with wave.open(in_file, "rb") as wav_in:
|
| 590 |
-
while True:
|
| 591 |
-
data = wav_in.readframes(self.chunk)
|
| 592 |
-
if data == b"":
|
| 593 |
-
break
|
| 594 |
-
wavfile.writeframes(data)
|
| 595 |
-
# remove this file
|
| 596 |
-
os.remove(in_file)
|
| 597 |
-
wavfile.close()
|
| 598 |
-
# clean up temporary directory to store chunks
|
| 599 |
-
if os.path.exists("chunks"):
|
| 600 |
-
shutil.rmtree("chunks")
|
| 601 |
-
|
| 602 |
-
@staticmethod
|
| 603 |
-
def bytes_to_float_array(audio_bytes):
|
| 604 |
-
"""
|
| 605 |
-
Convert audio data from bytes to a NumPy float array.
|
| 606 |
-
|
| 607 |
-
It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
|
| 608 |
-
have values between -1 and 1.
|
| 609 |
-
|
| 610 |
-
Args:
|
| 611 |
-
audio_bytes (bytes): Audio data in bytes.
|
| 612 |
-
|
| 613 |
-
Returns:
|
| 614 |
-
np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
|
| 615 |
-
"""
|
| 616 |
-
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
|
| 617 |
-
return raw_data.astype(np.float32) / 32768.0
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
class TranscriptionClient(TranscriptionTeeClient):
|
| 621 |
-
"""
|
| 622 |
-
Client for handling audio transcription tasks via a single WebSocket connection.
|
| 623 |
-
|
| 624 |
-
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
| 625 |
-
to send audio data for transcription to a server and receive transcribed text segments.
|
| 626 |
-
|
| 627 |
-
Args:
|
| 628 |
-
host (str): The hostname or IP address of the server.
|
| 629 |
-
port (int): The port number to connect to on the server.
|
| 630 |
-
lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
|
| 631 |
-
save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
|
| 632 |
-
output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
|
| 633 |
-
output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
|
| 634 |
-
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
| 635 |
-
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
| 636 |
-
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
| 637 |
-
mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
|
| 638 |
-
|
| 639 |
-
Attributes:
|
| 640 |
-
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
|
| 641 |
-
|
| 642 |
-
Example:
|
| 643 |
-
To create a TranscriptionClient and start transcription on microphone audio:
|
| 644 |
-
```python
|
| 645 |
-
transcription_client = TranscriptionClient(host="localhost", port=9090)
|
| 646 |
-
transcription_client()
|
| 647 |
-
```
|
| 648 |
-
"""
|
| 649 |
-
|
| 650 |
-
def __init__(
|
| 651 |
-
self,
|
| 652 |
-
host,
|
| 653 |
-
port,
|
| 654 |
-
lang=None,
|
| 655 |
-
save_output_recording=False,
|
| 656 |
-
output_recording_filename="./output_recording.wav",
|
| 657 |
-
log_transcription=True,
|
| 658 |
-
max_clients=4,
|
| 659 |
-
max_connection_time=600,
|
| 660 |
-
mute_audio_playback=False,
|
| 661 |
-
dst_lang='en',
|
| 662 |
-
):
|
| 663 |
-
self.client = Client(
|
| 664 |
-
host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
|
| 665 |
-
max_connection_time=max_connection_time, dst_lang=dst_lang
|
| 666 |
-
)
|
| 667 |
-
|
| 668 |
-
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
| 669 |
-
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
|
| 670 |
-
|
| 671 |
-
TranscriptionTeeClient.__init__(
|
| 672 |
-
self,
|
| 673 |
-
[self.client],
|
| 674 |
-
save_output_recording=save_output_recording,
|
| 675 |
-
output_recording_filename=output_recording_filename,
|
| 676 |
-
mute_audio_playback=mute_audio_playback,
|
| 677 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/helpers/funasr.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
| 1 |
-
import
|
| 2 |
-
import uuid
|
| 3 |
-
from logging import getLogger
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
from funasr import AutoModel
|
| 7 |
-
import soundfile as sf
|
| 8 |
|
| 9 |
import config
|
| 10 |
|
| 11 |
-
logger = getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
class FunASR:
|
|
@@ -16,7 +13,7 @@ class FunASR:
|
|
| 16 |
self.source_lange = source_lange
|
| 17 |
|
| 18 |
self.model = AutoModel(
|
| 19 |
-
model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc"
|
| 20 |
)
|
| 21 |
if warmup:
|
| 22 |
self.warmup()
|
|
@@ -30,8 +27,8 @@ class FunASR:
|
|
| 30 |
audio_frames = np.frombuffer(audio_buffer, dtype=np.float32)
|
| 31 |
# sf.write(f'{config.ASSERT_DIR}/{time.time()}.wav', audio_frames, samplerate=16000)
|
| 32 |
try:
|
| 33 |
-
output = self.model.generate(input=audio_frames, disable_pbar=True)
|
| 34 |
return output
|
| 35 |
except Exception as e:
|
| 36 |
-
|
| 37 |
return []
|
|
|
|
| 1 |
+
# from logging import getLogger
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
from funasr import AutoModel
|
|
|
|
| 5 |
|
| 6 |
import config
|
| 7 |
|
| 8 |
+
# logger = getLogger(__name__)
|
| 9 |
|
| 10 |
|
| 11 |
class FunASR:
|
|
|
|
| 13 |
self.source_lange = source_lange
|
| 14 |
|
| 15 |
self.model = AutoModel(
|
| 16 |
+
model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc", log_level="ERROR",
|
| 17 |
)
|
| 18 |
if warmup:
|
| 19 |
self.warmup()
|
|
|
|
| 27 |
audio_frames = np.frombuffer(audio_buffer, dtype=np.float32)
|
| 28 |
# sf.write(f'{config.ASSERT_DIR}/{time.time()}.wav', audio_frames, samplerate=16000)
|
| 29 |
try:
|
| 30 |
+
output = self.model.generate(input=audio_frames, disable_pbar=True, hotword=config.hotwords_file.as_posix())
|
| 31 |
return output
|
| 32 |
except Exception as e:
|
| 33 |
+
print(f"Error during transcription: {e}")
|
| 34 |
return []
|
transcribe/helpers/vadprocessor.py
CHANGED
|
@@ -36,7 +36,7 @@ class AdaptiveSilenceController:
|
|
| 36 |
speed_factor = 0.5
|
| 37 |
elif avg_speech < 600:
|
| 38 |
speed_factor = 0.8
|
| 39 |
-
|
| 40 |
# 3. silence 的变化趋势也考虑进去
|
| 41 |
adaptive = self.base * speed_factor + 0.3 * avg_silence
|
| 42 |
|
|
@@ -155,7 +155,7 @@ class VADIteratorOnnx:
|
|
| 155 |
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
| 156 |
|
| 157 |
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 158 |
-
self.max_speech_samples = int(sampling_rate * max_speech_duration_s)
|
| 159 |
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 160 |
self.reset_states()
|
| 161 |
|
|
@@ -184,7 +184,7 @@ class VADIteratorOnnx:
|
|
| 184 |
self.current_sample += window_size_samples
|
| 185 |
|
| 186 |
speech_prob = self.model(x, self.sampling_rate)[0,0]
|
| 187 |
-
|
| 188 |
|
| 189 |
if (speech_prob >= self.threshold) and self.temp_end:
|
| 190 |
self.temp_end = 0
|
|
@@ -196,11 +196,11 @@ class VADIteratorOnnx:
|
|
| 196 |
self.start = speech_start
|
| 197 |
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
| 198 |
|
| 199 |
-
if (speech_prob >= self.threshold) and self.current_sample - self.start >= self.max_speech_samples:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
|
| 205 |
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
| 206 |
if not self.temp_end:
|
|
|
|
| 36 |
speed_factor = 0.5
|
| 37 |
elif avg_speech < 600:
|
| 38 |
speed_factor = 0.8
|
| 39 |
+
logging.warning(f"Avg speech :{avg_speech}, Avg silence: {avg_silence}")
|
| 40 |
# 3. silence 的变化趋势也考虑进去
|
| 41 |
adaptive = self.base * speed_factor + 0.3 * avg_silence
|
| 42 |
|
|
|
|
| 155 |
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
| 156 |
|
| 157 |
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 158 |
+
# self.max_speech_samples = int(sampling_rate * max_speech_duration_s)
|
| 159 |
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 160 |
self.reset_states()
|
| 161 |
|
|
|
|
| 184 |
self.current_sample += window_size_samples
|
| 185 |
|
| 186 |
speech_prob = self.model(x, self.sampling_rate)[0,0]
|
| 187 |
+
|
| 188 |
|
| 189 |
if (speech_prob >= self.threshold) and self.temp_end:
|
| 190 |
self.temp_end = 0
|
|
|
|
| 196 |
self.start = speech_start
|
| 197 |
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
| 198 |
|
| 199 |
+
# if (speech_prob >= self.threshold) and self.current_sample - self.start >= self.max_speech_samples:
|
| 200 |
+
# if self.temp_end:
|
| 201 |
+
# self.temp_end = 0
|
| 202 |
+
# self.start = self.current_sample
|
| 203 |
+
# return {'end': int(self.current_sample) if not return_seconds else round(self.current_sample / self.sampling_rate, 1)}
|
| 204 |
|
| 205 |
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
| 206 |
if not self.temp_end:
|
transcribe/pipelines/pipe_vad.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
|
| 2 |
from .base import MetaItem, BasePipe
|
| 3 |
-
from ..helpers.vadprocessor import FixedVADIterator
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
-
from silero_vad import get_speech_timestamps
|
| 7 |
-
from typing import List
|
| 8 |
import logging
|
| 9 |
-
|
| 10 |
# import noisereduce as nr
|
| 11 |
|
| 12 |
|
|
@@ -18,15 +16,12 @@ class VadPipe(BasePipe):
|
|
| 18 |
super().__init__(in_queue, out_queue)
|
| 19 |
self._offset = 0 # 处理的frame size offset
|
| 20 |
self._status = 'END'
|
| 21 |
-
self.last_state_change_offset = 0
|
| 22 |
-
self.adaptive_ctrl = AdaptiveSilenceController()
|
| 23 |
|
| 24 |
|
| 25 |
def reset(self):
|
| 26 |
self._offset = 0
|
| 27 |
self._status = 'END'
|
| 28 |
-
|
| 29 |
-
self.adaptive_ctrl = AdaptiveSilenceController()
|
| 30 |
self.vac.reset_states()
|
| 31 |
|
| 32 |
@classmethod
|
|
@@ -38,7 +33,6 @@ class VadPipe(BasePipe):
|
|
| 38 |
# speech_pad_ms=10
|
| 39 |
min_silence_duration_ms = 80,
|
| 40 |
# speech_pad_ms = 30,
|
| 41 |
-
max_speech_duration_s=25.0,
|
| 42 |
)
|
| 43 |
cls.vac.reset_states()
|
| 44 |
|
|
@@ -55,15 +49,9 @@ class VadPipe(BasePipe):
|
|
| 55 |
if start_frame:
|
| 56 |
relative_start_frame =start_frame - self._offset
|
| 57 |
if end_frame:
|
| 58 |
-
relative_end_frame =
|
| 59 |
return relative_start_frame, relative_end_frame
|
| 60 |
|
| 61 |
-
def update_silence_ms(self):
|
| 62 |
-
min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
|
| 63 |
-
min_silence_samples = self.sample_rate * min_silence / 1000
|
| 64 |
-
self.vac.min_silence_samples = min_silence_samples
|
| 65 |
-
logging.warning(f"🫠 update_silence_ms :{min_silence} => current: {self.vac.min_silence_samples} ")
|
| 66 |
-
|
| 67 |
def process(self, in_data: MetaItem) -> MetaItem:
|
| 68 |
if self._offset == 0:
|
| 69 |
self.vac.reset_states()
|
|
@@ -73,34 +61,19 @@ class VadPipe(BasePipe):
|
|
| 73 |
speech_data = self._process_speech_chunk(source_audio)
|
| 74 |
|
| 75 |
if speech_data: # 表示有音频的变化点出现
|
| 76 |
-
# self.update_silence_ms()
|
| 77 |
rel_start_frame, rel_end_frame = speech_data
|
| 78 |
if rel_start_frame is not None and rel_end_frame is None:
|
| 79 |
self._status = "START" # 语音开始
|
| 80 |
-
target_audio = source_audio[rel_start_frame:]
|
| 81 |
-
|
| 82 |
-
# 计算上一段静音长度
|
| 83 |
-
silence_len = (self._offset + rel_start_frame - self.last_state_change_offset) / self.sample_rate * 1000
|
| 84 |
-
self.adaptive_ctrl.update_silence(silence_len)
|
| 85 |
-
self.last_state_change_offset = self._offset + rel_start_frame
|
| 86 |
-
|
| 87 |
logging.debug("🫸 Speech start frame: {}".format(rel_start_frame))
|
| 88 |
elif rel_start_frame is None and rel_end_frame is not None:
|
| 89 |
self._status = "END" # 音频结束
|
| 90 |
target_audio = source_audio[:rel_end_frame]
|
| 91 |
-
|
| 92 |
-
speech_len = (rel_end_frame) / self.sample_rate * 1000
|
| 93 |
-
self.adaptive_ctrl.update_speech(speech_len)
|
| 94 |
-
self.last_state_change_offset = self._offset + rel_end_frame
|
| 95 |
logging.debug(" 🫷Speech ended, capturing audio up to frame: {}".format(rel_end_frame))
|
| 96 |
else:
|
| 97 |
self._status = 'END'
|
| 98 |
target_audio = source_audio[rel_start_frame:rel_end_frame]
|
| 99 |
logging.debug(" 🔄 Speech segment captured from frame {} to frame {}".format(rel_start_frame, rel_end_frame))
|
| 100 |
-
|
| 101 |
-
seg_len = (rel_end_frame - rel_start_frame) / self.sample_rate * 1000
|
| 102 |
-
self.adaptive_ctrl.update_speech(seg_len)
|
| 103 |
-
self.last_state_change_offset = self._offset + rel_end_frame
|
| 104 |
# logging.debug("❌ No valid speech segment detected, setting status to END")
|
| 105 |
else:
|
| 106 |
if self._status == 'START':
|
|
|
|
| 1 |
|
| 2 |
from .base import MetaItem, BasePipe
|
| 3 |
+
from ..helpers.vadprocessor import FixedVADIterator
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
import logging
|
| 7 |
+
|
| 8 |
# import noisereduce as nr
|
| 9 |
|
| 10 |
|
|
|
|
| 16 |
super().__init__(in_queue, out_queue)
|
| 17 |
self._offset = 0 # 处理的frame size offset
|
| 18 |
self._status = 'END'
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def reset(self):
|
| 22 |
self._offset = 0
|
| 23 |
self._status = 'END'
|
| 24 |
+
|
|
|
|
| 25 |
self.vac.reset_states()
|
| 26 |
|
| 27 |
@classmethod
|
|
|
|
| 33 |
# speech_pad_ms=10
|
| 34 |
min_silence_duration_ms = 80,
|
| 35 |
# speech_pad_ms = 30,
|
|
|
|
| 36 |
)
|
| 37 |
cls.vac.reset_states()
|
| 38 |
|
|
|
|
| 49 |
if start_frame:
|
| 50 |
relative_start_frame =start_frame - self._offset
|
| 51 |
if end_frame:
|
| 52 |
+
relative_end_frame = end_frame - self._offset
|
| 53 |
return relative_start_frame, relative_end_frame
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def process(self, in_data: MetaItem) -> MetaItem:
|
| 56 |
if self._offset == 0:
|
| 57 |
self.vac.reset_states()
|
|
|
|
| 61 |
speech_data = self._process_speech_chunk(source_audio)
|
| 62 |
|
| 63 |
if speech_data: # 表示有音频的变化点出现
|
|
|
|
| 64 |
rel_start_frame, rel_end_frame = speech_data
|
| 65 |
if rel_start_frame is not None and rel_end_frame is None:
|
| 66 |
self._status = "START" # 语音开始
|
| 67 |
+
target_audio = source_audio[max(rel_start_frame-100, 0):]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
logging.debug("🫸 Speech start frame: {}".format(rel_start_frame))
|
| 69 |
elif rel_start_frame is None and rel_end_frame is not None:
|
| 70 |
self._status = "END" # 音频结束
|
| 71 |
target_audio = source_audio[:rel_end_frame]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
logging.debug(" 🫷Speech ended, capturing audio up to frame: {}".format(rel_end_frame))
|
| 73 |
else:
|
| 74 |
self._status = 'END'
|
| 75 |
target_audio = source_audio[rel_start_frame:rel_end_frame]
|
| 76 |
logging.debug(" 🔄 Speech segment captured from frame {} to frame {}".format(rel_start_frame, rel_end_frame))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# logging.debug("❌ No valid speech segment detected, setting status to END")
|
| 78 |
else:
|
| 79 |
if self._status == 'START':
|
transcribe/server.py
DELETED
|
@@ -1,382 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
import threading
|
| 5 |
-
import time
|
| 6 |
-
import config
|
| 7 |
-
import librosa
|
| 8 |
-
import numpy as np
|
| 9 |
-
import soundfile
|
| 10 |
-
from pywhispercpp.model import Model
|
| 11 |
-
|
| 12 |
-
logging.basicConfig(level=logging.INFO)
|
| 13 |
-
|
| 14 |
-
class ServeClientBase(object):
|
| 15 |
-
RATE = 16000
|
| 16 |
-
SERVER_READY = "SERVER_READY"
|
| 17 |
-
DISCONNECT = "DISCONNECT"
|
| 18 |
-
|
| 19 |
-
def __init__(self, client_uid, websocket):
|
| 20 |
-
self.client_uid = client_uid
|
| 21 |
-
self.websocket = websocket
|
| 22 |
-
self.frames = b""
|
| 23 |
-
self.timestamp_offset = 0.0
|
| 24 |
-
self.frames_np = None
|
| 25 |
-
self.frames_offset = 0.0
|
| 26 |
-
self.text = []
|
| 27 |
-
self.current_out = ''
|
| 28 |
-
self.prev_out = ''
|
| 29 |
-
self.t_start = None
|
| 30 |
-
self.exit = False
|
| 31 |
-
self.same_output_count = 0
|
| 32 |
-
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
| 33 |
-
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
| 34 |
-
self.transcript = []
|
| 35 |
-
self.send_last_n_segments = 10
|
| 36 |
-
|
| 37 |
-
# text formatting
|
| 38 |
-
self.pick_previous_segments = 2
|
| 39 |
-
|
| 40 |
-
# threading
|
| 41 |
-
self.lock = threading.Lock()
|
| 42 |
-
|
| 43 |
-
def speech_to_text(self):
|
| 44 |
-
raise NotImplementedError
|
| 45 |
-
|
| 46 |
-
def transcribe_audio(self):
|
| 47 |
-
raise NotImplementedError
|
| 48 |
-
|
| 49 |
-
def handle_transcription_output(self):
|
| 50 |
-
raise NotImplementedError
|
| 51 |
-
|
| 52 |
-
def add_frames(self, frame_np):
|
| 53 |
-
"""
|
| 54 |
-
Add audio frames to the ongoing audio stream buffer.
|
| 55 |
-
|
| 56 |
-
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
| 57 |
-
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
| 58 |
-
to prevent excessive memory usage.
|
| 59 |
-
|
| 60 |
-
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
| 61 |
-
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
| 62 |
-
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
| 66 |
-
|
| 67 |
-
"""
|
| 68 |
-
self.lock.acquire()
|
| 69 |
-
if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
|
| 70 |
-
self.frames_offset += 30.0
|
| 71 |
-
self.frames_np = self.frames_np[int(30 * self.RATE):]
|
| 72 |
-
# check timestamp offset(should be >= self.frame_offset)
|
| 73 |
-
# this basically means that there is no speech as timestamp offset hasnt updated
|
| 74 |
-
# and is less than frame_offset
|
| 75 |
-
if self.timestamp_offset < self.frames_offset:
|
| 76 |
-
self.timestamp_offset = self.frames_offset
|
| 77 |
-
if self.frames_np is None:
|
| 78 |
-
self.frames_np = frame_np.copy()
|
| 79 |
-
else:
|
| 80 |
-
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
| 81 |
-
self.lock.release()
|
| 82 |
-
|
| 83 |
-
def clip_audio_if_no_valid_segment(self):
|
| 84 |
-
"""
|
| 85 |
-
Update the timestamp offset based on audio buffer status.
|
| 86 |
-
Clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
| 87 |
-
no valid segment for the last 30 seconds from whisper
|
| 88 |
-
"""
|
| 89 |
-
with self.lock:
|
| 90 |
-
if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
|
| 91 |
-
duration = self.frames_np.shape[0] / self.RATE
|
| 92 |
-
self.timestamp_offset = self.frames_offset + duration - 5
|
| 93 |
-
|
| 94 |
-
def get_audio_chunk_for_processing(self):
|
| 95 |
-
"""
|
| 96 |
-
Retrieves the next chunk of audio data for processing based on the current offsets.
|
| 97 |
-
|
| 98 |
-
Calculates which part of the audio data should be processed next, based on
|
| 99 |
-
the difference between the current timestamp offset and the frame's offset, scaled by
|
| 100 |
-
the audio sample rate (RATE). It then returns this chunk of audio data along with its
|
| 101 |
-
duration in seconds.
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
tuple: A tuple containing:
|
| 105 |
-
- input_bytes (np.ndarray): The next chunk of audio data to be processed.
|
| 106 |
-
- duration (float): The duration of the audio chunk in seconds.
|
| 107 |
-
"""
|
| 108 |
-
with self.lock:
|
| 109 |
-
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
|
| 110 |
-
input_bytes = self.frames_np[int(samples_take):].copy()
|
| 111 |
-
duration = input_bytes.shape[0] / self.RATE
|
| 112 |
-
return input_bytes, duration
|
| 113 |
-
|
| 114 |
-
def prepare_segments(self, last_segment=None):
|
| 115 |
-
"""
|
| 116 |
-
Prepares the segments of transcribed text to be sent to the client.
|
| 117 |
-
|
| 118 |
-
This method compiles the recent segments of transcribed text, ensuring that only the
|
| 119 |
-
specified number of the most recent segments are included. It also appends the most
|
| 120 |
-
recent segment of text if provided (which is considered incomplete because of the possibility
|
| 121 |
-
of the last word being truncated in the audio chunk).
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
last_segment (str, optional): The most recent segment of transcribed text to be added
|
| 125 |
-
to the list of segments. Defaults to None.
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
list: A list of transcribed text segments to be sent to the client.
|
| 129 |
-
"""
|
| 130 |
-
segments = []
|
| 131 |
-
if len(self.transcript) >= self.send_last_n_segments:
|
| 132 |
-
segments = self.transcript[-self.send_last_n_segments:].copy()
|
| 133 |
-
else:
|
| 134 |
-
segments = self.transcript.copy()
|
| 135 |
-
if last_segment is not None:
|
| 136 |
-
segments = segments + [last_segment]
|
| 137 |
-
logging.info(f"{segments}")
|
| 138 |
-
return segments
|
| 139 |
-
|
| 140 |
-
def get_audio_chunk_duration(self, input_bytes):
|
| 141 |
-
"""
|
| 142 |
-
Calculates the duration of the provided audio chunk.
|
| 143 |
-
|
| 144 |
-
Args:
|
| 145 |
-
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
|
| 146 |
-
|
| 147 |
-
Returns:
|
| 148 |
-
float: The duration of the audio chunk in seconds.
|
| 149 |
-
"""
|
| 150 |
-
return input_bytes.shape[0] / self.RATE
|
| 151 |
-
|
| 152 |
-
def send_transcription_to_client(self, segments):
|
| 153 |
-
"""
|
| 154 |
-
Sends the specified transcription segments to the client over the websocket connection.
|
| 155 |
-
|
| 156 |
-
This method formats the transcription segments into a JSON object and attempts to send
|
| 157 |
-
this object to the client. If an error occurs during the send operation, it logs the error.
|
| 158 |
-
|
| 159 |
-
Returns:
|
| 160 |
-
segments (list): A list of transcription segments to be sent to the client.
|
| 161 |
-
"""
|
| 162 |
-
try:
|
| 163 |
-
self.websocket.send(
|
| 164 |
-
json.dumps({
|
| 165 |
-
"uid": self.client_uid,
|
| 166 |
-
"segments": segments,
|
| 167 |
-
})
|
| 168 |
-
)
|
| 169 |
-
except Exception as e:
|
| 170 |
-
logging.error(f"[ERROR]: Sending data to client: {e}")
|
| 171 |
-
|
| 172 |
-
def disconnect(self):
|
| 173 |
-
"""
|
| 174 |
-
Notify the client of disconnection and send a disconnect message.
|
| 175 |
-
|
| 176 |
-
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
| 177 |
-
that the transcription service is disconnecting gracefully.
|
| 178 |
-
|
| 179 |
-
"""
|
| 180 |
-
self.websocket.send(json.dumps({
|
| 181 |
-
"uid": self.client_uid,
|
| 182 |
-
"message": self.DISCONNECT
|
| 183 |
-
}))
|
| 184 |
-
|
| 185 |
-
def cleanup(self):
|
| 186 |
-
"""
|
| 187 |
-
Perform cleanup tasks before exiting the transcription service.
|
| 188 |
-
|
| 189 |
-
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
| 190 |
-
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
| 191 |
-
associated with the transcription process.
|
| 192 |
-
|
| 193 |
-
"""
|
| 194 |
-
logging.info("Cleaning up.")
|
| 195 |
-
self.exit = True
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
class ServeClientWhisperCPP(ServeClientBase):
|
| 199 |
-
SINGLE_MODEL = None
|
| 200 |
-
SINGLE_MODEL_LOCK = threading.Lock()
|
| 201 |
-
|
| 202 |
-
def __init__(self, websocket, language=None, client_uid=None,
|
| 203 |
-
single_model=False):
|
| 204 |
-
"""
|
| 205 |
-
Initialize a ServeClient instance.
|
| 206 |
-
The Whisper model is initialized based on the client's language and device availability.
|
| 207 |
-
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
| 208 |
-
to the client to indicate that the server is ready.
|
| 209 |
-
|
| 210 |
-
Args:
|
| 211 |
-
websocket (WebSocket): The WebSocket connection for the client.
|
| 212 |
-
language (str, optional): The language for transcription. Defaults to None.
|
| 213 |
-
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
| 214 |
-
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
| 215 |
-
|
| 216 |
-
"""
|
| 217 |
-
super().__init__(client_uid, websocket)
|
| 218 |
-
self.language = language
|
| 219 |
-
self.eos = False
|
| 220 |
-
|
| 221 |
-
if single_model:
|
| 222 |
-
if ServeClientWhisperCPP.SINGLE_MODEL is None:
|
| 223 |
-
self.create_model()
|
| 224 |
-
ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
|
| 225 |
-
else:
|
| 226 |
-
self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
|
| 227 |
-
else:
|
| 228 |
-
self.create_model()
|
| 229 |
-
|
| 230 |
-
# threading
|
| 231 |
-
logging.info('Create a thread to process audio.')
|
| 232 |
-
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
| 233 |
-
self.trans_thread.start()
|
| 234 |
-
|
| 235 |
-
self.websocket.send(json.dumps({
|
| 236 |
-
"uid": self.client_uid,
|
| 237 |
-
"message": self.SERVER_READY,
|
| 238 |
-
"backend": "pywhispercpp"
|
| 239 |
-
}))
|
| 240 |
-
|
| 241 |
-
def create_model(self, warmup=True):
|
| 242 |
-
"""
|
| 243 |
-
Instantiates a new model, sets it as the transcriber and does warmup if desired.
|
| 244 |
-
"""
|
| 245 |
-
|
| 246 |
-
self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR)
|
| 247 |
-
if warmup:
|
| 248 |
-
self.warmup()
|
| 249 |
-
|
| 250 |
-
def warmup(self, warmup_steps=1):
|
| 251 |
-
"""
|
| 252 |
-
Warmup TensorRT since first few inferences are slow.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
warmup_steps (int): Number of steps to warm up the model for.
|
| 256 |
-
"""
|
| 257 |
-
logging.info("[INFO:] Warming up whisper.cpp engine..")
|
| 258 |
-
mel, _, = soundfile.read("assets/jfk.flac")
|
| 259 |
-
for i in range(warmup_steps):
|
| 260 |
-
self.transcriber.transcribe(mel, print_progress=False)
|
| 261 |
-
|
| 262 |
-
def set_eos(self, eos):
|
| 263 |
-
"""
|
| 264 |
-
Sets the End of Speech (EOS) flag.
|
| 265 |
-
|
| 266 |
-
Args:
|
| 267 |
-
eos (bool): The value to set for the EOS flag.
|
| 268 |
-
"""
|
| 269 |
-
self.lock.acquire()
|
| 270 |
-
self.eos = eos
|
| 271 |
-
self.lock.release()
|
| 272 |
-
|
| 273 |
-
def handle_transcription_output(self, last_segment, duration):
|
| 274 |
-
"""
|
| 275 |
-
Handle the transcription output, updating the transcript and sending data to the client.
|
| 276 |
-
|
| 277 |
-
Args:
|
| 278 |
-
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
|
| 279 |
-
of the possibility of word being truncated.
|
| 280 |
-
duration (float): Duration of the transcribed audio chunk.
|
| 281 |
-
"""
|
| 282 |
-
segments = self.prepare_segments({"text": last_segment})
|
| 283 |
-
self.send_transcription_to_client(segments)
|
| 284 |
-
if self.eos:
|
| 285 |
-
self.update_timestamp_offset(last_segment, duration)
|
| 286 |
-
|
| 287 |
-
def transcribe_audio(self, input_bytes):
|
| 288 |
-
"""
|
| 289 |
-
Transcribe the audio chunk and send the results to the client.
|
| 290 |
-
|
| 291 |
-
Args:
|
| 292 |
-
input_bytes (np.array): The audio chunk to transcribe.
|
| 293 |
-
"""
|
| 294 |
-
if ServeClientWhisperCPP.SINGLE_MODEL:
|
| 295 |
-
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
|
| 296 |
-
logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
|
| 297 |
-
mel = input_bytes
|
| 298 |
-
duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
|
| 299 |
-
|
| 300 |
-
if self.language == "zh":
|
| 301 |
-
prompt = '以下是简体中文普通话的句子。'
|
| 302 |
-
else:
|
| 303 |
-
prompt = 'The following is an English sentence.'
|
| 304 |
-
|
| 305 |
-
segments = self.transcriber.transcribe(
|
| 306 |
-
mel,
|
| 307 |
-
language=self.language,
|
| 308 |
-
initial_prompt=prompt,
|
| 309 |
-
token_timestamps=True,
|
| 310 |
-
# max_len=max_len,
|
| 311 |
-
print_progress=False
|
| 312 |
-
)
|
| 313 |
-
text = []
|
| 314 |
-
for segment in segments:
|
| 315 |
-
content = segment.text
|
| 316 |
-
text.append(content)
|
| 317 |
-
last_segment = ' '.join(text)
|
| 318 |
-
|
| 319 |
-
logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
|
| 320 |
-
|
| 321 |
-
if ServeClientWhisperCPP.SINGLE_MODEL:
|
| 322 |
-
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
|
| 323 |
-
if last_segment:
|
| 324 |
-
self.handle_transcription_output(last_segment, duration)
|
| 325 |
-
|
| 326 |
-
def update_timestamp_offset(self, last_segment, duration):
|
| 327 |
-
"""
|
| 328 |
-
Update timestamp offset and transcript.
|
| 329 |
-
|
| 330 |
-
Args:
|
| 331 |
-
last_segment (str): Last transcribed audio from the whisper model.
|
| 332 |
-
duration (float): Duration of the last audio chunk.
|
| 333 |
-
"""
|
| 334 |
-
if not len(self.transcript):
|
| 335 |
-
self.transcript.append({"text": last_segment + " "})
|
| 336 |
-
elif self.transcript[-1]["text"].strip() != last_segment:
|
| 337 |
-
self.transcript.append({"text": last_segment + " "})
|
| 338 |
-
|
| 339 |
-
logging.info(f'Transcript list context: {self.transcript}')
|
| 340 |
-
|
| 341 |
-
with self.lock:
|
| 342 |
-
self.timestamp_offset += duration
|
| 343 |
-
|
| 344 |
-
def speech_to_text(self):
|
| 345 |
-
"""
|
| 346 |
-
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
| 347 |
-
|
| 348 |
-
This method continuously receives audio frames, performs real-time transcription, and sends
|
| 349 |
-
transcribed segments to the client via a WebSocket connection.
|
| 350 |
-
|
| 351 |
-
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
| 352 |
-
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
| 353 |
-
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
| 354 |
-
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
| 355 |
-
there is no speech for a specified duration to indicate a pause.
|
| 356 |
-
|
| 357 |
-
Raises:
|
| 358 |
-
Exception: If there is an issue with audio processing or WebSocket communication.
|
| 359 |
-
|
| 360 |
-
"""
|
| 361 |
-
while True:
|
| 362 |
-
if self.exit:
|
| 363 |
-
logging.info("Exiting speech to text thread")
|
| 364 |
-
break
|
| 365 |
-
|
| 366 |
-
if self.frames_np is None:
|
| 367 |
-
time.sleep(0.02) # wait for any audio to arrive
|
| 368 |
-
continue
|
| 369 |
-
|
| 370 |
-
self.clip_audio_if_no_valid_segment()
|
| 371 |
-
|
| 372 |
-
input_bytes, duration = self.get_audio_chunk_for_processing()
|
| 373 |
-
if duration < 1:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
try:
|
| 377 |
-
input_sample = input_bytes.copy()
|
| 378 |
-
logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
|
| 379 |
-
self.transcribe_audio(input_sample)
|
| 380 |
-
|
| 381 |
-
except Exception as e:
|
| 382 |
-
logging.error(f"[ERROR]: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/strategy.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import collections
|
| 3 |
-
import logging
|
| 4 |
-
from difflib import SequenceMatcher
|
| 5 |
-
from itertools import chain
|
| 6 |
-
from dataclasses import dataclass, field
|
| 7 |
-
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
|
| 8 |
-
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
|
| 9 |
-
from enum import Enum
|
| 10 |
-
import wordninja
|
| 11 |
-
import config
|
| 12 |
-
import re
|
| 13 |
-
logger = logging.getLogger("TranscriptionStrategy")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class SplitMode(Enum):
|
| 17 |
-
PUNCTUATION = "punctuation"
|
| 18 |
-
PAUSE = "pause"
|
| 19 |
-
END = "end"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@dataclass
|
| 24 |
-
class TranscriptResult:
|
| 25 |
-
seg_id: int = 0
|
| 26 |
-
cut_index: int = 0
|
| 27 |
-
is_end_sentence: bool = False
|
| 28 |
-
context: str = ""
|
| 29 |
-
|
| 30 |
-
def partial(self):
|
| 31 |
-
return not self.is_end_sentence
|
| 32 |
-
|
| 33 |
-
@dataclass
|
| 34 |
-
class TranscriptToken:
|
| 35 |
-
"""表示一个转录片段,包含文本和时间信息"""
|
| 36 |
-
text: str # 转录的文本内容
|
| 37 |
-
t0: int # 开始时间(百分之一秒)
|
| 38 |
-
t1: int # 结束时间(百分之一秒)
|
| 39 |
-
|
| 40 |
-
def is_punctuation(self):
|
| 41 |
-
"""检查文本是否包含标点符号"""
|
| 42 |
-
return REGEX_MARKERS.search(self.text.strip()) is not None
|
| 43 |
-
|
| 44 |
-
def is_end(self):
|
| 45 |
-
"""检查文本是否为句子结束标记"""
|
| 46 |
-
return SENTENCE_END_PATTERN.search(self.text.strip()) is not None
|
| 47 |
-
|
| 48 |
-
def is_pause(self):
|
| 49 |
-
"""检查文本是否为暂停标记"""
|
| 50 |
-
return PAUSEE_END_PATTERN.search(self.text.strip()) is not None
|
| 51 |
-
|
| 52 |
-
def buffer_index(self) -> int:
|
| 53 |
-
return int(self.t1 / 100 * SAMPLE_RATE)
|
| 54 |
-
|
| 55 |
-
@dataclass
|
| 56 |
-
class TranscriptChunk:
|
| 57 |
-
"""表示一组转录片段,支持分割和比较操作"""
|
| 58 |
-
separator: str = "" # 用于连接片段的分隔符
|
| 59 |
-
items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
|
| 60 |
-
|
| 61 |
-
@staticmethod
|
| 62 |
-
def _calculate_similarity(text1: str, text2: str) -> float:
|
| 63 |
-
"""计算两段文本的相似度"""
|
| 64 |
-
return SequenceMatcher(None, text1, text2).ratio()
|
| 65 |
-
|
| 66 |
-
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
|
| 67 |
-
"""根据文本中的标点符号分割片段列表"""
|
| 68 |
-
if mode == SplitMode.PUNCTUATION:
|
| 69 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
|
| 70 |
-
elif mode == SplitMode.PAUSE:
|
| 71 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
|
| 72 |
-
elif mode == SplitMode.END:
|
| 73 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
|
| 74 |
-
else:
|
| 75 |
-
raise ValueError(f"Unsupported mode: {mode}")
|
| 76 |
-
|
| 77 |
-
# 每个切分点向后移一个索引,表示“分隔符归前段”
|
| 78 |
-
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
|
| 79 |
-
chunks = [
|
| 80 |
-
TranscriptChunk(items=self.items[start:end], separator=self.separator)
|
| 81 |
-
for start, end in zip(cut_points, cut_points[1:])
|
| 82 |
-
]
|
| 83 |
-
return [
|
| 84 |
-
ck
|
| 85 |
-
for ck in chunks
|
| 86 |
-
if not ck.only_punctuation()
|
| 87 |
-
]
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def get_split_first_rest(self, mode: SplitMode):
|
| 91 |
-
chunks = self.split_by(mode)
|
| 92 |
-
fisrt_chunk = chunks[0] if chunks else self
|
| 93 |
-
rest_chunks = chunks[1:] if chunks else None
|
| 94 |
-
return fisrt_chunk, rest_chunks
|
| 95 |
-
|
| 96 |
-
def puncation_numbers(self) -> int:
|
| 97 |
-
"""计算片段中标点符号的数量"""
|
| 98 |
-
return sum(1 for seg in self.items if seg.is_punctuation())
|
| 99 |
-
|
| 100 |
-
def length(self) -> int:
|
| 101 |
-
"""返回片段列表的长度"""
|
| 102 |
-
return len(self.items)
|
| 103 |
-
|
| 104 |
-
def join(self) -> str:
|
| 105 |
-
"""将片段连接为一个字符串"""
|
| 106 |
-
return self.separator.join(seg.text for seg in self.items)
|
| 107 |
-
|
| 108 |
-
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
|
| 109 |
-
"""比较当前片段与另一个片段的相似度"""
|
| 110 |
-
if not chunk:
|
| 111 |
-
return 0
|
| 112 |
-
|
| 113 |
-
score = self._calculate_similarity(self.join(), chunk.join())
|
| 114 |
-
# logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
|
| 115 |
-
return score
|
| 116 |
-
|
| 117 |
-
def only_punctuation(self)->bool:
|
| 118 |
-
return all(seg.is_punctuation() for seg in self.items)
|
| 119 |
-
|
| 120 |
-
def has_punctuation(self) -> bool:
|
| 121 |
-
return any(seg.is_punctuation() for seg in self.items)
|
| 122 |
-
|
| 123 |
-
def get_buffer_index(self) -> int:
|
| 124 |
-
return self.items[-1].buffer_index()
|
| 125 |
-
|
| 126 |
-
def is_end_sentence(self) ->bool:
|
| 127 |
-
return self.items[-1].is_end()
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
class TranscriptHistory:
|
| 131 |
-
"""管理转录片段的历史记录"""
|
| 132 |
-
|
| 133 |
-
def __init__(self) -> None:
|
| 134 |
-
self.history = collections.deque(maxlen=2) # 存储最近的两个片段
|
| 135 |
-
|
| 136 |
-
def add(self, chunk: TranscriptChunk):
|
| 137 |
-
"""添加新的片段到历史记录"""
|
| 138 |
-
self.history.appendleft(chunk)
|
| 139 |
-
|
| 140 |
-
def previous_chunk(self) -> Optional[TranscriptChunk]:
|
| 141 |
-
"""获取上一个片段(如果存在)"""
|
| 142 |
-
return self.history[1] if len(self.history) == 2 else None
|
| 143 |
-
|
| 144 |
-
def lastest_chunk(self):
|
| 145 |
-
"""获取最后一个片段"""
|
| 146 |
-
return self.history[-1]
|
| 147 |
-
|
| 148 |
-
def clear(self):
|
| 149 |
-
self.history.clear()
|
| 150 |
-
|
| 151 |
-
class TranscriptBuffer:
|
| 152 |
-
"""
|
| 153 |
-
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
|
| 154 |
-
|
| 155 |
-
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
|
| 156 |
-
|
| 157 |
-
管理 pending -> line -> paragraph 的缓冲逻辑
|
| 158 |
-
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
def __init__(self, source_lang:str, separator:str):
|
| 162 |
-
self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
|
| 163 |
-
self._sentences: List[str] = collections.deque() # 当前段落中的短句
|
| 164 |
-
self._buffer: str = "" # 当前缓冲中的文本
|
| 165 |
-
self._current_seg_id: int = 0
|
| 166 |
-
self.source_language = source_lang
|
| 167 |
-
self._separator = separator
|
| 168 |
-
|
| 169 |
-
def get_seg_id(self) -> int:
|
| 170 |
-
return self._current_seg_id
|
| 171 |
-
|
| 172 |
-
@property
|
| 173 |
-
def current_sentences_length(self) -> int:
|
| 174 |
-
count = 0
|
| 175 |
-
for item in self._sentences:
|
| 176 |
-
if self._separator:
|
| 177 |
-
count += len(item.split(self._separator))
|
| 178 |
-
else:
|
| 179 |
-
count += len(item)
|
| 180 |
-
return count
|
| 181 |
-
|
| 182 |
-
def update_pending_text(self, text: str) -> None:
|
| 183 |
-
"""更新临时缓冲字符串"""
|
| 184 |
-
self._buffer = text
|
| 185 |
-
|
| 186 |
-
def commit_line(self,) -> None:
|
| 187 |
-
"""将缓冲字符串提交为短句"""
|
| 188 |
-
if self._buffer:
|
| 189 |
-
self._sentences.append(self._buffer)
|
| 190 |
-
self._buffer = ""
|
| 191 |
-
|
| 192 |
-
def commit_paragraph(self) -> None:
|
| 193 |
-
"""
|
| 194 |
-
提交当前短句为完整段落(如句子结束)
|
| 195 |
-
|
| 196 |
-
Args:
|
| 197 |
-
end_of_sentence: 是否为句子结尾(如检测到句号)
|
| 198 |
-
"""
|
| 199 |
-
|
| 200 |
-
count = 0
|
| 201 |
-
current_sentences = []
|
| 202 |
-
while len(self._sentences): # and count < 20:
|
| 203 |
-
item = self._sentences.popleft()
|
| 204 |
-
current_sentences.append(item)
|
| 205 |
-
if self._separator:
|
| 206 |
-
count += len(item.split(self._separator))
|
| 207 |
-
else:
|
| 208 |
-
count += len(item)
|
| 209 |
-
if current_sentences:
|
| 210 |
-
self._segments.append("".join(current_sentences))
|
| 211 |
-
logger.debug(f"=== count to paragraph ===")
|
| 212 |
-
logger.debug(f"push: {current_sentences}")
|
| 213 |
-
logger.debug(f"rest: {self._sentences}")
|
| 214 |
-
# if self._sentences:
|
| 215 |
-
# self._segments.append("".join(self._sentences))
|
| 216 |
-
# self._sentences.clear()
|
| 217 |
-
|
| 218 |
-
def rebuild(self, text):
|
| 219 |
-
output = self.split_and_join(
|
| 220 |
-
text.replace(
|
| 221 |
-
self._separator, ""))
|
| 222 |
-
|
| 223 |
-
logger.debug("==== rebuild string ====")
|
| 224 |
-
logger.debug(text)
|
| 225 |
-
logger.debug(output)
|
| 226 |
-
|
| 227 |
-
return output
|
| 228 |
-
|
| 229 |
-
@staticmethod
|
| 230 |
-
def split_and_join(text):
|
| 231 |
-
tokens = []
|
| 232 |
-
word_buf = ''
|
| 233 |
-
|
| 234 |
-
for char in text:
|
| 235 |
-
if char in ALL_MARKERS:
|
| 236 |
-
if word_buf:
|
| 237 |
-
tokens.extend(wordninja.split(word_buf))
|
| 238 |
-
word_buf = ''
|
| 239 |
-
tokens.append(char)
|
| 240 |
-
else:
|
| 241 |
-
word_buf += char
|
| 242 |
-
if word_buf:
|
| 243 |
-
tokens.extend(wordninja.split(word_buf))
|
| 244 |
-
|
| 245 |
-
output = ''
|
| 246 |
-
for i, token in enumerate(tokens):
|
| 247 |
-
if i == 0:
|
| 248 |
-
output += token
|
| 249 |
-
elif token in ALL_MARKERS:
|
| 250 |
-
output += (token + " ")
|
| 251 |
-
else:
|
| 252 |
-
output += ' ' + token
|
| 253 |
-
return output
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
|
| 257 |
-
if self.source_language == "en":
|
| 258 |
-
stable_strings = [self.rebuild(i) for i in stable_strings]
|
| 259 |
-
remaining_strings =[self.rebuild(i) for i in remaining_strings]
|
| 260 |
-
remaining_string = "".join(remaining_strings)
|
| 261 |
-
|
| 262 |
-
logger.debug(f"{self.__dict__}")
|
| 263 |
-
if is_end_sentence:
|
| 264 |
-
for stable_str in stable_strings:
|
| 265 |
-
self.update_pending_text(stable_str)
|
| 266 |
-
self.commit_line()
|
| 267 |
-
|
| 268 |
-
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
|
| 269 |
-
# current_text_len = len(self.current_not_commit_text.split(self._separator))
|
| 270 |
-
self.update_pending_text(remaining_string)
|
| 271 |
-
if current_text_len >= config.TEXT_THREHOLD:
|
| 272 |
-
self.commit_paragraph()
|
| 273 |
-
self._current_seg_id += 1
|
| 274 |
-
return True
|
| 275 |
-
else:
|
| 276 |
-
for stable_str in stable_strings:
|
| 277 |
-
self.update_pending_text(stable_str)
|
| 278 |
-
self.commit_line()
|
| 279 |
-
self.update_pending_text(remaining_string)
|
| 280 |
-
return False
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
@property
|
| 284 |
-
def un_commit_paragraph(self) -> str:
|
| 285 |
-
"""当前短句组合"""
|
| 286 |
-
return "".join([i for i in self._sentences])
|
| 287 |
-
|
| 288 |
-
@property
|
| 289 |
-
def pending_text(self) -> str:
|
| 290 |
-
"""当前缓冲内容"""
|
| 291 |
-
return self._buffer
|
| 292 |
-
|
| 293 |
-
@property
|
| 294 |
-
def latest_paragraph(self) -> str:
|
| 295 |
-
"""最新确认的段落"""
|
| 296 |
-
return self._segments[-1] if self._segments else ""
|
| 297 |
-
|
| 298 |
-
@property
|
| 299 |
-
def current_not_commit_text(self) -> str:
|
| 300 |
-
return self.un_commit_paragraph + self.pending_text
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
class TranscriptStabilityAnalyzer:
|
| 305 |
-
def __init__(self, source_lang, separator) -> None:
|
| 306 |
-
self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
|
| 307 |
-
self._transcript_history = TranscriptHistory()
|
| 308 |
-
self._separator = separator
|
| 309 |
-
logger.debug(f"Current separator: {self._separator}")
|
| 310 |
-
|
| 311 |
-
def merge_chunks(self, chunks: List[TranscriptChunk])->str:
|
| 312 |
-
if not chunks:
|
| 313 |
-
return [""]
|
| 314 |
-
output = list(r.join() for r in chunks if r)
|
| 315 |
-
return output
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
|
| 319 |
-
current = TranscriptChunk(items=current, separator=self._separator)
|
| 320 |
-
self._transcript_history.add(current)
|
| 321 |
-
|
| 322 |
-
prev = self._transcript_history.previous_chunk()
|
| 323 |
-
self._transcript_buffer.update_pending_text(current.join())
|
| 324 |
-
if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
|
| 325 |
-
yield TranscriptResult(
|
| 326 |
-
context=self._transcript_buffer.current_not_commit_text,
|
| 327 |
-
seg_id=self._transcript_buffer.get_seg_id()
|
| 328 |
-
)
|
| 329 |
-
return
|
| 330 |
-
|
| 331 |
-
# yield from self._handle_short_buffer(current, prev)
|
| 332 |
-
if buffer_duration <= 4:
|
| 333 |
-
yield from self._handle_short_buffer(current, prev)
|
| 334 |
-
else:
|
| 335 |
-
yield from self._handle_long_buffer(current)
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
|
| 339 |
-
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
|
| 340 |
-
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
|
| 341 |
-
|
| 342 |
-
# logger.debug("==== Current cut item ====")
|
| 343 |
-
# logger.debug(f"{curr.join()} ")
|
| 344 |
-
# logger.debug(f"{prev.join()}")
|
| 345 |
-
# logger.debug("==========================")
|
| 346 |
-
|
| 347 |
-
if curr_first and prev_first:
|
| 348 |
-
|
| 349 |
-
core = curr_first.compare(prev_first)
|
| 350 |
-
has_punctuation = curr_first.has_punctuation()
|
| 351 |
-
if core >= 0.8 and has_punctuation:
|
| 352 |
-
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
|
| 353 |
-
return
|
| 354 |
-
|
| 355 |
-
yield TranscriptResult(
|
| 356 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 357 |
-
context=self._transcript_buffer.current_not_commit_text
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
|
| 362 |
-
chunks = curr.split_by(SplitMode.PUNCTUATION)
|
| 363 |
-
if len(chunks) > 1:
|
| 364 |
-
stable, remaining = chunks[:-1], chunks[-1:]
|
| 365 |
-
# stable_str = self.merge_chunks(stable)
|
| 366 |
-
# remaining_str = self.merge_chunks(remaining)
|
| 367 |
-
yield from self._yield_commit_results(
|
| 368 |
-
stable, remaining, is_end_sentence=True # 暂时硬编码为True
|
| 369 |
-
)
|
| 370 |
-
else:
|
| 371 |
-
yield TranscriptResult(
|
| 372 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 373 |
-
context=self._transcript_buffer.current_not_commit_text
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
|
| 378 |
-
stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
|
| 379 |
-
remaining_str_list = self.merge_chunks(remaining_chunks)
|
| 380 |
-
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
|
| 381 |
-
|
| 382 |
-
prev_seg_id = self._transcript_buffer.get_seg_id()
|
| 383 |
-
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
|
| 384 |
-
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
|
| 385 |
-
|
| 386 |
-
if commit_paragraph:
|
| 387 |
-
# 表示生成了一个新段落 换行
|
| 388 |
-
yield TranscriptResult(
|
| 389 |
-
seg_id=prev_seg_id,
|
| 390 |
-
cut_index=frame_cut_index,
|
| 391 |
-
context=self._transcript_buffer.latest_paragraph,
|
| 392 |
-
is_end_sentence=True
|
| 393 |
-
)
|
| 394 |
-
if (context := self._transcript_buffer.current_not_commit_text.strip()):
|
| 395 |
-
yield TranscriptResult(
|
| 396 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 397 |
-
context=context,
|
| 398 |
-
)
|
| 399 |
-
else:
|
| 400 |
-
yield TranscriptResult(
|
| 401 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
| 402 |
-
cut_index=frame_cut_index,
|
| 403 |
-
context=self._transcript_buffer.current_not_commit_text,
|
| 404 |
-
)
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/transcription.py
DELETED
|
@@ -1,334 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import time
|
| 3 |
-
import functools
|
| 4 |
-
import json
|
| 5 |
-
import logging
|
| 6 |
-
import time
|
| 7 |
-
from enum import Enum
|
| 8 |
-
from typing import List, Optional
|
| 9 |
-
import numpy as np
|
| 10 |
-
from .server import ServeClientBase
|
| 11 |
-
from .whisper_llm_serve import PyWhiperCppServe
|
| 12 |
-
from .vad import VoiceActivityDetector
|
| 13 |
-
from urllib.parse import urlparse, parse_qsl
|
| 14 |
-
from websockets.exceptions import ConnectionClosed
|
| 15 |
-
from websockets.sync.server import serve
|
| 16 |
-
from uuid import uuid1
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
logging.basicConfig(level=logging.INFO)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class ClientManager:
|
| 23 |
-
def __init__(self, max_clients=4, max_connection_time=600):
|
| 24 |
-
"""
|
| 25 |
-
Initializes the ClientManager with specified limits on client connections and connection durations.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
|
| 29 |
-
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
|
| 30 |
-
to 600 seconds (10 minutes).
|
| 31 |
-
"""
|
| 32 |
-
self.clients = {}
|
| 33 |
-
self.start_times = {}
|
| 34 |
-
self.max_clients = max_clients
|
| 35 |
-
self.max_connection_time = max_connection_time
|
| 36 |
-
|
| 37 |
-
def add_client(self, websocket, client):
|
| 38 |
-
"""
|
| 39 |
-
Adds a client and their connection start time to the tracking dictionaries.
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
websocket: The websocket associated with the client to add.
|
| 43 |
-
client: The client object to be added and tracked.
|
| 44 |
-
"""
|
| 45 |
-
self.clients[websocket] = client
|
| 46 |
-
self.start_times[websocket] = time.time()
|
| 47 |
-
|
| 48 |
-
def get_client(self, websocket):
|
| 49 |
-
"""
|
| 50 |
-
Retrieves a client associated with the given websocket.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
websocket: The websocket associated with the client to retrieve.
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
The client object if found, False otherwise.
|
| 57 |
-
"""
|
| 58 |
-
if websocket in self.clients:
|
| 59 |
-
return self.clients[websocket]
|
| 60 |
-
return False
|
| 61 |
-
|
| 62 |
-
def remove_client(self, websocket):
|
| 63 |
-
"""
|
| 64 |
-
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
|
| 65 |
-
client if necessary.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
websocket: The websocket associated with the client to be removed.
|
| 69 |
-
"""
|
| 70 |
-
client = self.clients.pop(websocket, None)
|
| 71 |
-
if client:
|
| 72 |
-
client.cleanup()
|
| 73 |
-
self.start_times.pop(websocket, None)
|
| 74 |
-
|
| 75 |
-
def get_wait_time(self):
|
| 76 |
-
"""
|
| 77 |
-
Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
|
| 81 |
-
"""
|
| 82 |
-
wait_time = None
|
| 83 |
-
for start_time in self.start_times.values():
|
| 84 |
-
current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
|
| 85 |
-
if wait_time is None or current_client_time_remaining < wait_time:
|
| 86 |
-
wait_time = current_client_time_remaining
|
| 87 |
-
return wait_time / 60 if wait_time is not None else 0
|
| 88 |
-
|
| 89 |
-
def is_server_full(self, websocket, options):
|
| 90 |
-
"""
|
| 91 |
-
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
|
| 92 |
-
|
| 93 |
-
Args:
|
| 94 |
-
websocket: The websocket of the client attempting to connect.
|
| 95 |
-
options: A dictionary of options that may include the client's unique identifier.
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
True if the server is full, False otherwise.
|
| 99 |
-
"""
|
| 100 |
-
if len(self.clients) >= self.max_clients:
|
| 101 |
-
wait_time = self.get_wait_time()
|
| 102 |
-
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
|
| 103 |
-
websocket.send(json.dumps(response))
|
| 104 |
-
return True
|
| 105 |
-
return False
|
| 106 |
-
|
| 107 |
-
def is_client_timeout(self, websocket):
|
| 108 |
-
"""
|
| 109 |
-
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
|
| 110 |
-
|
| 111 |
-
Args:
|
| 112 |
-
websocket: The websocket associated with the client to check.
|
| 113 |
-
|
| 114 |
-
Returns:
|
| 115 |
-
True if the client's connection time has exceeded the maximum limit, False otherwise.
|
| 116 |
-
"""
|
| 117 |
-
elapsed_time = time.time() - self.start_times[websocket]
|
| 118 |
-
if elapsed_time >= self.max_connection_time:
|
| 119 |
-
self.clients[websocket].disconnect()
|
| 120 |
-
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
|
| 121 |
-
return True
|
| 122 |
-
return False
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class BackendType(Enum):
|
| 126 |
-
PYWHISPERCPP = "pywhispercpp"
|
| 127 |
-
|
| 128 |
-
@staticmethod
|
| 129 |
-
def valid_types() -> List[str]:
|
| 130 |
-
return [backend_type.value for backend_type in BackendType]
|
| 131 |
-
|
| 132 |
-
@staticmethod
|
| 133 |
-
def is_valid(backend: str) -> bool:
|
| 134 |
-
return backend in BackendType.valid_types()
|
| 135 |
-
|
| 136 |
-
def is_pywhispercpp(self) -> bool:
|
| 137 |
-
return self == BackendType.PYWHISPERCPP
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class TranscriptionServer:
|
| 141 |
-
RATE = 16000
|
| 142 |
-
|
| 143 |
-
def __init__(self):
|
| 144 |
-
self.client_manager = None
|
| 145 |
-
self.no_voice_activity_chunks = 0
|
| 146 |
-
self.single_model = False
|
| 147 |
-
|
| 148 |
-
def initialize_client(
|
| 149 |
-
self, websocket, options
|
| 150 |
-
):
|
| 151 |
-
client: Optional[ServeClientBase] = None
|
| 152 |
-
|
| 153 |
-
if self.backend.is_pywhispercpp():
|
| 154 |
-
client = PyWhiperCppServe(
|
| 155 |
-
websocket,
|
| 156 |
-
language=options["language"],
|
| 157 |
-
client_uid=options["uid"],
|
| 158 |
-
)
|
| 159 |
-
logging.info("Running pywhispercpp backend.")
|
| 160 |
-
|
| 161 |
-
if client is None:
|
| 162 |
-
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
|
| 163 |
-
|
| 164 |
-
self.client_manager.add_client(websocket, client)
|
| 165 |
-
|
| 166 |
-
def get_audio_from_websocket(self, websocket):
|
| 167 |
-
"""
|
| 168 |
-
Receives audio buffer from websocket and creates a numpy array out of it.
|
| 169 |
-
|
| 170 |
-
Args:
|
| 171 |
-
websocket: The websocket to receive audio from.
|
| 172 |
-
|
| 173 |
-
Returns:
|
| 174 |
-
A numpy array containing the audio.
|
| 175 |
-
"""
|
| 176 |
-
frame_data = websocket.recv()
|
| 177 |
-
if frame_data == b"END_OF_AUDIO":
|
| 178 |
-
return False
|
| 179 |
-
return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0
|
| 180 |
-
# return np.frombuffer(frame_data, dtype=np.float32)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
def handle_new_connection(self, websocket):
|
| 184 |
-
query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query))
|
| 185 |
-
from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
|
| 186 |
-
|
| 187 |
-
try:
|
| 188 |
-
logging.info("New client connected")
|
| 189 |
-
options = websocket.recv()
|
| 190 |
-
try:
|
| 191 |
-
options = json.loads(options)
|
| 192 |
-
except Exception as e:
|
| 193 |
-
options = {"language": from_lang, "uid": str(uuid1())}
|
| 194 |
-
if self.client_manager is None:
|
| 195 |
-
max_clients = options.get('max_clients', 4)
|
| 196 |
-
max_connection_time = options.get('max_connection_time', 600)
|
| 197 |
-
self.client_manager = ClientManager(max_clients, max_connection_time)
|
| 198 |
-
|
| 199 |
-
if self.client_manager.is_server_full(websocket, options):
|
| 200 |
-
websocket.close()
|
| 201 |
-
return False # Indicates that the connection should not continue
|
| 202 |
-
|
| 203 |
-
if self.backend.is_pywhispercpp():
|
| 204 |
-
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
|
| 205 |
-
|
| 206 |
-
self.initialize_client(websocket, options)
|
| 207 |
-
if from_lang and to_lang:
|
| 208 |
-
self.set_lang(websocket, from_lang, to_lang)
|
| 209 |
-
logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 210 |
-
return True
|
| 211 |
-
except json.JSONDecodeError:
|
| 212 |
-
logging.error("Failed to decode JSON from client")
|
| 213 |
-
return False
|
| 214 |
-
except ConnectionClosed:
|
| 215 |
-
logging.info("Connection closed by client")
|
| 216 |
-
return False
|
| 217 |
-
except Exception as e:
|
| 218 |
-
logging.error(f"Error during new connection initialization: {str(e)}")
|
| 219 |
-
return False
|
| 220 |
-
|
| 221 |
-
def process_audio_frames(self, websocket):
|
| 222 |
-
frame_np = self.get_audio_from_websocket(websocket)
|
| 223 |
-
client = self.client_manager.get_client(websocket)
|
| 224 |
-
|
| 225 |
-
# TODO Vad has some problem, it will be blocking process loop
|
| 226 |
-
# if frame_np is False:
|
| 227 |
-
# if self.backend.is_pywhispercpp():
|
| 228 |
-
# client.set_eos(True)
|
| 229 |
-
# return False
|
| 230 |
-
|
| 231 |
-
# if self.backend.is_pywhispercpp():
|
| 232 |
-
# voice_active = self.voice_activity(websocket, frame_np)
|
| 233 |
-
# if voice_active:
|
| 234 |
-
# self.no_voice_activity_chunks = 0
|
| 235 |
-
# client.set_eos(False)
|
| 236 |
-
# if self.use_vad and not voice_active:
|
| 237 |
-
# return True
|
| 238 |
-
|
| 239 |
-
client.add_frames(frame_np)
|
| 240 |
-
return True
|
| 241 |
-
|
| 242 |
-
def set_lang(self, websocket, src_lang, dst_lang):
|
| 243 |
-
client = self.client_manager.get_client(websocket)
|
| 244 |
-
if isinstance(client, PyWhiperCppServe):
|
| 245 |
-
client.set_lang(src_lang, dst_lang)
|
| 246 |
-
|
| 247 |
-
def recv_audio(self,
|
| 248 |
-
websocket,
|
| 249 |
-
backend: BackendType = BackendType.PYWHISPERCPP):
|
| 250 |
-
|
| 251 |
-
self.backend = backend
|
| 252 |
-
if not self.handle_new_connection(websocket):
|
| 253 |
-
return
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
try:
|
| 257 |
-
while not self.client_manager.is_client_timeout(websocket):
|
| 258 |
-
if not self.process_audio_frames(websocket):
|
| 259 |
-
break
|
| 260 |
-
except ConnectionClosed:
|
| 261 |
-
logging.info("Connection closed by client")
|
| 262 |
-
except Exception as e:
|
| 263 |
-
logging.error(f"Unexpected error: {str(e)}")
|
| 264 |
-
finally:
|
| 265 |
-
if self.client_manager.get_client(websocket):
|
| 266 |
-
self.cleanup(websocket)
|
| 267 |
-
websocket.close()
|
| 268 |
-
del websocket
|
| 269 |
-
|
| 270 |
-
def run(self,
|
| 271 |
-
host,
|
| 272 |
-
port=9090,
|
| 273 |
-
backend="pywhispercpp"):
|
| 274 |
-
"""
|
| 275 |
-
Run the transcription server.
|
| 276 |
-
|
| 277 |
-
Args:
|
| 278 |
-
host (str): The host address to bind the server.
|
| 279 |
-
port (int): The port number to bind the server.
|
| 280 |
-
"""
|
| 281 |
-
|
| 282 |
-
if not BackendType.is_valid(backend):
|
| 283 |
-
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
|
| 284 |
-
|
| 285 |
-
with serve(
|
| 286 |
-
functools.partial(
|
| 287 |
-
self.recv_audio,
|
| 288 |
-
backend=BackendType(backend),
|
| 289 |
-
),
|
| 290 |
-
host,
|
| 291 |
-
port
|
| 292 |
-
) as server:
|
| 293 |
-
server.serve_forever()
|
| 294 |
-
|
| 295 |
-
def voice_activity(self, websocket, frame_np):
|
| 296 |
-
"""
|
| 297 |
-
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
|
| 298 |
-
|
| 299 |
-
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
|
| 300 |
-
contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
|
| 301 |
-
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
|
| 302 |
-
speech detection to improve subsequent processing steps.
|
| 303 |
-
|
| 304 |
-
Args:
|
| 305 |
-
websocket: The websocket associated with the current client. Used to retrieve the client object
|
| 306 |
-
from the client manager for state management.
|
| 307 |
-
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
|
| 308 |
-
the audio data for the current frame.
|
| 309 |
-
|
| 310 |
-
Returns:
|
| 311 |
-
bool: True if voice activity is detected in the current frame, False otherwise. When returning False
|
| 312 |
-
after detecting no voice activity for more than three consecutive frames, it also triggers the
|
| 313 |
-
end-of-speech (EOS) flag for the client.
|
| 314 |
-
"""
|
| 315 |
-
if not self.vad_detector(frame_np):
|
| 316 |
-
self.no_voice_activity_chunks += 1
|
| 317 |
-
if self.no_voice_activity_chunks > 3:
|
| 318 |
-
client = self.client_manager.get_client(websocket)
|
| 319 |
-
if not client.eos:
|
| 320 |
-
client.set_eos(True)
|
| 321 |
-
time.sleep(0.1) # Sleep 100m; wait some voice activity.
|
| 322 |
-
return False
|
| 323 |
-
return True
|
| 324 |
-
|
| 325 |
-
def cleanup(self, websocket):
|
| 326 |
-
"""
|
| 327 |
-
Cleans up resources associated with a given client's websocket.
|
| 328 |
-
|
| 329 |
-
Args:
|
| 330 |
-
websocket: The websocket associated with the client to be cleaned up.
|
| 331 |
-
"""
|
| 332 |
-
if self.client_manager.get_client(websocket):
|
| 333 |
-
self.client_manager.remove_client(websocket)
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/translatepipes.py
CHANGED
|
@@ -3,9 +3,7 @@ from transcribe.pipelines import WhisperPipe, MetaItem, WhisperChinese, Translat
|
|
| 3 |
|
| 4 |
class TranslatePipes:
|
| 5 |
def __init__(self) -> None:
|
| 6 |
-
|
| 7 |
-
# self.translate_input_q = mp.Queue()
|
| 8 |
-
# self.result_queue = mp.Queue()
|
| 9 |
self._process = []
|
| 10 |
# whisper 转录
|
| 11 |
self._whisper_pipe_en = self._launch_process(WhisperPipe())
|
|
@@ -14,13 +12,9 @@ class TranslatePipes:
|
|
| 14 |
|
| 15 |
# llm 翻译
|
| 16 |
# self._translate_pipe = self._launch_process(TranslatePipe())
|
| 17 |
-
|
| 18 |
self._translate_7b_pipe = self._launch_process(Translate7BPipe())
|
| 19 |
# vad
|
| 20 |
self._vad_pipe = self._launch_process(VadPipe())
|
| 21 |
-
|
| 22 |
-
# def reset(self):
|
| 23 |
-
# self._vad_pipe.reset()
|
| 24 |
|
| 25 |
def _launch_process(self, process_obj):
|
| 26 |
process_obj.daemon = True
|
|
@@ -48,17 +42,12 @@ class TranslatePipes:
|
|
| 48 |
self._translate_7b_pipe.input_queue.put(item)
|
| 49 |
return self._translate_7b_pipe.output_queue.get()
|
| 50 |
|
| 51 |
-
def get_whisper_model(self, lang: str = 'en'):
|
| 52 |
-
if lang == 'zh':
|
| 53 |
-
return self._whisper_pipe_zh
|
| 54 |
-
return self._whisper_pipe_en
|
| 55 |
-
|
| 56 |
def get_transcription_model(self, lang: str = 'en'):
|
| 57 |
if lang == 'zh':
|
| 58 |
return self._funasr_pipe
|
| 59 |
return self._whisper_pipe_en
|
| 60 |
|
| 61 |
-
def
|
| 62 |
transcription_model = self.get_transcription_model(src_lang)
|
| 63 |
item = MetaItem(audio=audio_buffer, source_language=src_lang)
|
| 64 |
transcription_model.input_queue.put(item)
|
|
@@ -76,6 +65,6 @@ if __name__ == "__main__":
|
|
| 76 |
tp = TranslatePipes()
|
| 77 |
# result = tp.translate("你好,今天天气怎么样?", src_lang="zh", dst_lang="en")
|
| 78 |
mel, _, = soundfile.read("assets/jfk.flac")
|
| 79 |
-
# result = tp.
|
| 80 |
result = tp.voice_detect(mel)
|
| 81 |
print(result)
|
|
|
|
| 3 |
|
| 4 |
class TranslatePipes:
|
| 5 |
def __init__(self) -> None:
|
| 6 |
+
|
|
|
|
|
|
|
| 7 |
self._process = []
|
| 8 |
# whisper 转录
|
| 9 |
self._whisper_pipe_en = self._launch_process(WhisperPipe())
|
|
|
|
| 12 |
|
| 13 |
# llm 翻译
|
| 14 |
# self._translate_pipe = self._launch_process(TranslatePipe())
|
|
|
|
| 15 |
self._translate_7b_pipe = self._launch_process(Translate7BPipe())
|
| 16 |
# vad
|
| 17 |
self._vad_pipe = self._launch_process(VadPipe())
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def _launch_process(self, process_obj):
|
| 20 |
process_obj.daemon = True
|
|
|
|
| 42 |
self._translate_7b_pipe.input_queue.put(item)
|
| 43 |
return self._translate_7b_pipe.output_queue.get()
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def get_transcription_model(self, lang: str = 'en'):
|
| 46 |
if lang == 'zh':
|
| 47 |
return self._funasr_pipe
|
| 48 |
return self._whisper_pipe_en
|
| 49 |
|
| 50 |
+
def transcribe(self, audio_buffer: bytes, src_lang: str) -> MetaItem:
|
| 51 |
transcription_model = self.get_transcription_model(src_lang)
|
| 52 |
item = MetaItem(audio=audio_buffer, source_language=src_lang)
|
| 53 |
transcription_model.input_queue.put(item)
|
|
|
|
| 65 |
tp = TranslatePipes()
|
| 66 |
# result = tp.translate("你好,今天天气怎么样?", src_lang="zh", dst_lang="en")
|
| 67 |
mel, _, = soundfile.read("assets/jfk.flac")
|
| 68 |
+
# result = tp.transcribe(mel, 'en')
|
| 69 |
result = tp.voice_detect(mel)
|
| 70 |
print(result)
|
transcribe/utils.py
CHANGED
|
@@ -8,6 +8,7 @@ import config
|
|
| 8 |
import csv
|
| 9 |
import av
|
| 10 |
import re
|
|
|
|
| 11 |
|
| 12 |
# Compile regex patterns once outside the loop for better performance
|
| 13 |
p_pattern = re.compile(r"(\s*\[.*?\])")
|
|
@@ -18,43 +19,67 @@ p_end_pattern = re.compile(r"(\s*.*\])")
|
|
| 18 |
def filter_words(res_word):
|
| 19 |
"""
|
| 20 |
Filter words according to specific bracket patterns.
|
| 21 |
-
|
| 22 |
Args:
|
| 23 |
res_word: Iterable of word objects with a 'text' attribute
|
| 24 |
-
|
| 25 |
Returns:
|
| 26 |
List of filtered word objects
|
| 27 |
"""
|
| 28 |
asr_results = []
|
| 29 |
skip_word = False
|
| 30 |
-
|
| 31 |
for word in res_word:
|
| 32 |
# Skip words that completely match the pattern
|
| 33 |
if p_pattern.match(word.text):
|
| 34 |
continue
|
| 35 |
-
|
| 36 |
# Mark the start of a section to skip
|
| 37 |
if p_start_pattern.match(word.text):
|
| 38 |
skip_word = True
|
| 39 |
continue
|
| 40 |
-
|
| 41 |
# Mark the end of a section to skip
|
| 42 |
if p_end_pattern.match(word.text) and skip_word:
|
| 43 |
skip_word = False
|
| 44 |
continue
|
| 45 |
-
|
| 46 |
# Skip words if we're in a skip section
|
| 47 |
if skip_word:
|
| 48 |
continue
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
# Add the word to results if it passed all filters
|
| 51 |
asr_results.append(word)
|
| 52 |
-
|
| 53 |
return asr_results
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def log_block(key: str, value, unit=''):
|
| 56 |
if config.DEBUG:
|
| 57 |
-
return
|
| 58 |
"""格式化输出日志内容"""
|
| 59 |
key_fmt = f"[ {key.ljust(25)}]" # 左对齐填充
|
| 60 |
val_fmt = f"{value} {unit}".strip()
|
|
@@ -157,8 +182,8 @@ class TestDataWriter:
|
|
| 157 |
def __init__(self, file_path='test_data.csv'):
|
| 158 |
self.file_path = file_path
|
| 159 |
self.fieldnames = [
|
| 160 |
-
'seg_id', '
|
| 161 |
-
'
|
| 162 |
]
|
| 163 |
self._ensure_file_has_header()
|
| 164 |
|
|
@@ -171,4 +196,4 @@ class TestDataWriter:
|
|
| 171 |
def write(self, result: 'DebugResult'):
|
| 172 |
with open(self.file_path, mode='a', newline='') as file:
|
| 173 |
writer = csv.DictWriter(file, fieldnames=self.fieldnames)
|
| 174 |
-
writer.writerow(result.model_dump(by_alias=True))
|
|
|
|
| 8 |
import csv
|
| 9 |
import av
|
| 10 |
import re
|
| 11 |
+
import json
|
| 12 |
|
| 13 |
# Compile regex patterns once outside the loop for better performance
|
| 14 |
p_pattern = re.compile(r"(\s*\[.*?\])")
|
|
|
|
| 19 |
def filter_words(res_word):
|
| 20 |
"""
|
| 21 |
Filter words according to specific bracket patterns.
|
| 22 |
+
|
| 23 |
Args:
|
| 24 |
res_word: Iterable of word objects with a 'text' attribute
|
| 25 |
+
|
| 26 |
Returns:
|
| 27 |
List of filtered word objects
|
| 28 |
"""
|
| 29 |
asr_results = []
|
| 30 |
skip_word = False
|
| 31 |
+
|
| 32 |
for word in res_word:
|
| 33 |
# Skip words that completely match the pattern
|
| 34 |
if p_pattern.match(word.text):
|
| 35 |
continue
|
| 36 |
+
|
| 37 |
# Mark the start of a section to skip
|
| 38 |
if p_start_pattern.match(word.text):
|
| 39 |
skip_word = True
|
| 40 |
continue
|
| 41 |
+
|
| 42 |
# Mark the end of a section to skip
|
| 43 |
if p_end_pattern.match(word.text) and skip_word:
|
| 44 |
skip_word = False
|
| 45 |
continue
|
| 46 |
+
|
| 47 |
# Skip words if we're in a skip section
|
| 48 |
if skip_word:
|
| 49 |
continue
|
| 50 |
+
|
| 51 |
+
word.text = replace_hotwords(word.text)
|
| 52 |
+
|
| 53 |
# Add the word to results if it passed all filters
|
| 54 |
asr_results.append(word)
|
| 55 |
+
|
| 56 |
return asr_results
|
| 57 |
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def replace_hotwords(text: str) -> str:
|
| 61 |
+
"""
|
| 62 |
+
Reads hotwords from a JSON file and replaces occurrences in the input text.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
text: The input string to process.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
The string with hotwords replaced.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
processed_text = text
|
| 72 |
+
# Iterate through the hotwords dictionary
|
| 73 |
+
for key, value in config.hotwords_json.items():
|
| 74 |
+
# Replace all occurrences of the key with the value in the text
|
| 75 |
+
processed_text = processed_text.replace(key, value)
|
| 76 |
+
logging.debug(f"Replace string: {text} => {processed_text}")
|
| 77 |
+
return processed_text
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def log_block(key: str, value, unit=''):
|
| 81 |
if config.DEBUG:
|
| 82 |
+
return
|
| 83 |
"""格式化输出日志内容"""
|
| 84 |
key_fmt = f"[ {key.ljust(25)}]" # 左对齐填充
|
| 85 |
val_fmt = f"{value} {unit}".strip()
|
|
|
|
| 182 |
def __init__(self, file_path='test_data.csv'):
|
| 183 |
self.file_path = file_path
|
| 184 |
self.fieldnames = [
|
| 185 |
+
'seg_id', 'transcribe_time', 'translate_time',
|
| 186 |
+
'transcribeContent', 'from', 'to', 'translateContent', 'partial'
|
| 187 |
]
|
| 188 |
self._ensure_file_has_header()
|
| 189 |
|
|
|
|
| 196 |
def write(self, result: 'DebugResult'):
|
| 197 |
with open(self.file_path, mode='a', newline='') as file:
|
| 198 |
writer = csv.DictWriter(file, fieldnames=self.fieldnames)
|
| 199 |
+
writer.writerow(result.model_dump(by_alias=True))
|
transcribe/whisper_llm_serve.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
import json
|
| 3 |
import queue
|
| 4 |
import threading
|
| 5 |
import time
|
| 6 |
from logging import getLogger
|
| 7 |
-
from typing import List, Optional, Iterator, Tuple, Any
|
| 8 |
import asyncio
|
| 9 |
import numpy as np
|
| 10 |
import config
|
|
@@ -13,16 +11,26 @@ from api_model import TransResult, Message, DebugResult
|
|
| 13 |
|
| 14 |
from .utils import log_block, save_to_wave, TestDataWriter, filter_words
|
| 15 |
from .translatepipes import TranslatePipes
|
| 16 |
-
|
| 17 |
-
TranscriptStabilityAnalyzer, TranscriptToken)
|
| 18 |
-
from transcribe.helpers.vadprocessor import VadProcessor
|
| 19 |
-
# from transcribe.helpers.vad_dynamic import VadProcessor
|
| 20 |
-
# from transcribe.helpers.vadprocessor import VadProcessor
|
| 21 |
from transcribe.pipelines import MetaItem
|
| 22 |
|
|
|
|
| 23 |
logger = getLogger("TranscriptionService")
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class WhisperTranscriptionService:
|
| 27 |
"""
|
| 28 |
Whisper语音转录服务类,处理音频流转录和翻译
|
|
@@ -42,45 +50,35 @@ class WhisperTranscriptionService:
|
|
| 42 |
self._translate_pipe = pipe
|
| 43 |
|
| 44 |
# 音频处理相关
|
| 45 |
-
self.sample_rate =
|
| 46 |
|
| 47 |
self.lock = threading.Lock()
|
| 48 |
-
|
| 49 |
-
|
| 50 |
# 文本分隔符,根据语言设置
|
| 51 |
-
self.text_separator =
|
| 52 |
self.loop = asyncio.get_event_loop()
|
| 53 |
-
# 发送就绪状态
|
| 54 |
# 原始音频队列
|
| 55 |
self._frame_queue = queue.Queue()
|
| 56 |
# 音频队列缓冲区
|
| 57 |
-
self.frames_np =
|
|
|
|
| 58 |
# 完整音频队列
|
| 59 |
-
self.
|
| 60 |
-
self._temp_string = ""
|
| 61 |
-
|
| 62 |
-
self._transcrible_analysis = None
|
| 63 |
# 启动处理线程
|
| 64 |
self._translate_thread_stop = threading.Event()
|
| 65 |
self._frame_processing_thread_stop = threading.Event()
|
| 66 |
|
| 67 |
-
self.translate_thread =
|
| 68 |
-
self.frame_processing_thread =
|
| 69 |
-
# if language == "zh":
|
| 70 |
-
# self._vad = VadProcessor(prob_threshold=0.8, silence_s=0.2, cache_s=0.15)
|
| 71 |
-
# else:
|
| 72 |
-
# self._vad = VadProcessor(prob_threshold=0.7, silence_s=0.2, cache_s=0.15)
|
| 73 |
self.row_number = 0
|
| 74 |
# for test
|
| 75 |
-
self.
|
| 76 |
self._translate_time_cost = 0.
|
| 77 |
|
| 78 |
if config.SAVE_DATA_SAVE:
|
| 79 |
self._save_task_stop = threading.Event()
|
| 80 |
self._save_queue = queue.Queue()
|
| 81 |
-
self._save_thread =
|
| 82 |
|
| 83 |
-
# self._c = 0
|
| 84 |
|
| 85 |
def save_data_loop(self):
|
| 86 |
writer = TestDataWriter()
|
|
@@ -88,33 +86,6 @@ class WhisperTranscriptionService:
|
|
| 88 |
test_data = self._save_queue.get()
|
| 89 |
writer.write(test_data) # Save test_data to CSV
|
| 90 |
|
| 91 |
-
|
| 92 |
-
def _start_thread(self, target_function) -> threading.Thread:
|
| 93 |
-
"""启动守护线程执行指定函数"""
|
| 94 |
-
thread = threading.Thread(target=target_function)
|
| 95 |
-
thread.daemon = True
|
| 96 |
-
thread.start()
|
| 97 |
-
return thread
|
| 98 |
-
|
| 99 |
-
def _get_text_separator(self, language: str) -> str:
|
| 100 |
-
"""根据语言返回适当的文本分隔符"""
|
| 101 |
-
return "" if language == "zh" else " "
|
| 102 |
-
|
| 103 |
-
async def send_ready_state(self) -> None:
|
| 104 |
-
"""发送服务就绪状态消息"""
|
| 105 |
-
await self.websocket.send(json.dumps({
|
| 106 |
-
"uid": self.client_uid,
|
| 107 |
-
"message": self.SERVER_READY,
|
| 108 |
-
"backend": "whisper_transcription"
|
| 109 |
-
}))
|
| 110 |
-
|
| 111 |
-
def set_language(self, source_lang: str, target_lang: str) -> None:
|
| 112 |
-
"""设置源语言和目标语言"""
|
| 113 |
-
self.source_language = source_lang
|
| 114 |
-
self.target_language = target_lang
|
| 115 |
-
self.text_separator = self._get_text_separator(source_lang)
|
| 116 |
-
# self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator)
|
| 117 |
-
|
| 118 |
def add_frames(self, frame_np: np.ndarray) -> None:
|
| 119 |
"""添加音频帧到处理队列"""
|
| 120 |
self._frame_queue.put(frame_np)
|
|
@@ -126,100 +97,88 @@ class WhisperTranscriptionService:
|
|
| 126 |
speech_status = processed_audio.speech_status
|
| 127 |
return speech_audio, speech_status
|
| 128 |
|
|
|
|
| 129 |
def _frame_processing_loop(self) -> None:
|
| 130 |
"""从队列获取音频帧并合并到缓冲区"""
|
| 131 |
while not self._frame_processing_thread_stop.is_set():
|
| 132 |
try:
|
| 133 |
frame_np = self._frame_queue.get(timeout=0.1)
|
| 134 |
frame_np, speech_status = self._apply_voice_activity_detection(frame_np)
|
| 135 |
-
|
|
|
|
| 136 |
continue
|
|
|
|
| 137 |
with self.lock:
|
| 138 |
-
if self.
|
| 139 |
-
self.
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
if
|
| 143 |
-
self.
|
|
|
|
|
|
|
| 144 |
self.frames_np = np.array([], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
except queue.Empty:
|
| 146 |
pass
|
| 147 |
|
| 148 |
-
def _process_transcription_results_2(self, seg_text:str,partial):
|
| 149 |
-
|
| 150 |
-
item = TransResult(
|
| 151 |
-
seg_id=self.row_number,
|
| 152 |
-
context=seg_text,
|
| 153 |
-
from_=self.source_language,
|
| 154 |
-
to=self.target_language,
|
| 155 |
-
tran_content=self._translate_text_large(seg_text),
|
| 156 |
-
partial=partial
|
| 157 |
-
)
|
| 158 |
-
if partial == False:
|
| 159 |
-
self.row_number += 1
|
| 160 |
-
return item
|
| 161 |
-
|
| 162 |
def _transcription_processing_loop(self) -> None:
|
| 163 |
"""主转录处理循环"""
|
| 164 |
frame_epoch = 1
|
| 165 |
-
while not self._translate_thread_stop.is_set():
|
| 166 |
-
|
| 167 |
-
if self.frames_np is None:
|
| 168 |
-
time.sleep(0.01)
|
| 169 |
-
continue
|
| 170 |
|
|
|
|
| 171 |
|
| 172 |
-
if len(self.
|
| 173 |
-
audio_buffer = self.segments_queue.pop()
|
| 174 |
-
partial = False
|
| 175 |
-
else:
|
| 176 |
-
with self.lock:
|
| 177 |
-
audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
|
| 178 |
-
partial = True
|
| 179 |
-
|
| 180 |
-
if len(audio_buffer) ==0:
|
| 181 |
time.sleep(0.01)
|
| 182 |
continue
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
if len(audio_buffer) < int(self.sample_rate):
|
| 185 |
silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
|
| 186 |
silence_audio[-len(audio_buffer):] = audio_buffer
|
| 187 |
audio_buffer = silence_audio
|
| 188 |
|
| 189 |
-
|
| 190 |
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
|
| 191 |
-
# try:
|
| 192 |
meta_item = self._transcribe_audio(audio_buffer)
|
| 193 |
segments = meta_item.segments
|
| 194 |
logger.debug(f"Segments: {segments}")
|
| 195 |
segments = filter_words(segments)
|
|
|
|
| 196 |
if len(segments):
|
| 197 |
seg_text = self.text_separator.join(seg.text for seg in segments)
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
self._temp_string = ""
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
result = self._process_transcription_results_2(seg_text, partial)
|
| 210 |
self._send_result_to_client(result)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
if partial == False:
|
| 214 |
frame_epoch = 1
|
| 215 |
else:
|
| 216 |
frame_epoch += 1
|
| 217 |
-
# 处理转录结果并发送到客户端
|
| 218 |
-
# for result in self._process_transcription_results(segments, audio_buffer):
|
| 219 |
-
# self._send_result_to_client(result)
|
| 220 |
|
| 221 |
-
|
| 222 |
-
# logger.error(f"Error processing audio: {e}")
|
| 223 |
|
| 224 |
|
| 225 |
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
|
|
@@ -227,14 +186,13 @@ class WhisperTranscriptionService:
|
|
| 227 |
log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
|
| 228 |
start_time = time.perf_counter()
|
| 229 |
|
| 230 |
-
result = self._translate_pipe.
|
| 231 |
segments = result.segments
|
| 232 |
time_diff = (time.perf_counter() - start_time)
|
| 233 |
-
logger.debug(f"📝
|
| 234 |
-
|
| 235 |
-
log_block("📝
|
| 236 |
-
|
| 237 |
-
self._transcrible_time_cost = round(time_diff, 3)
|
| 238 |
return result
|
| 239 |
|
| 240 |
def _translate_text(self, text: str) -> str:
|
|
@@ -270,51 +228,6 @@ class WhisperTranscriptionService:
|
|
| 270 |
return translated_text
|
| 271 |
|
| 272 |
|
| 273 |
-
|
| 274 |
-
def _process_transcription_results(self, segments: List[TranscriptToken], audio_buffer: np.ndarray) -> Iterator[TransResult]:
|
| 275 |
-
"""
|
| 276 |
-
处理转录结果,生成翻译结果
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
TransResult对象的迭代器
|
| 280 |
-
"""
|
| 281 |
-
|
| 282 |
-
if not segments:
|
| 283 |
-
return
|
| 284 |
-
start_time = time.perf_counter()
|
| 285 |
-
for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate):
|
| 286 |
-
if (cut_index :=ana_result.cut_index)>0:
|
| 287 |
-
# 更新音频缓冲区,移除已处理部分
|
| 288 |
-
self._update_audio_buffer(cut_index)
|
| 289 |
-
if ana_result.partial():
|
| 290 |
-
translated_context = self._translate_text(ana_result.context)
|
| 291 |
-
else:
|
| 292 |
-
translated_context = self._translate_text_large(ana_result.context)
|
| 293 |
-
|
| 294 |
-
yield TransResult(
|
| 295 |
-
seg_id=ana_result.seg_id,
|
| 296 |
-
context=ana_result.context,
|
| 297 |
-
from_=self.source_language,
|
| 298 |
-
to=self.target_language,
|
| 299 |
-
tran_content=translated_context,
|
| 300 |
-
partial=ana_result.partial()
|
| 301 |
-
)
|
| 302 |
-
current_time = time.perf_counter()
|
| 303 |
-
time_diff = current_time - start_time
|
| 304 |
-
if config.SAVE_DATA_SAVE:
|
| 305 |
-
self._save_queue.put(DebugResult(
|
| 306 |
-
seg_id=ana_result.seg_id,
|
| 307 |
-
transcrible_time=self._transcrible_time_cost,
|
| 308 |
-
translate_time=self._translate_time_cost,
|
| 309 |
-
context=ana_result.context,
|
| 310 |
-
from_=self.source_language,
|
| 311 |
-
to=self.target_language,
|
| 312 |
-
tran_content=translated_context,
|
| 313 |
-
partial=ana_result.partial()
|
| 314 |
-
))
|
| 315 |
-
log_block("🚦 Traffic times diff", round(time_diff, 2), 's')
|
| 316 |
-
|
| 317 |
-
|
| 318 |
def _send_result_to_client(self, result: TransResult) -> None:
|
| 319 |
"""发送翻译结果到客户端"""
|
| 320 |
try:
|
|
|
|
| 1 |
+
|
|
|
|
| 2 |
import queue
|
| 3 |
import threading
|
| 4 |
import time
|
| 5 |
from logging import getLogger
|
|
|
|
| 6 |
import asyncio
|
| 7 |
import numpy as np
|
| 8 |
import config
|
|
|
|
| 11 |
|
| 12 |
from .utils import log_block, save_to_wave, TestDataWriter, filter_words
|
| 13 |
from .translatepipes import TranslatePipes
|
| 14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from transcribe.pipelines import MetaItem
|
| 16 |
|
| 17 |
+
|
| 18 |
logger = getLogger("TranscriptionService")
|
| 19 |
|
| 20 |
|
| 21 |
+
def _get_text_separator(language: str) -> str:
|
| 22 |
+
"""根据语言返回适当的文本分隔符"""
|
| 23 |
+
return "" if language == "zh" else " "
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _start_thread(target_function) -> threading.Thread:
|
| 27 |
+
"""启动守护线程执行指定函数"""
|
| 28 |
+
thread = threading.Thread(target=target_function)
|
| 29 |
+
thread.daemon = True
|
| 30 |
+
thread.start()
|
| 31 |
+
return thread
|
| 32 |
+
|
| 33 |
+
|
| 34 |
class WhisperTranscriptionService:
|
| 35 |
"""
|
| 36 |
Whisper语音转录服务类,处理音频流转录和翻译
|
|
|
|
| 50 |
self._translate_pipe = pipe
|
| 51 |
|
| 52 |
# 音频处理相关
|
| 53 |
+
self.sample_rate = config.SAMPLE_RATE
|
| 54 |
|
| 55 |
self.lock = threading.Lock()
|
|
|
|
|
|
|
| 56 |
# 文本分隔符,根据语言设置
|
| 57 |
+
self.text_separator = _get_text_separator(language)
|
| 58 |
self.loop = asyncio.get_event_loop()
|
|
|
|
| 59 |
# 原始音频队列
|
| 60 |
self._frame_queue = queue.Queue()
|
| 61 |
# 音频队列缓冲区
|
| 62 |
+
self.frames_np = np.array([], dtype=np.float32)
|
| 63 |
+
self.frames_np_start_timestamp = None
|
| 64 |
# 完整音频队列
|
| 65 |
+
self.full_segments_queue = collections.deque()
|
|
|
|
|
|
|
|
|
|
| 66 |
# 启动处理线程
|
| 67 |
self._translate_thread_stop = threading.Event()
|
| 68 |
self._frame_processing_thread_stop = threading.Event()
|
| 69 |
|
| 70 |
+
self.translate_thread = _start_thread(self._transcription_processing_loop)
|
| 71 |
+
self.frame_processing_thread = _start_thread(self._frame_processing_loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self.row_number = 0
|
| 73 |
# for test
|
| 74 |
+
self._transcribe_time_cost = 0.
|
| 75 |
self._translate_time_cost = 0.
|
| 76 |
|
| 77 |
if config.SAVE_DATA_SAVE:
|
| 78 |
self._save_task_stop = threading.Event()
|
| 79 |
self._save_queue = queue.Queue()
|
| 80 |
+
self._save_thread = _start_thread(self.save_data_loop)
|
| 81 |
|
|
|
|
| 82 |
|
| 83 |
def save_data_loop(self):
|
| 84 |
writer = TestDataWriter()
|
|
|
|
| 86 |
test_data = self._save_queue.get()
|
| 87 |
writer.write(test_data) # Save test_data to CSV
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
def add_frames(self, frame_np: np.ndarray) -> None:
|
| 90 |
"""添加音频帧到处理队列"""
|
| 91 |
self._frame_queue.put(frame_np)
|
|
|
|
| 97 |
speech_status = processed_audio.speech_status
|
| 98 |
return speech_audio, speech_status
|
| 99 |
|
| 100 |
+
|
| 101 |
def _frame_processing_loop(self) -> None:
|
| 102 |
"""从队列获取音频帧并合并到缓冲区"""
|
| 103 |
while not self._frame_processing_thread_stop.is_set():
|
| 104 |
try:
|
| 105 |
frame_np = self._frame_queue.get(timeout=0.1)
|
| 106 |
frame_np, speech_status = self._apply_voice_activity_detection(frame_np)
|
| 107 |
+
|
| 108 |
+
if frame_np is None:
|
| 109 |
continue
|
| 110 |
+
|
| 111 |
with self.lock:
|
| 112 |
+
if speech_status == "START" and self.frames_np_start_timestamp is None:
|
| 113 |
+
self.frames_np_start_timestamp = time.time()
|
| 114 |
+
# 添加音频到音频缓冲区
|
| 115 |
+
self.frames_np = np.append(self.frames_np, frame_np)
|
| 116 |
+
if len(self.frames_np) >= self.sample_rate * config.MAX_SPEECH_DURATION_S:
|
| 117 |
+
audio_array=self.frames_np.copy()
|
| 118 |
+
self.full_segments_queue.appendleft(audio_array) # 根据时间是否满足三秒长度 来整合音频块
|
| 119 |
+
self.frames_np_start_timestamp = time.time()
|
| 120 |
self.frames_np = np.array([], dtype=np.float32)
|
| 121 |
+
|
| 122 |
+
elif speech_status == "END" and len(self.frames_np) > 0 and self.frames_np_start_timestamp:
|
| 123 |
+
time_diff = time.time() - self.frames_np_start_timestamp
|
| 124 |
+
if time_diff >= config.FRAME_SCOPE_TIME_THRESHOLD:
|
| 125 |
+
audio_array=self.frames_np.copy()
|
| 126 |
+
self.full_segments_queue.appendleft(audio_array) # 根据时间是否满足三秒长度 来整合音频块
|
| 127 |
+
self.frames_np_start_timestamp = None
|
| 128 |
+
self.frames_np = np.array([], dtype=np.float32)
|
| 129 |
+
else:
|
| 130 |
+
logger.debug(f"🥳 当前时间与上一句的时间差: {time_diff:.2f}s,继续增加缓冲区")
|
| 131 |
+
|
| 132 |
except queue.Empty:
|
| 133 |
pass
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def _transcription_processing_loop(self) -> None:
|
| 136 |
"""主转录处理循环"""
|
| 137 |
frame_epoch = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
while not self._translate_thread_stop.is_set():
|
| 140 |
|
| 141 |
+
if len(self.frames_np) ==0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
time.sleep(0.01)
|
| 143 |
continue
|
| 144 |
|
| 145 |
+
with self.lock:
|
| 146 |
+
if len(self.full_segments_queue) > 0:
|
| 147 |
+
audio_buffer = self.full_segments_queue.pop()
|
| 148 |
+
partial = False
|
| 149 |
+
else:
|
| 150 |
+
audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
|
| 151 |
+
partial = True
|
| 152 |
+
|
| 153 |
if len(audio_buffer) < int(self.sample_rate):
|
| 154 |
silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
|
| 155 |
silence_audio[-len(audio_buffer):] = audio_buffer
|
| 156 |
audio_buffer = silence_audio
|
| 157 |
|
|
|
|
| 158 |
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
|
|
|
|
| 159 |
meta_item = self._transcribe_audio(audio_buffer)
|
| 160 |
segments = meta_item.segments
|
| 161 |
logger.debug(f"Segments: {segments}")
|
| 162 |
segments = filter_words(segments)
|
| 163 |
+
|
| 164 |
if len(segments):
|
| 165 |
seg_text = self.text_separator.join(seg.text for seg in segments)
|
| 166 |
+
result = TransResult(
|
| 167 |
+
seg_id=self.row_number,
|
| 168 |
+
context=seg_text,
|
| 169 |
+
from_=self.source_language,
|
| 170 |
+
to=self.target_language,
|
| 171 |
+
tran_content=self._translate_text_large(seg_text),
|
| 172 |
+
partial=partial
|
| 173 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
self._send_result_to_client(result)
|
| 175 |
+
if not partial:
|
| 176 |
+
self.row_number += 1
|
|
|
|
| 177 |
frame_epoch = 1
|
| 178 |
else:
|
| 179 |
frame_epoch += 1
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
|
|
|
|
| 186 |
log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
|
| 187 |
start_time = time.perf_counter()
|
| 188 |
|
| 189 |
+
result = self._translate_pipe.transcribe(audio_buffer.tobytes(), self.source_language)
|
| 190 |
segments = result.segments
|
| 191 |
time_diff = (time.perf_counter() - start_time)
|
| 192 |
+
logger.debug(f"📝 transcribe Segments: {segments} ")
|
| 193 |
+
log_block("📝 transcribe output", f"{self.text_separator.join(seg.text for seg in segments)}", "")
|
| 194 |
+
log_block("📝 transcribe time", f"{time_diff:.3f}", "s")
|
| 195 |
+
self._transcribe_time_cost = round(time_diff, 3)
|
|
|
|
| 196 |
return result
|
| 197 |
|
| 198 |
def _translate_text(self, text: str) -> str:
|
|
|
|
| 228 |
return translated_text
|
| 229 |
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
def _send_result_to_client(self, result: TransResult) -> None:
|
| 232 |
"""发送翻译结果到客户端"""
|
| 233 |
try:
|