| from pathlib import Path | |
| from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad | |
| from lib.utils import Timer | |
| from lib.asr_models.base_model import AbstractASRModel, ModelName | |
| MODEL_DIR = "/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models" | |
| class FunasrQuant(AbstractASRModel): | |
| def __init__(self, device='mps'): | |
| super().__init__(device=device) | |
| self.name = ModelName.FUNASR_QUANT | |
| def load(self, model_dir=MODEL_DIR, language=""): | |
| quantize=True | |
| with Timer("Loading Fun-ASR-Quant model"): | |
| model_dir = Path(model_dir) | |
| asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' | |
| vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch' | |
| punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large' | |
| self.vad_model = Fsmn_vad(vad_model_path, quantize=quantize) | |
| self.asr_model = SeacoParaformer(asr_model_path, quantize=quantize) | |
| self.punc_model = CT_Transformer(punc_model_path, quantize=quantize) | |
| def transcribe(self, wav, language="zh"): | |
| with Timer("Transcribing audio") as t: | |
| asr_res = self.asr_model(str(wav), hotwords="", language=language) | |
| text = "" | |
| if len(asr_res) > 0: | |
| asr_text = asr_res[0]["preds"] | |
| result = self.punc_model(asr_text) | |
| text = result[0] | |
| return text, t.duration | |
| if __name__ == "__main__": | |
| model = FunasrQuant(device='mps') | |
| model.load() | |
| text, cost = model.transcribe('../../test_data/recordings/1.wav', language="en") | |
| print("inference time: ", cost) | |
| print(text) |