|
|
from pywhispercpp.model import Model |
|
|
import soundfile |
|
|
import numpy as np |
|
|
from logging import getLogger |
|
|
from pathlib import Path |
|
|
|
|
|
from lib.utils import Timer, read_audio |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
MODEL_DIR = Path("/Users/jeqin/work/code/Translator/python_server/moyoyo_asr_models") |
|
|
WHISPER_PROMPT_ZH = "以下是简体中文普通话的句子。" |
|
|
WHISPER_PROMPT_EN = "" |
|
|
|
|
|
class WhisperCPP: |
|
|
def __init__(self, model_dir=MODEL_DIR, source_lange: str = 'en') -> None: |
|
|
whisper_model = 'large-v3-turbo-q5_0' |
|
|
with Timer("load whisper"): |
|
|
self.model = Model( |
|
|
model=whisper_model, |
|
|
models_dir=str(model_dir), |
|
|
print_realtime=False, |
|
|
print_progress=False, |
|
|
print_timestamps=False, |
|
|
translate=False, |
|
|
|
|
|
temperature=0., |
|
|
no_context=True |
|
|
) |
|
|
self._warmup() |
|
|
|
|
|
def _warmup(self): |
|
|
fake_audio = np.random.randn(16000).astype(np.float32) |
|
|
self.model.transcribe(fake_audio, print_progress=False) |
|
|
|
|
|
@staticmethod |
|
|
def config_language(language): |
|
|
if language == "zh": |
|
|
return WHISPER_PROMPT_ZH |
|
|
elif language == "en": |
|
|
return WHISPER_PROMPT_EN |
|
|
raise ValueError(f"Unsupported language : {language}") |
|
|
|
|
|
def transcribe(self, audio: np.ndarray, language): |
|
|
prompt = self.config_language(language) |
|
|
try: |
|
|
with Timer("whisper inference") as t: |
|
|
segments = self.model.transcribe( |
|
|
audio, |
|
|
initial_prompt=prompt, |
|
|
language=language, |
|
|
|
|
|
split_on_word=True, |
|
|
|
|
|
) |
|
|
text = "".join([s.text for s in segments]) |
|
|
return text, t.duration |
|
|
except Exception as e: |
|
|
logger.error(e) |
|
|
return [] |
|
|
|
|
|
if __name__ == '__main__': |
|
|
from lib.utils import read_audio |
|
|
whisper = WhisperCPP() |
|
|
audio = read_audio(Path("/Users/jeqin/work/code/TestTranslator/test_data/recordings/1.wav")) |
|
|
text, time_cost = whisper.transcribe(audio, "zh") |
|
|
print(text) |
|
|
print(time_cost) |