fish-speech / model.py
playmak3r's picture
feat: implement custom wrappers for the UI inference function
a2888ef
import logging, sys
from functools import partial
from groq import Groq
from tools.webui.inference import inference_wrapper
from typing import Callable
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
stream=sys.stdout
)
DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
DEFAULT_REF_TEXT = "That place in the distance, it's huge and dedicated to Lady Shah. It can only mean one thing. I have a hidden place close to the cloister where night orchids bloom."
def transcribe(file_path: str):
client = Groq()
with open(file_path, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(file_path, file.read()),
model="whisper-large-v3-turbo",
temperature=0,
response_format="verbose_json",
)
if len(transcription.text) <= 0: logging.warn("Error while transcripting the reference audio.")
return transcription.text
def run_tts(
text,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
engine,
):
if reference_text is None:
reference_text = ""
if "female_shadowheart4.flac" not in reference_audio and reference_text == DEFAULT_REF_TEXT:
reference_text = ""
if not len(reference_text):
reference_text = transcribe(reference_audio)
return inference_wrapper(
text,
reference_id,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
seed,
use_memory_cache,
engine,
)
def custom_inference_wrapper(engine) -> Callable:
"""
Get the inference function with the immutable arguments.
"""
return partial(
run_tts,
engine=engine,
)