File size: 1,738 Bytes
db0d138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from pathlib import Path
from funasr_onnx import SeacoParaformer, CT_Transformer, Fsmn_vad

from lib.utils import Timer
from lib.asr_models.base_model import AbstractASRModel, ModelName


MODEL_DIR = "/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models"

class FunasrQuant(AbstractASRModel):
    def __init__(self, device='mps'):
        super().__init__(device=device)
        self.name = ModelName.FUNASR_QUANT

    def load(self, model_dir=MODEL_DIR, language=""):
        quantize=True
        with Timer("Loading Fun-ASR-Quant model"):
            model_dir = Path(model_dir)
            asr_model_path = model_dir / 'speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
            vad_model_path = model_dir / 'speech_fsmn_vad_zh-cn-16k-common-pytorch'
            punc_model_path = model_dir / 'punc_ct-transformer_cn-en-common-vocab471067-large'
            self.vad_model = Fsmn_vad(vad_model_path, quantize=quantize)
            self.asr_model = SeacoParaformer(asr_model_path, quantize=quantize)
            self.punc_model = CT_Transformer(punc_model_path, quantize=quantize)

    def transcribe(self, wav, language="zh"):
        with Timer("Transcribing audio") as t:
            asr_res = self.asr_model(str(wav), hotwords="", language=language)
            text = ""
            if len(asr_res) > 0:
                asr_text = asr_res[0]["preds"]
                result = self.punc_model(asr_text)
                text = result[0]
        return text, t.duration
    
if __name__ == "__main__":
    model = FunasrQuant(device='mps')
    model.load()
    text, cost = model.transcribe('../../test_data/recordings/1.wav', language="en")
    print("inference time: ", cost)
    print(text)