| | import subprocess |
| | import sys |
| |
|
| | |
| | subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts==0.2.3", "uroman", "-q"]) |
| |
|
| | import torch |
| | import torchaudio |
| | import io |
| | import base64 |
| | from transformers import AutoModelForCausalLM |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | from audiotokenizer import AudioTokenizerV2 |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {self.device}") |
| |
|
| | print("Downloading WavTokenizer config...") |
| | wav_config_path = hf_hub_download( |
| | repo_id="novateur/WavTokenizer-medium-speech-75token", |
| | filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" |
| | ) |
| |
|
| | print("Downloading WavTokenizer model...") |
| | wav_model_path = hf_hub_download( |
| | repo_id="novateur/WavTokenizer-large-speech-75token", |
| | filename="wavtokenizer_large_speech_320_v2.ckpt" |
| | ) |
| |
|
| | print("Initializing AudioTokenizer...") |
| | self.audio_tokenizer = AudioTokenizerV2( |
| | path, |
| | wav_model_path, |
| | wav_config_path |
| | ) |
| |
|
| | print("Loading YarnGPT model...") |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, |
| | torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32 |
| | ).to(self.device) |
| | self.model.eval() |
| | print("Model loaded successfully!") |
| |
|
| | self.voices = ["tayo", "idera", "jude", "osagie", "umar", "emma", "zainab", "remi", "regina", "chinenye", "joke"] |
| |
|
| | def __call__(self, data): |
| | text = data.get("inputs", data.get("text", "")) |
| | voice = data.get("voice", "tayo").lower() |
| | language = data.get("language", "english") |
| |
|
| | if not text: |
| | return {"error": "No input text provided"} |
| |
|
| | if voice not in self.voices: |
| | voice = "tayo" |
| |
|
| | try: |
| | prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice) |
| | input_ids = self.audio_tokenizer.tokenize_prompt(prompt).to(self.device) |
| |
|
| | with torch.no_grad(): |
| | output = self.model.generate( |
| | input_ids=input_ids, |
| | temperature=0.1, |
| | repetition_penalty=1.1, |
| | max_length=8000, |
| | do_sample=True |
| | ) |
| |
|
| | codes = self.audio_tokenizer.get_codes(output) |
| | audio = self.audio_tokenizer.get_audio(codes) |
| |
|
| | buffer = io.BytesIO() |
| | torchaudio.save(buffer, audio.cpu(), sample_rate=24000, format="wav") |
| | audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| |
|
| | return { |
| | "audio": audio_base64, |
| | "sampling_rate": 24000, |
| | "format": "wav", |
| | "voice": voice |
| | } |
| |
|
| | except Exception as e: |
| | import traceback |
| | return {"error": str(e), "traceback": traceback.format_exc()} |
| |
|