add test_models
Browse files- lib/models/__init__.py +0 -0
- lib/models/funasr.py +42 -0
- lib/models/kokoro.py +113 -0
- lib/models/llm.py +91 -0
- lib/models/whisper.py +68 -0
- scripts/asr_utils.py +0 -41
- test_data/audios.py +49 -0
- test_data/texts.py +18 -0
- test_data/{recordings/text → texts}/test_translation_en.txt +0 -0
- test_data/{recordings/text → texts}/test_translation_zh.txt +0 -0
- tests/test_models/__init__.py +0 -0
- tests/test_models/conftest.py +12 -0
- tests/test_models/test_funasr.py +22 -0
- tests/test_models/test_llm.py +30 -0
- tests/test_models/test_tts.py +31 -0
- tests/test_models/test_whisper.py +22 -0
lib/models/__init__.py
ADDED
|
File without changes
|
lib/models/funasr.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import time
|
| 3 |
+
import csv
|
| 4 |
+
import numpy as np
|
| 5 |
+
from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad
|
| 6 |
+
|
| 7 |
+
from lib.utils import Timer, read_audio
|
| 8 |
+
|
| 9 |
+
MODEL_DIR = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")
|
| 10 |
+
|
| 11 |
+
class FunASR:
|
| 12 |
+
def __init__(self, model_dir=MODEL_DIR, quantize=True):
|
| 13 |
+
asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
|
| 14 |
+
# vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch'
|
| 15 |
+
punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large'
|
| 16 |
+
t0 = time.time()
|
| 17 |
+
# vad_model = Fsmn_vad(vad_model_path, quantize=quantize)
|
| 18 |
+
with Timer("load FunASR") as t:
|
| 19 |
+
self.asr_model = SeacoParaformer(asr_model_path, quantize=quantize)
|
| 20 |
+
self.punc_model = CT_Transformer(punc_model_path, quantize=quantize)
|
| 21 |
+
self._warm_up()
|
| 22 |
+
|
| 23 |
+
def _warm_up(self):
|
| 24 |
+
# 生成 1 秒 16kHz 的假音频数据
|
| 25 |
+
fake_audio = np.random.randn(16000).astype(np.float32)
|
| 26 |
+
self.asr_model(fake_audio, hotwords="")
|
| 27 |
+
|
| 28 |
+
def transcribe(self, audio:np.ndarray):
|
| 29 |
+
with Timer("FunASR inference") as t:
|
| 30 |
+
asr_res = self.asr_model(audio, hotwords="")
|
| 31 |
+
asr_text = asr_res[0]["preds"]
|
| 32 |
+
result = self.punc_model(asr_text)
|
| 33 |
+
text = result[0]
|
| 34 |
+
return text, t.duration
|
| 35 |
+
|
| 36 |
+
if __name__ == '__main__':
|
| 37 |
+
funasr = FunASR()
|
| 38 |
+
audio = read_audio(Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/1.wav"))
|
| 39 |
+
text, time_cost =funasr.transcribe(audio)
|
| 40 |
+
print(text)
|
| 41 |
+
print(time_cost)
|
| 42 |
+
|
lib/models/kokoro.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from kokoro_onnx import Kokoro
|
| 4 |
+
from misaki import espeak, en, zh
|
| 5 |
+
from misaki.espeak import EspeakG2P
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from logging import getLogger
|
| 8 |
+
import librosa
|
| 9 |
+
import onnxruntime
|
| 10 |
+
|
| 11 |
+
from lib.utils import Timer, write_audio
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = getLogger(__name__)
|
| 15 |
+
providers = onnxruntime.get_available_providers()
|
| 16 |
+
MODEL_DIR = Path("//Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/kokoro")
|
| 17 |
+
|
| 18 |
+
def create_session(model_path):
|
| 19 |
+
# See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377
|
| 20 |
+
providers = onnxruntime.get_available_providers()
|
| 21 |
+
print(f"Available onnx runtime providers: {providers}")
|
| 22 |
+
|
| 23 |
+
# See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management
|
| 24 |
+
sess_options = onnxruntime.SessionOptions()
|
| 25 |
+
cpu_count = os.cpu_count() // 2
|
| 26 |
+
print(f"Setting threads to CPU cores count: {cpu_count}")
|
| 27 |
+
sess_options.intra_op_num_threads = cpu_count
|
| 28 |
+
session = onnxruntime.InferenceSession(
|
| 29 |
+
model_path, providers=["CPUExecutionProvider"], sess_options=sess_options
|
| 30 |
+
)
|
| 31 |
+
return session
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class KokoroTTS:
|
| 35 |
+
language_voice_mapping = {
|
| 36 |
+
"JP": "jf_alpha",
|
| 37 |
+
"JA": "jf_alpha",
|
| 38 |
+
"ZH": "zf_xiaoyi",
|
| 39 |
+
"EN": "af_heart",
|
| 40 |
+
"FR": "ff_siwis",
|
| 41 |
+
"IT": "im_nicola",
|
| 42 |
+
"HI": "hf_alpha",
|
| 43 |
+
"PT": "im_nicola",
|
| 44 |
+
"ES": "im_nicola"
|
| 45 |
+
}
|
| 46 |
+
language_word_mapping = {
|
| 47 |
+
"ZH": "你好",
|
| 48 |
+
"EN": "hello",
|
| 49 |
+
"FR": "Bonjour",
|
| 50 |
+
"IT": "Ciao",
|
| 51 |
+
"HI": "हेलो",
|
| 52 |
+
"PT": "Olá",
|
| 53 |
+
"ES": "Hola"
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None):
|
| 57 |
+
self._session = create_session(model_path)
|
| 58 |
+
self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config)
|
| 59 |
+
self.g2p = gcp
|
| 60 |
+
self.voice = voice
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_language(cls, language: str, model_dir: Path=MODEL_DIR):
|
| 64 |
+
model_path: str = str(model_dir / "kokoro-quant.onnx")
|
| 65 |
+
voice_model_path: str = str(model_dir / "voices-v1.0.bin")
|
| 66 |
+
voice = cls.language_voice_mapping.get(language.upper())
|
| 67 |
+
warm_up_text = cls.language_word_mapping.get(language.upper())
|
| 68 |
+
logger.info(f"[TTS] language: {language}")
|
| 69 |
+
if not voice:
|
| 70 |
+
raise ValueError(f"Unsupported language: {language}, voice: {voice}")
|
| 71 |
+
vocab_config = None
|
| 72 |
+
if language.upper() == "ZH":
|
| 73 |
+
g2p = zh.ZHG2P()
|
| 74 |
+
vocab_config = model_dir / "zh_config.json"
|
| 75 |
+
elif language.upper() == 'EN':
|
| 76 |
+
fallback = espeak.EspeakFallback(british=False)
|
| 77 |
+
g2p = en.G2P(trf=False, british=False, fallback=fallback)
|
| 78 |
+
elif language.upper() == "HI":
|
| 79 |
+
g2p = EspeakG2P(language="hi")
|
| 80 |
+
elif language.upper() == "IT":
|
| 81 |
+
g2p = EspeakG2P(language="it")
|
| 82 |
+
elif language.upper() == "PT":
|
| 83 |
+
g2p = EspeakG2P(language="pt-br")
|
| 84 |
+
elif language.upper() == "ES":
|
| 85 |
+
g2p = EspeakG2P(language="es")
|
| 86 |
+
elif language.upper() == "FR":
|
| 87 |
+
g2p = EspeakG2P(language="fr-fr")
|
| 88 |
+
else:
|
| 89 |
+
g2p = EspeakG2P(language.lower())
|
| 90 |
+
with Timer("load tts"):
|
| 91 |
+
tts = cls(model_path, voice_model_path,vocab_config=vocab_config, gcp=g2p, voice=voice)
|
| 92 |
+
tts.generate(warm_up_text)
|
| 93 |
+
return tts
|
| 94 |
+
|
| 95 |
+
def generate(self, text, speed=1.2):
|
| 96 |
+
with Timer("tts inference") as t:
|
| 97 |
+
phonemes, _ = self.g2p(text)
|
| 98 |
+
samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed)
|
| 99 |
+
return samples, sample_rate, t.duration
|
| 100 |
+
# return librosa.resample(samples, target_sr=44100, orig_sr=sample_rate)
|
| 101 |
+
|
| 102 |
+
async def stream(self, text, speed=1.2):
|
| 103 |
+
phonemes, _ = self.g2p(text)
|
| 104 |
+
stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed)
|
| 105 |
+
async for samples, sample_rate in stream:
|
| 106 |
+
yield samples, sample_rate
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
tts = KokoroTTS.from_language(language="ZH")
|
| 111 |
+
samples, sr, time_cost = tts.generate("今天天气怎么样?")
|
| 112 |
+
write_audio("tts_out.wav", samples, sr)
|
| 113 |
+
print(time_cost)
|
lib/models/llm.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from logging import getLogger
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from llama_cpp import Llama
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
from lib.utils import Timer
|
| 7 |
+
|
| 8 |
+
logger = getLogger(__name__)
|
| 9 |
+
LLM_SYS_PROMPT_EN= """
|
| 10 |
+
你是一名专业的同声传译员,正在为 GOSIM 会议提供英中翻译服务。你的任务是准确、流畅地翻译发言内容。
|
| 11 |
+
|
| 12 |
+
请遵循以下要求:
|
| 13 |
+
1. 语言风格:翻译成中文时,请使用自然、流畅、符合现代汉语口语习惯的表达方式。避免生硬、逐字翻译的痕迹,要让听众容易理解。
|
| 14 |
+
2. 专业术语:**请优先参考下方提供的术语对照表进行翻译。** 对于对照表中未包含的术语,如果该术语有公认的标准翻译,请使用标准翻译;如果没有或不确定,请保留英文原文。不要用通俗词汇替代专业术语。
|
| 15 |
+
3. 专有名词:对于专有名词,如会议名称 "GOSIM"、人名、公司名、项目名、特定技术名称等,请保留其原始英文不做翻译。
|
| 16 |
+
4. 流畅性与准确性:在追求口语化的同时,务必保证信息传达的准确性。
|
| 17 |
+
5. 输出:请直接输出翻译结果,不要添加任何额外的解释或说明。
|
| 18 |
+
|
| 19 |
+
**专业术语对照表:**
|
| 20 |
+
* driver: 驱动
|
| 21 |
+
* bus: 总线
|
| 22 |
+
* mask: 掩码
|
| 23 |
+
* preemption: 抢占
|
| 24 |
+
* register: 寄存器
|
| 25 |
+
* Library: 库
|
| 26 |
+
* biases: 偏移
|
| 27 |
+
* OpenAGI: OpenAGI
|
| 28 |
+
* LLaMA Factory: LLaMA Factory
|
| 29 |
+
* OPENGL: OPENGL
|
| 30 |
+
|
| 31 |
+
现在,请将以下内容翻译成中文:
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
LLM_SYS_PROMPT_ZH = """
|
| 35 |
+
你是一位中英文翻译专家。请将以下中文文本翻译成英文,遵循以下要求:
|
| 36 |
+
|
| 37 |
+
翻译要求:
|
| 38 |
+
- 保留原文英文内容:以下内容请保持原始英文形式,不进行翻译或改写:
|
| 39 |
+
- 技术术语与专业词汇
|
| 40 |
+
- 产品名称、品牌名称
|
| 41 |
+
- 代码片段、函数名、变量名
|
| 42 |
+
- 专有名词、缩写、首字母缩略词(如 API、NLP、RAG 等)
|
| 43 |
+
- 翻译符合英文表达习惯,流畅自然,不生硬直译。
|
| 44 |
+
- 保持专业性与准确性,清晰传达原意。
|
| 45 |
+
- 如遇原文表达模糊或逻辑不清的情况,允许适度调整语序或措辞,以增强英文表述的清晰度和逻辑性。
|
| 46 |
+
|
| 47 |
+
注意:
|
| 48 |
+
若难以确定某个词汇是否需要翻译,请优先保留原始英文形式。
|
| 49 |
+
不需添加额外解释或注释,仅翻译正文内容。
|
| 50 |
+
特别注意,翻译的内容只能包含英文,不能包含其他的语言。
|
| 51 |
+
|
| 52 |
+
文本:"""
|
| 53 |
+
MODEL_PATH = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/qwen2.5-1.5b-instruct-q5_0.gguf")
|
| 54 |
+
class QwenTranslator:
|
| 55 |
+
def __init__(self, model_path=MODEL_PATH, system_prompt_en=LLM_SYS_PROMPT_EN, system_prompt_zh=LLM_SYS_PROMPT_ZH) -> None:
|
| 56 |
+
with Timer("load llm"):
|
| 57 |
+
self.llm = Llama(
|
| 58 |
+
model_path=str(model_path),
|
| 59 |
+
chat_format="chatml",
|
| 60 |
+
verbose=False)
|
| 61 |
+
self.sys_prompt_en = system_prompt_en
|
| 62 |
+
self.sys_prompt_zh = system_prompt_zh
|
| 63 |
+
self._warmup()
|
| 64 |
+
|
| 65 |
+
def to_message(self, prompt, src_lang, dst_lang):
|
| 66 |
+
"""构造提示词"""
|
| 67 |
+
return [
|
| 68 |
+
{"role": "system", "content": self.sys_prompt_en if src_lang == "en" else self.sys_prompt_zh},
|
| 69 |
+
{"role": "user", "content": prompt},
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
def _warmup(self):
|
| 73 |
+
self.translate(prompt="hello", src_lang="en", dst_lang="zh")
|
| 74 |
+
|
| 75 |
+
@lru_cache(maxsize=10)
|
| 76 |
+
def translate(self, prompt, src_lang, dst_lang) -> str:
|
| 77 |
+
message = self.to_message(prompt, src_lang, dst_lang)
|
| 78 |
+
with Timer("llm inference") as t:
|
| 79 |
+
output = self.llm.create_chat_completion(messages=message, temperature=0)
|
| 80 |
+
return output['choices'][0]['message']['content'], t.duration
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == '__main__':
|
| 84 |
+
model_dir = Path("/Users/jeqin/work/code/Translator/moyoyo_asr_models")
|
| 85 |
+
qwen2 = (model_dir / "qwen2.5-1.5b-instruct-q5_0.gguf").as_posix()
|
| 86 |
+
qwen3 = (model_dir / "Qwen_Qwen3-1.7B-Q4_K_M.gguf").as_posix()
|
| 87 |
+
|
| 88 |
+
translator = QwenTranslator(qwen3)
|
| 89 |
+
text, time_cost =translator.translate("今天天气怎么样?", "zh", "en")
|
| 90 |
+
print(text)
|
| 91 |
+
print(time_cost)
|
lib/models/whisper.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pywhispercpp.model import Model
|
| 2 |
+
import soundfile
|
| 3 |
+
import numpy as np
|
| 4 |
+
from logging import getLogger
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from lib.utils import Timer, read_audio
|
| 8 |
+
|
| 9 |
+
logger = getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
MODEL_DIR = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")
|
| 12 |
+
WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。"
|
| 13 |
+
WHISPER_PROMPT_EN = "" # "The following is an English sentence."
|
| 14 |
+
|
| 15 |
+
class WhisperCPP:
|
| 16 |
+
def __init__(self, model_dir=MODEL_DIR, source_lange: str = 'en') -> None:
|
| 17 |
+
whisper_model = 'large-v3-turbo-q5_0'
|
| 18 |
+
with Timer("load whisper"):
|
| 19 |
+
self.model = Model(
|
| 20 |
+
model=whisper_model,
|
| 21 |
+
models_dir=str(model_dir),
|
| 22 |
+
print_realtime=False,
|
| 23 |
+
print_progress=False,
|
| 24 |
+
print_timestamps=False,
|
| 25 |
+
translate=False,
|
| 26 |
+
# beam_search=1,
|
| 27 |
+
temperature=0.,
|
| 28 |
+
no_context=True
|
| 29 |
+
)
|
| 30 |
+
self._warmup()
|
| 31 |
+
|
| 32 |
+
def _warmup(self):
|
| 33 |
+
fake_audio = np.random.randn(16000).astype(np.float32)
|
| 34 |
+
self.model.transcribe(fake_audio, print_progress=False)
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def config_language(language):
|
| 38 |
+
if language == "zh":
|
| 39 |
+
return WHISPER_PROMPT_ZH
|
| 40 |
+
elif language == "en":
|
| 41 |
+
return WHISPER_PROMPT_EN
|
| 42 |
+
raise ValueError(f"Unsupported language : {language}")
|
| 43 |
+
|
| 44 |
+
def transcribe(self, audio: np.ndarray, language):
|
| 45 |
+
prompt = self.config_language(language)
|
| 46 |
+
try:
|
| 47 |
+
with Timer("whisper inference") as t:
|
| 48 |
+
segments = self.model.transcribe(
|
| 49 |
+
audio,
|
| 50 |
+
initial_prompt=prompt,
|
| 51 |
+
language=language,
|
| 52 |
+
# token_timestamps=True,
|
| 53 |
+
split_on_word=True,
|
| 54 |
+
# max_len=max_len
|
| 55 |
+
)
|
| 56 |
+
text = "".join([s.text for s in segments])
|
| 57 |
+
return text, t.duration
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(e)
|
| 60 |
+
return []
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
from lib.utils import read_audio
|
| 64 |
+
whisper = WhisperCPP()
|
| 65 |
+
audio = read_audio(Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/1.wav"))
|
| 66 |
+
text, time_cost = whisper.transcribe(audio, "zh")
|
| 67 |
+
print(text)
|
| 68 |
+
print(time_cost)
|
scripts/asr_utils.py
CHANGED
|
@@ -7,17 +7,6 @@ from pathlib import Path
|
|
| 7 |
import subprocess
|
| 8 |
from subprocess import CompletedProcess
|
| 9 |
|
| 10 |
-
|
| 11 |
-
def cmd(command: str, check=True, capture_output=False) -> CompletedProcess:
|
| 12 |
-
print(command)
|
| 13 |
-
if capture_output:
|
| 14 |
-
ret = subprocess.run(command, shell=True, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
| 15 |
-
universal_newlines=True)
|
| 16 |
-
else:
|
| 17 |
-
ret = subprocess.run(command, shell=True, check=check)
|
| 18 |
-
print(ret.stdout)
|
| 19 |
-
return ret
|
| 20 |
-
|
| 21 |
def add_text_index():
|
| 22 |
text_file = '../test_data/text/test_asr_zh.txt'
|
| 23 |
index = 1
|
|
@@ -89,37 +78,7 @@ def get_origin_text_dict():
|
|
| 89 |
text_dict[idx] = text
|
| 90 |
return text_dict
|
| 91 |
|
| 92 |
-
def read_dataset(file):
|
| 93 |
-
"""line sample: {"audio": {"path": "dataset/audio/data_aishell/wav/test/S0916/BAC009S0916W0158.wav"}, "sentence": "顾客体验的核心是真善美", "duration": 3.22, "sentences": [{"start": 0, "end": 3.22, "text": "顾客体验的核心是真善美"}]}"""
|
| 94 |
-
with open(file) as f:
|
| 95 |
-
lines =f.readlines()
|
| 96 |
-
for line in lines:
|
| 97 |
-
line = line.strip()
|
| 98 |
-
if not line:
|
| 99 |
-
continue
|
| 100 |
-
data = json.loads(line)
|
| 101 |
-
|
| 102 |
-
yield data["audio"]["path"], data["sentence"], data["duration"]
|
| 103 |
|
| 104 |
-
def read_emilia(folder: Path, count_limit=None):
|
| 105 |
-
"""读取 emilia 数据集,返回音频路径、文本、时长,
|
| 106 |
-
json 文件样例:
|
| 107 |
-
{"id": "ZH_B00000_S00110_W000000", "wav": "ZH_B00000/ZH_B00000_S00110/mp3/ZH_B00000_S00110_W000000.mp3", "text": "\u628a\u63e1\u6700\u524d\u6cbf\u7684\u91d1\u878d\u9886\u57df\u548c\u533a\u5757\u94fe\u6700\u65b0\u8d44\u8baf\u3002\u6211\u4eec\u4e00\u8d77\u6765\u4e86\u89e3\u4e00\u4e0b\u4eca\u5929\u5e02\u573a\u4e0a\u6709\u53d1\u751f\u54ea\u4e9b\u91cd\u8981\u4e8b\u4ef6\u3002", "duration": 7.963, "speaker": "ZH_B00000_S00110", "language": "zh", "dnsmos": 3.3808}"""
|
| 108 |
-
count = 0
|
| 109 |
-
for json_file in sorted(folder.glob("*.json")):
|
| 110 |
-
count += 1
|
| 111 |
-
if count_limit and count > count_limit:
|
| 112 |
-
break
|
| 113 |
-
with open(json_file, encoding="utf-8") as f:
|
| 114 |
-
data = json.load(f)
|
| 115 |
-
text = data["text"]
|
| 116 |
-
duration = data["duration"]
|
| 117 |
-
wav_path = folder /f'{json_file.stem}.wav'
|
| 118 |
-
if not wav_path.exists():
|
| 119 |
-
mp3_path = folder / f'{json_file.stem}.mp3'
|
| 120 |
-
command=f"ffmpeg -i {mp3_path} -ac 1 -ar 16000 {wav_path}"
|
| 121 |
-
cmd(command)
|
| 122 |
-
yield wav_path, text, duration
|
| 123 |
|
| 124 |
|
| 125 |
|
|
|
|
| 7 |
import subprocess
|
| 8 |
from subprocess import CompletedProcess
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def add_text_index():
|
| 11 |
text_file = '../test_data/text/test_asr_zh.txt'
|
| 12 |
index = 1
|
|
|
|
| 78 |
text_dict[idx] = text
|
| 79 |
return text_dict
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
|
test_data/audios.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from lib.utils import cmd
|
| 5 |
+
from environment import TEST_DATA
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def read_recording(folder: Path=Path("./recordings"), count_limit=None):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def read_dataset(file: Path=Path("./dataset_aishell/dataset.txt"), count_limit=None):
|
| 12 |
+
"""line sample: {"audio": {"path": "dataset/audio/data_aishell/wav/test/S0916/BAC009S0916W0158.wav"}, "sentence": "顾客体验的核心是真善美", "duration": 3.22, "sentences": [{"start": 0, "end": 3.22, "text": "顾客体验的核心是真善美"}]}"""
|
| 13 |
+
with open(file) as f:
|
| 14 |
+
lines =f.readlines()
|
| 15 |
+
count = 0
|
| 16 |
+
for line in lines:
|
| 17 |
+
if count_limit and count > count_limit:
|
| 18 |
+
break
|
| 19 |
+
count += 1
|
| 20 |
+
line = line.strip()
|
| 21 |
+
if not line:
|
| 22 |
+
continue
|
| 23 |
+
data = json.loads(line)
|
| 24 |
+
|
| 25 |
+
yield data["audio"]["path"], data["sentence"], data["duration"]
|
| 26 |
+
|
| 27 |
+
def read_emilia(folder: Path=TEST_DATA/"ZH-B000000", count_limit=None):
|
| 28 |
+
"""读取 emilia 数据集,返回音频路径、文本、时长,
|
| 29 |
+
json 文件样例:
|
| 30 |
+
{"id": "ZH_B00000_S00110_W000000", "wav": "ZH_B00000/ZH_B00000_S00110/mp3/ZH_B00000_S00110_W000000.mp3", "text": "\u628a\u63e1\u6700\u524d\u6cbf\u7684\u91d1\u878d\u9886\u57df\u548c\u533a\u5757\u94fe\u6700\u65b0\u8d44\u8baf\u3002\u6211\u4eec\u4e00\u8d77\u6765\u4e86\u89e3\u4e00\u4e0b\u4eca\u5929\u5e02\u573a\u4e0a\u6709\u53d1\u751f\u54ea\u4e9b\u91cd\u8981\u4e8b\u4ef6\u3002", "duration": 7.963, "speaker": "ZH_B00000_S00110", "language": "zh", "dnsmos": 3.3808}"""
|
| 31 |
+
count = 0
|
| 32 |
+
for json_file in sorted(folder.glob("*.json")):
|
| 33 |
+
count += 1
|
| 34 |
+
if count_limit and count > count_limit:
|
| 35 |
+
break
|
| 36 |
+
with open(json_file, encoding="utf-8") as f:
|
| 37 |
+
data = json.load(f)
|
| 38 |
+
text = data["text"]
|
| 39 |
+
duration = data["duration"]
|
| 40 |
+
wav_path = folder /f'{json_file.stem}.wav'
|
| 41 |
+
if not wav_path.exists():
|
| 42 |
+
mp3_path = folder / f'{json_file.stem}.mp3'
|
| 43 |
+
command=f"ffmpeg -i {mp3_path} -ac 1 -ar 16000 {wav_path}"
|
| 44 |
+
cmd(command)
|
| 45 |
+
yield wav_path, text, duration
|
| 46 |
+
|
| 47 |
+
if __name__ == '__main__':
|
| 48 |
+
for res in read_dataset(count_limit=3):
|
| 49 |
+
print(res)
|
test_data/texts.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from environment import TEST_DATA
|
| 2 |
+
|
| 3 |
+
def read_translation(language, count_limit=None):
|
| 4 |
+
if language == "zh":
|
| 5 |
+
text_file = TEST_DATA/"texts"/"test_translation_zh.txt"
|
| 6 |
+
elif language == "en":
|
| 7 |
+
text_file = TEST_DATA/"texts"/"test_translation_en.txt"
|
| 8 |
+
else:
|
| 9 |
+
raise ValueError(f"not support language: {language}")
|
| 10 |
+
count = 0
|
| 11 |
+
with open(text_file, encoding="utf-8") as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
if not line.strip():
|
| 14 |
+
continue
|
| 15 |
+
count += 1
|
| 16 |
+
if count_limit is not None and count > count_limit:
|
| 17 |
+
break
|
| 18 |
+
yield line.strip()
|
test_data/{recordings/text → texts}/test_translation_en.txt
RENAMED
|
File without changes
|
test_data/{recordings/text → texts}/test_translation_zh.txt
RENAMED
|
File without changes
|
tests/test_models/__init__.py
ADDED
|
File without changes
|
tests/test_models/conftest.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform
|
| 2 |
+
from pytest import fixture
|
| 3 |
+
|
| 4 |
+
@fixture(scope="session")
|
| 5 |
+
def get_platform():
|
| 6 |
+
processor = platform.processor()
|
| 7 |
+
if processor.startswith("Intel"):
|
| 8 |
+
return "intel"
|
| 9 |
+
elif processor.startswith("arm"):
|
| 10 |
+
return "apple"
|
| 11 |
+
else:
|
| 12 |
+
raise ValueError(f"Unsupported platform: {processor}")
|
tests/test_models/test_funasr.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from lib.models.funasr import FunASR
|
| 3 |
+
from lib.utils import read_audio, save_csv
|
| 4 |
+
from test_data.audios import read_emilia
|
| 5 |
+
from environment import REPORTS_DIR
|
| 6 |
+
|
| 7 |
+
@pytest.fixture(scope="module")
|
| 8 |
+
def asr(get_platform)-> FunASR:
|
| 9 |
+
if get_platform == "apple":
|
| 10 |
+
return FunASR()
|
| 11 |
+
elif get_platform == "intel":
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def test_inference(asr: FunASR):
|
| 15 |
+
#TODO: 测试CER
|
| 16 |
+
report = []
|
| 17 |
+
for audio_file, text, duration in read_emilia(count_limit=100):
|
| 18 |
+
print(audio_file)
|
| 19 |
+
audio = read_audio(audio_file)
|
| 20 |
+
asr_text, time_cost = asr.transcribe(audio)
|
| 21 |
+
report.append([audio_file,duration, text, asr_text, time_cost])
|
| 22 |
+
save_csv(REPORTS_DIR/"funasr.csv", ["audio", "duration", "ref", "asr", "time"], report)
|
tests/test_models/test_llm.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from lib.models.llm import QwenTranslator
|
| 3 |
+
from test_data.texts import read_translation
|
| 4 |
+
from lib.utils import save_csv
|
| 5 |
+
from environment import REPORTS_DIR
|
| 6 |
+
|
| 7 |
+
@pytest.fixture(scope="module")
|
| 8 |
+
def llm(get_platform)-> QwenTranslator:
|
| 9 |
+
if get_platform == "apple":
|
| 10 |
+
return QwenTranslator()
|
| 11 |
+
elif get_platform == "intel":
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def test_llm_zh(llm: QwenTranslator):
|
| 15 |
+
report = []
|
| 16 |
+
for src in read_translation("zh"):
|
| 17 |
+
dst, time_cost = llm.translate(src, src_lang="zh", dst_lang="en")
|
| 18 |
+
print("Prompt:", src)
|
| 19 |
+
print("Response:", dst)
|
| 20 |
+
report.append([src, dst, time_cost])
|
| 21 |
+
save_csv(REPORTS_DIR/"translation_zh.csv", ["src", "dst", "time"], report)
|
| 22 |
+
|
| 23 |
+
def test_llm_en(llm: QwenTranslator):
|
| 24 |
+
report = []
|
| 25 |
+
for src in read_translation("en"):
|
| 26 |
+
dst, time_cost = llm.translate(src, src_lang="en", dst_lang="zh")
|
| 27 |
+
print("Prompt:", src)
|
| 28 |
+
print("Response:", dst)
|
| 29 |
+
report.append([src, dst, time_cost])
|
| 30 |
+
save_csv(REPORTS_DIR/"translation_en.csv", ["src", "dst", "time"], report)
|
tests/test_models/test_tts.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from lib.models.kokoro import KokoroTTS
|
| 3 |
+
from test_data.texts import read_translation
|
| 4 |
+
from lib.utils import save_csv
|
| 5 |
+
from environment import REPORTS_DIR
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture(scope="module")
|
| 9 |
+
def llm(get_platform) -> KokoroTTS:
|
| 10 |
+
if get_platform == "apple":
|
| 11 |
+
pass
|
| 12 |
+
elif get_platform == "intel":
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_tts_zh():
|
| 17 |
+
tts = KokoroTTS.from_language("zh")
|
| 18 |
+
report = []
|
| 19 |
+
for text in read_translation("zh"):
|
| 20 |
+
samples, sr, time_cost = tts.generate(text)
|
| 21 |
+
report.append([text, time_cost])
|
| 22 |
+
save_csv(REPORTS_DIR / "tts_zh.csv", ["text", "time"], report)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_tts_en():
|
| 26 |
+
tts = KokoroTTS.from_language("en")
|
| 27 |
+
report = []
|
| 28 |
+
for text in read_translation("en"):
|
| 29 |
+
samples, sr, time_cost = tts.generate(text)
|
| 30 |
+
report.append([text, time_cost])
|
| 31 |
+
save_csv(REPORTS_DIR / "tts_en.csv", ["text", "time"], report)
|
tests/test_models/test_whisper.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from lib.models.whisper import WhisperCPP
|
| 3 |
+
from lib.utils import read_audio, save_csv
|
| 4 |
+
from test_data.audios import read_emilia
|
| 5 |
+
from environment import REPORTS_DIR
|
| 6 |
+
|
| 7 |
+
@pytest.fixture(scope="module")
|
| 8 |
+
def whisper(get_platform)-> WhisperCPP:
|
| 9 |
+
if get_platform == "apple":
|
| 10 |
+
return WhisperCPP()
|
| 11 |
+
elif get_platform == "intel":
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def test_inference(whisper: WhisperCPP):
|
| 15 |
+
#TODO: 测试CER
|
| 16 |
+
report = []
|
| 17 |
+
for audio_file, text, duration in read_emilia(count_limit=100):
|
| 18 |
+
print(audio_file)
|
| 19 |
+
audio = read_audio(audio_file)
|
| 20 |
+
asr_text, time_cost = whisper.transcribe(audio, "zh")
|
| 21 |
+
report.append([audio_file,duration, text, asr_text, time_cost])
|
| 22 |
+
save_csv(REPORTS_DIR/"whisper.csv", ["audio", "duration", "ref", "asr", "time"], report)
|