yujuanqin commited on
Commit
1e495f3
·
1 Parent(s): 3d1d87d

add test_models

Browse files
lib/models/__init__.py ADDED
File without changes
lib/models/funasr.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import time
3
+ import csv
4
+ import numpy as np
5
+ from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad
6
+
7
+ from lib.utils import Timer, read_audio
8
+
9
+ MODEL_DIR = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")
10
+
11
+ class FunASR:
12
+ def __init__(self, model_dir=MODEL_DIR, quantize=True):
13
+ asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
14
+ # vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch'
15
+ punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large'
16
+ t0 = time.time()
17
+ # vad_model = Fsmn_vad(vad_model_path, quantize=quantize)
18
+ with Timer("load FunASR") as t:
19
+ self.asr_model = SeacoParaformer(asr_model_path, quantize=quantize)
20
+ self.punc_model = CT_Transformer(punc_model_path, quantize=quantize)
21
+ self._warm_up()
22
+
23
+ def _warm_up(self):
24
+ # 生成 1 秒 16kHz 的假音频数据
25
+ fake_audio = np.random.randn(16000).astype(np.float32)
26
+ self.asr_model(fake_audio, hotwords="")
27
+
28
+ def transcribe(self, audio:np.ndarray):
29
+ with Timer("FunASR inference") as t:
30
+ asr_res = self.asr_model(audio, hotwords="")
31
+ asr_text = asr_res[0]["preds"]
32
+ result = self.punc_model(asr_text)
33
+ text = result[0]
34
+ return text, t.duration
35
+
36
+ if __name__ == '__main__':
37
+ funasr = FunASR()
38
+ audio = read_audio(Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/1.wav"))
39
+ text, time_cost =funasr.transcribe(audio)
40
+ print(text)
41
+ print(time_cost)
42
+
lib/models/kokoro.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from kokoro_onnx import Kokoro
4
+ from misaki import espeak, en, zh
5
+ from misaki.espeak import EspeakG2P
6
+ from functools import lru_cache
7
+ from logging import getLogger
8
+ import librosa
9
+ import onnxruntime
10
+
11
+ from lib.utils import Timer, write_audio
12
+
13
+
14
+ logger = getLogger(__name__)
15
+ providers = onnxruntime.get_available_providers()
16
+ MODEL_DIR = Path("//Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/kokoro")
17
+
18
+ def create_session(model_path):
19
+ # See list of providers https://github.com/microsoft/onnxruntime/issues/22101#issuecomment-2357667377
20
+ providers = onnxruntime.get_available_providers()
21
+ print(f"Available onnx runtime providers: {providers}")
22
+
23
+ # See session options https://onnxruntime.ai/docs/performance/tune-performance/threading.html#thread-management
24
+ sess_options = onnxruntime.SessionOptions()
25
+ cpu_count = os.cpu_count() // 2
26
+ print(f"Setting threads to CPU cores count: {cpu_count}")
27
+ sess_options.intra_op_num_threads = cpu_count
28
+ session = onnxruntime.InferenceSession(
29
+ model_path, providers=["CPUExecutionProvider"], sess_options=sess_options
30
+ )
31
+ return session
32
+
33
+
34
+ class KokoroTTS:
35
+ language_voice_mapping = {
36
+ "JP": "jf_alpha",
37
+ "JA": "jf_alpha",
38
+ "ZH": "zf_xiaoyi",
39
+ "EN": "af_heart",
40
+ "FR": "ff_siwis",
41
+ "IT": "im_nicola",
42
+ "HI": "hf_alpha",
43
+ "PT": "im_nicola",
44
+ "ES": "im_nicola"
45
+ }
46
+ language_word_mapping = {
47
+ "ZH": "你好",
48
+ "EN": "hello",
49
+ "FR": "Bonjour",
50
+ "IT": "Ciao",
51
+ "HI": "हेलो",
52
+ "PT": "Olá",
53
+ "ES": "Hola"
54
+ }
55
+
56
+ def __init__(self, model_path: str, voice_model_path: str, vocab_config=None, gcp=None, voice=None):
57
+ self._session = create_session(model_path)
58
+ self.model = Kokoro.from_session(self._session, voice_model_path, vocab_config=vocab_config)
59
+ self.g2p = gcp
60
+ self.voice = voice
61
+
62
+ @classmethod
63
+ def from_language(cls, language: str, model_dir: Path=MODEL_DIR):
64
+ model_path: str = str(model_dir / "kokoro-quant.onnx")
65
+ voice_model_path: str = str(model_dir / "voices-v1.0.bin")
66
+ voice = cls.language_voice_mapping.get(language.upper())
67
+ warm_up_text = cls.language_word_mapping.get(language.upper())
68
+ logger.info(f"[TTS] language: {language}")
69
+ if not voice:
70
+ raise ValueError(f"Unsupported language: {language}, voice: {voice}")
71
+ vocab_config = None
72
+ if language.upper() == "ZH":
73
+ g2p = zh.ZHG2P()
74
+ vocab_config = model_dir / "zh_config.json"
75
+ elif language.upper() == 'EN':
76
+ fallback = espeak.EspeakFallback(british=False)
77
+ g2p = en.G2P(trf=False, british=False, fallback=fallback)
78
+ elif language.upper() == "HI":
79
+ g2p = EspeakG2P(language="hi")
80
+ elif language.upper() == "IT":
81
+ g2p = EspeakG2P(language="it")
82
+ elif language.upper() == "PT":
83
+ g2p = EspeakG2P(language="pt-br")
84
+ elif language.upper() == "ES":
85
+ g2p = EspeakG2P(language="es")
86
+ elif language.upper() == "FR":
87
+ g2p = EspeakG2P(language="fr-fr")
88
+ else:
89
+ g2p = EspeakG2P(language.lower())
90
+ with Timer("load tts"):
91
+ tts = cls(model_path, voice_model_path,vocab_config=vocab_config, gcp=g2p, voice=voice)
92
+ tts.generate(warm_up_text)
93
+ return tts
94
+
95
+ def generate(self, text, speed=1.2):
96
+ with Timer("tts inference") as t:
97
+ phonemes, _ = self.g2p(text)
98
+ samples, sample_rate = self.model.create(phonemes, self.voice, is_phonemes=True, speed=speed)
99
+ return samples, sample_rate, t.duration
100
+ # return librosa.resample(samples, target_sr=44100, orig_sr=sample_rate)
101
+
102
+ async def stream(self, text, speed=1.2):
103
+ phonemes, _ = self.g2p(text)
104
+ stream = self.model.create_stream(phonemes, self.voice, is_phonemes=True, speed=speed)
105
+ async for samples, sample_rate in stream:
106
+ yield samples, sample_rate
107
+
108
+
109
+ if __name__ == '__main__':
110
+ tts = KokoroTTS.from_language(language="ZH")
111
+ samples, sr, time_cost = tts.generate("今天天气怎么样?")
112
+ write_audio("tts_out.wav", samples, sr)
113
+ print(time_cost)
lib/models/llm.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ from pathlib import Path
3
+ from llama_cpp import Llama
4
+ from functools import lru_cache
5
+
6
+ from lib.utils import Timer
7
+
8
+ logger = getLogger(__name__)
9
+ LLM_SYS_PROMPT_EN= """
10
+ 你是一名专业的同声传译员,正在为 GOSIM 会议提供英中翻译服务。你的任务是准确、流畅地翻译发言内容。
11
+
12
+ 请遵循以下要求:
13
+ 1. 语言风格:翻译成中文时,请使用自然、流畅、符合现代汉语口语习惯的表达方式。避免生硬、逐字翻译的痕迹,要让听众容易理解。
14
+ 2. 专业术语:**请优先参考下方提供的术语对照表进行翻译。** 对于对照表中未包含的术语,如果该术语有公认的标准翻译,请使用标准翻译;如果没有或不确定,请保留英文原文。不要用通俗词汇替代专业术语。
15
+ 3. 专有名词:对于专有名词,如会议名称 "GOSIM"、人名、公司名、项目名、特定技术名称等,请保留其原始英文不做翻译。
16
+ 4. 流畅性与准确性:在追求口语化的同时,务必保证信息传达的准确性。
17
+ 5. 输出:请直接输出翻译结果,不要添加任何额外的解释或说明。
18
+
19
+ **专业术语对照表:**
20
+ * driver: 驱动
21
+ * bus: 总线
22
+ * mask: 掩码
23
+ * preemption: 抢占
24
+ * register: 寄存器
25
+ * Library: 库
26
+ * biases: 偏移
27
+ * OpenAGI: OpenAGI
28
+ * LLaMA Factory: LLaMA Factory
29
+ * OPENGL: OPENGL
30
+
31
+ 现在,请将以下内容翻译成中文:
32
+ """
33
+
34
+ LLM_SYS_PROMPT_ZH = """
35
+ 你是一位中英文翻译专家。请将以下中文文本翻译成英文,遵循以下要求:
36
+
37
+ 翻译要求:
38
+ - 保留原文英文内容:以下内容请保持原始英文形式,不进行翻译或改写:
39
+ - 技术术语与专业词汇
40
+ - 产品名称、品牌名称
41
+ - 代码片段、函数名、变量名
42
+ - 专有名词、缩写、首字母缩略词(如 API、NLP、RAG 等)
43
+ - 翻译符合英文表达习惯,流畅自然,不生硬直译。
44
+ - 保持专业性与准确性,清晰传达原意。
45
+ - 如遇原文表达模糊或逻辑不清的情况,允许适度调整语序或措辞,以增强英文表述的清晰度和逻辑性。
46
+
47
+ 注意:
48
+ 若难以确定某个词汇是否需要翻译,请优先保留原始英文形式。
49
+ 不需添加额外解释或注释,仅翻译正文内容。
50
+ 特别注意,翻译的内容只能包含英文,不能包含其他的语言。
51
+
52
+ 文本:"""
53
+ MODEL_PATH = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models/qwen2.5-1.5b-instruct-q5_0.gguf")
54
+ class QwenTranslator:
55
+ def __init__(self, model_path=MODEL_PATH, system_prompt_en=LLM_SYS_PROMPT_EN, system_prompt_zh=LLM_SYS_PROMPT_ZH) -> None:
56
+ with Timer("load llm"):
57
+ self.llm = Llama(
58
+ model_path=str(model_path),
59
+ chat_format="chatml",
60
+ verbose=False)
61
+ self.sys_prompt_en = system_prompt_en
62
+ self.sys_prompt_zh = system_prompt_zh
63
+ self._warmup()
64
+
65
+ def to_message(self, prompt, src_lang, dst_lang):
66
+ """构造提示词"""
67
+ return [
68
+ {"role": "system", "content": self.sys_prompt_en if src_lang == "en" else self.sys_prompt_zh},
69
+ {"role": "user", "content": prompt},
70
+ ]
71
+
72
+ def _warmup(self):
73
+ self.translate(prompt="hello", src_lang="en", dst_lang="zh")
74
+
75
+ @lru_cache(maxsize=10)
76
+ def translate(self, prompt, src_lang, dst_lang) -> str:
77
+ message = self.to_message(prompt, src_lang, dst_lang)
78
+ with Timer("llm inference") as t:
79
+ output = self.llm.create_chat_completion(messages=message, temperature=0)
80
+ return output['choices'][0]['message']['content'], t.duration
81
+
82
+
83
+ if __name__ == '__main__':
84
+ model_dir = Path("/Users/jeqin/work/code/Translator/moyoyo_asr_models")
85
+ qwen2 = (model_dir / "qwen2.5-1.5b-instruct-q5_0.gguf").as_posix()
86
+ qwen3 = (model_dir / "Qwen_Qwen3-1.7B-Q4_K_M.gguf").as_posix()
87
+
88
+ translator = QwenTranslator(qwen3)
89
+ text, time_cost =translator.translate("今天天气怎么样?", "zh", "en")
90
+ print(text)
91
+ print(time_cost)
lib/models/whisper.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pywhispercpp.model import Model
2
+ import soundfile
3
+ import numpy as np
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+
7
+ from lib.utils import Timer, read_audio
8
+
9
+ logger = getLogger(__name__)
10
+
11
+ MODEL_DIR = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models")
12
+ WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。"
13
+ WHISPER_PROMPT_EN = "" # "The following is an English sentence."
14
+
15
+ class WhisperCPP:
16
+ def __init__(self, model_dir=MODEL_DIR, source_lange: str = 'en') -> None:
17
+ whisper_model = 'large-v3-turbo-q5_0'
18
+ with Timer("load whisper"):
19
+ self.model = Model(
20
+ model=whisper_model,
21
+ models_dir=str(model_dir),
22
+ print_realtime=False,
23
+ print_progress=False,
24
+ print_timestamps=False,
25
+ translate=False,
26
+ # beam_search=1,
27
+ temperature=0.,
28
+ no_context=True
29
+ )
30
+ self._warmup()
31
+
32
+ def _warmup(self):
33
+ fake_audio = np.random.randn(16000).astype(np.float32)
34
+ self.model.transcribe(fake_audio, print_progress=False)
35
+
36
+ @staticmethod
37
+ def config_language(language):
38
+ if language == "zh":
39
+ return WHISPER_PROMPT_ZH
40
+ elif language == "en":
41
+ return WHISPER_PROMPT_EN
42
+ raise ValueError(f"Unsupported language : {language}")
43
+
44
+ def transcribe(self, audio: np.ndarray, language):
45
+ prompt = self.config_language(language)
46
+ try:
47
+ with Timer("whisper inference") as t:
48
+ segments = self.model.transcribe(
49
+ audio,
50
+ initial_prompt=prompt,
51
+ language=language,
52
+ # token_timestamps=True,
53
+ split_on_word=True,
54
+ # max_len=max_len
55
+ )
56
+ text = "".join([s.text for s in segments])
57
+ return text, t.duration
58
+ except Exception as e:
59
+ logger.error(e)
60
+ return []
61
+
62
+ if __name__ == '__main__':
63
+ from lib.utils import read_audio
64
+ whisper = WhisperCPP()
65
+ audio = read_audio(Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/1.wav"))
66
+ text, time_cost = whisper.transcribe(audio, "zh")
67
+ print(text)
68
+ print(time_cost)
scripts/asr_utils.py CHANGED
@@ -7,17 +7,6 @@ from pathlib import Path
7
  import subprocess
8
  from subprocess import CompletedProcess
9
 
10
-
11
- def cmd(command: str, check=True, capture_output=False) -> CompletedProcess:
12
- print(command)
13
- if capture_output:
14
- ret = subprocess.run(command, shell=True, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
15
- universal_newlines=True)
16
- else:
17
- ret = subprocess.run(command, shell=True, check=check)
18
- print(ret.stdout)
19
- return ret
20
-
21
  def add_text_index():
22
  text_file = '../test_data/text/test_asr_zh.txt'
23
  index = 1
@@ -89,37 +78,7 @@ def get_origin_text_dict():
89
  text_dict[idx] = text
90
  return text_dict
91
 
92
- def read_dataset(file):
93
- """line sample: {"audio": {"path": "dataset/audio/data_aishell/wav/test/S0916/BAC009S0916W0158.wav"}, "sentence": "顾客体验的核心是真善美", "duration": 3.22, "sentences": [{"start": 0, "end": 3.22, "text": "顾客体验的核心是真善美"}]}"""
94
- with open(file) as f:
95
- lines =f.readlines()
96
- for line in lines:
97
- line = line.strip()
98
- if not line:
99
- continue
100
- data = json.loads(line)
101
-
102
- yield data["audio"]["path"], data["sentence"], data["duration"]
103
 
104
- def read_emilia(folder: Path, count_limit=None):
105
- """读取 emilia 数据集,返回音频路径、文本、时长,
106
- json 文件样例:
107
- {"id": "ZH_B00000_S00110_W000000", "wav": "ZH_B00000/ZH_B00000_S00110/mp3/ZH_B00000_S00110_W000000.mp3", "text": "\u628a\u63e1\u6700\u524d\u6cbf\u7684\u91d1\u878d\u9886\u57df\u548c\u533a\u5757\u94fe\u6700\u65b0\u8d44\u8baf\u3002\u6211\u4eec\u4e00\u8d77\u6765\u4e86\u89e3\u4e00\u4e0b\u4eca\u5929\u5e02\u573a\u4e0a\u6709\u53d1\u751f\u54ea\u4e9b\u91cd\u8981\u4e8b\u4ef6\u3002", "duration": 7.963, "speaker": "ZH_B00000_S00110", "language": "zh", "dnsmos": 3.3808}"""
108
- count = 0
109
- for json_file in sorted(folder.glob("*.json")):
110
- count += 1
111
- if count_limit and count > count_limit:
112
- break
113
- with open(json_file, encoding="utf-8") as f:
114
- data = json.load(f)
115
- text = data["text"]
116
- duration = data["duration"]
117
- wav_path = folder /f'{json_file.stem}.wav'
118
- if not wav_path.exists():
119
- mp3_path = folder / f'{json_file.stem}.mp3'
120
- command=f"ffmpeg -i {mp3_path} -ac 1 -ar 16000 {wav_path}"
121
- cmd(command)
122
- yield wav_path, text, duration
123
 
124
 
125
 
 
7
  import subprocess
8
  from subprocess import CompletedProcess
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  def add_text_index():
11
  text_file = '../test_data/text/test_asr_zh.txt'
12
  index = 1
 
78
  text_dict[idx] = text
79
  return text_dict
80
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
 
test_data/audios.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+
4
+ from lib.utils import cmd
5
+ from environment import TEST_DATA
6
+
7
+
8
+ def read_recording(folder: Path=Path("./recordings"), count_limit=None):
9
+ pass
10
+
11
+ def read_dataset(file: Path=Path("./dataset_aishell/dataset.txt"), count_limit=None):
12
+ """line sample: {"audio": {"path": "dataset/audio/data_aishell/wav/test/S0916/BAC009S0916W0158.wav"}, "sentence": "顾客体验的核心是真善美", "duration": 3.22, "sentences": [{"start": 0, "end": 3.22, "text": "顾客体验的核心是真善美"}]}"""
13
+ with open(file) as f:
14
+ lines =f.readlines()
15
+ count = 0
16
+ for line in lines:
17
+ if count_limit and count > count_limit:
18
+ break
19
+ count += 1
20
+ line = line.strip()
21
+ if not line:
22
+ continue
23
+ data = json.loads(line)
24
+
25
+ yield data["audio"]["path"], data["sentence"], data["duration"]
26
+
27
+ def read_emilia(folder: Path=TEST_DATA/"ZH-B000000", count_limit=None):
28
+ """读取 emilia 数据集,返回音频路径、文本、时长,
29
+ json 文件样例:
30
+ {"id": "ZH_B00000_S00110_W000000", "wav": "ZH_B00000/ZH_B00000_S00110/mp3/ZH_B00000_S00110_W000000.mp3", "text": "\u628a\u63e1\u6700\u524d\u6cbf\u7684\u91d1\u878d\u9886\u57df\u548c\u533a\u5757\u94fe\u6700\u65b0\u8d44\u8baf\u3002\u6211\u4eec\u4e00\u8d77\u6765\u4e86\u89e3\u4e00\u4e0b\u4eca\u5929\u5e02\u573a\u4e0a\u6709\u53d1\u751f\u54ea\u4e9b\u91cd\u8981\u4e8b\u4ef6\u3002", "duration": 7.963, "speaker": "ZH_B00000_S00110", "language": "zh", "dnsmos": 3.3808}"""
31
+ count = 0
32
+ for json_file in sorted(folder.glob("*.json")):
33
+ count += 1
34
+ if count_limit and count > count_limit:
35
+ break
36
+ with open(json_file, encoding="utf-8") as f:
37
+ data = json.load(f)
38
+ text = data["text"]
39
+ duration = data["duration"]
40
+ wav_path = folder /f'{json_file.stem}.wav'
41
+ if not wav_path.exists():
42
+ mp3_path = folder / f'{json_file.stem}.mp3'
43
+ command=f"ffmpeg -i {mp3_path} -ac 1 -ar 16000 {wav_path}"
44
+ cmd(command)
45
+ yield wav_path, text, duration
46
+
47
+ if __name__ == '__main__':
48
+ for res in read_dataset(count_limit=3):
49
+ print(res)
test_data/texts.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from environment import TEST_DATA
2
+
3
+ def read_translation(language, count_limit=None):
4
+ if language == "zh":
5
+ text_file = TEST_DATA/"texts"/"test_translation_zh.txt"
6
+ elif language == "en":
7
+ text_file = TEST_DATA/"texts"/"test_translation_en.txt"
8
+ else:
9
+ raise ValueError(f"not support language: {language}")
10
+ count = 0
11
+ with open(text_file, encoding="utf-8") as f:
12
+ for line in f:
13
+ if not line.strip():
14
+ continue
15
+ count += 1
16
+ if count_limit is not None and count > count_limit:
17
+ break
18
+ yield line.strip()
test_data/{recordings/text → texts}/test_translation_en.txt RENAMED
File without changes
test_data/{recordings/text → texts}/test_translation_zh.txt RENAMED
File without changes
tests/test_models/__init__.py ADDED
File without changes
tests/test_models/conftest.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from pytest import fixture
3
+
4
+ @fixture(scope="session")
5
+ def get_platform():
6
+ processor = platform.processor()
7
+ if processor.startswith("Intel"):
8
+ return "intel"
9
+ elif processor.startswith("arm"):
10
+ return "apple"
11
+ else:
12
+ raise ValueError(f"Unsupported platform: {processor}")
tests/test_models/test_funasr.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from lib.models.funasr import FunASR
3
+ from lib.utils import read_audio, save_csv
4
+ from test_data.audios import read_emilia
5
+ from environment import REPORTS_DIR
6
+
7
+ @pytest.fixture(scope="module")
8
+ def asr(get_platform)-> FunASR:
9
+ if get_platform == "apple":
10
+ return FunASR()
11
+ elif get_platform == "intel":
12
+ pass
13
+
14
+ def test_inference(asr: FunASR):
15
+ #TODO: 测试CER
16
+ report = []
17
+ for audio_file, text, duration in read_emilia(count_limit=100):
18
+ print(audio_file)
19
+ audio = read_audio(audio_file)
20
+ asr_text, time_cost = asr.transcribe(audio)
21
+ report.append([audio_file,duration, text, asr_text, time_cost])
22
+ save_csv(REPORTS_DIR/"funasr.csv", ["audio", "duration", "ref", "asr", "time"], report)
tests/test_models/test_llm.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from lib.models.llm import QwenTranslator
3
+ from test_data.texts import read_translation
4
+ from lib.utils import save_csv
5
+ from environment import REPORTS_DIR
6
+
7
+ @pytest.fixture(scope="module")
8
+ def llm(get_platform)-> QwenTranslator:
9
+ if get_platform == "apple":
10
+ return QwenTranslator()
11
+ elif get_platform == "intel":
12
+ pass
13
+
14
+ def test_llm_zh(llm: QwenTranslator):
15
+ report = []
16
+ for src in read_translation("zh"):
17
+ dst, time_cost = llm.translate(src, src_lang="zh", dst_lang="en")
18
+ print("Prompt:", src)
19
+ print("Response:", dst)
20
+ report.append([src, dst, time_cost])
21
+ save_csv(REPORTS_DIR/"translation_zh.csv", ["src", "dst", "time"], report)
22
+
23
+ def test_llm_en(llm: QwenTranslator):
24
+ report = []
25
+ for src in read_translation("en"):
26
+ dst, time_cost = llm.translate(src, src_lang="en", dst_lang="zh")
27
+ print("Prompt:", src)
28
+ print("Response:", dst)
29
+ report.append([src, dst, time_cost])
30
+ save_csv(REPORTS_DIR/"translation_en.csv", ["src", "dst", "time"], report)
tests/test_models/test_tts.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from lib.models.kokoro import KokoroTTS
3
+ from test_data.texts import read_translation
4
+ from lib.utils import save_csv
5
+ from environment import REPORTS_DIR
6
+
7
+
8
+ @pytest.fixture(scope="module")
9
+ def llm(get_platform) -> KokoroTTS:
10
+ if get_platform == "apple":
11
+ pass
12
+ elif get_platform == "intel":
13
+ pass
14
+
15
+
16
+ def test_tts_zh():
17
+ tts = KokoroTTS.from_language("zh")
18
+ report = []
19
+ for text in read_translation("zh"):
20
+ samples, sr, time_cost = tts.generate(text)
21
+ report.append([text, time_cost])
22
+ save_csv(REPORTS_DIR / "tts_zh.csv", ["text", "time"], report)
23
+
24
+
25
+ def test_tts_en():
26
+ tts = KokoroTTS.from_language("en")
27
+ report = []
28
+ for text in read_translation("en"):
29
+ samples, sr, time_cost = tts.generate(text)
30
+ report.append([text, time_cost])
31
+ save_csv(REPORTS_DIR / "tts_en.csv", ["text", "time"], report)
tests/test_models/test_whisper.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from lib.models.whisper import WhisperCPP
3
+ from lib.utils import read_audio, save_csv
4
+ from test_data.audios import read_emilia
5
+ from environment import REPORTS_DIR
6
+
7
+ @pytest.fixture(scope="module")
8
+ def whisper(get_platform)-> WhisperCPP:
9
+ if get_platform == "apple":
10
+ return WhisperCPP()
11
+ elif get_platform == "intel":
12
+ pass
13
+
14
+ def test_inference(whisper: WhisperCPP):
15
+ #TODO: 测试CER
16
+ report = []
17
+ for audio_file, text, duration in read_emilia(count_limit=100):
18
+ print(audio_file)
19
+ audio = read_audio(audio_file)
20
+ asr_text, time_cost = whisper.transcribe(audio, "zh")
21
+ report.append([audio_file,duration, text, asr_text, time_cost])
22
+ save_csv(REPORTS_DIR/"whisper.csv", ["audio", "duration", "ref", "asr", "time"], report)