import subprocess import sys # Install dependencies first subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts==0.2.3", "uroman", "-q"]) import torch import torchaudio import io import base64 from transformers import AutoModelForCausalLM from huggingface_hub import hf_hub_download # Import local audiotokenizer (in same directory as handler.py) from audiotokenizer import AudioTokenizerV2 class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") print("Downloading WavTokenizer config...") wav_config_path = hf_hub_download( repo_id="novateur/WavTokenizer-medium-speech-75token", filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" ) print("Downloading WavTokenizer model...") wav_model_path = hf_hub_download( repo_id="novateur/WavTokenizer-large-speech-75token", filename="wavtokenizer_large_speech_320_v2.ckpt" ) print("Initializing AudioTokenizer...") self.audio_tokenizer = AudioTokenizerV2( path, wav_model_path, wav_config_path ) print("Loading YarnGPT model...") self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32 ).to(self.device) self.model.eval() print("Model loaded successfully!") self.voices = ["tayo", "idera", "jude", "osagie", "umar", "emma", "zainab", "remi", "regina", "chinenye", "joke"] def __call__(self, data): text = data.get("inputs", data.get("text", "")) voice = data.get("voice", "tayo").lower() language = data.get("language", "english") if not text: return {"error": "No input text provided"} if voice not in self.voices: voice = "tayo" try: prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice) input_ids = self.audio_tokenizer.tokenize_prompt(prompt).to(self.device) with torch.no_grad(): output = self.model.generate( input_ids=input_ids, temperature=0.1, repetition_penalty=1.1, max_length=8000, do_sample=True ) codes = self.audio_tokenizer.get_codes(output) audio = self.audio_tokenizer.get_audio(codes) buffer = io.BytesIO() torchaudio.save(buffer, audio.cpu(), sample_rate=24000, format="wav") audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") return { "audio": audio_base64, "sampling_rate": 24000, "format": "wav", "voice": voice } except Exception as e: import traceback return {"error": str(e), "traceback": traceback.format_exc()}