import logging, sys from functools import partial from groq import Groq from tools.webui.inference import inference_wrapper from typing import Callable logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", stream=sys.stdout ) DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" DEFAULT_REF_TEXT = "That place in the distance, it's huge and dedicated to Lady Shah. It can only mean one thing. I have a hidden place close to the cloister where night orchids bloom." def transcribe(file_path: str): client = Groq() with open(file_path, "rb") as file: transcription = client.audio.transcriptions.create( file=(file_path, file.read()), model="whisper-large-v3-turbo", temperature=0, response_format="verbose_json", ) if len(transcription.text) <= 0: logging.warn("Error while transcripting the reference audio.") return transcription.text def run_tts( text, reference_id, reference_audio, reference_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature, seed, use_memory_cache, engine, ): if reference_text is None: reference_text = "" if "female_shadowheart4.flac" not in reference_audio and reference_text == DEFAULT_REF_TEXT: reference_text = "" if not len(reference_text): reference_text = transcribe(reference_audio) return inference_wrapper( text, reference_id, reference_audio, reference_text, max_new_tokens, chunk_length, top_p, repetition_penalty, temperature, seed, use_memory_cache, engine, ) def custom_inference_wrapper(engine) -> Callable: """ Get the inference function with the immutable arguments. """ return partial( run_tts, engine=engine, )