yarngpt-custom / handler.py
laztopaz's picture
Upload handler.py with huggingface_hub
5b9a927 verified
import subprocess
import sys
# Install dependencies first
subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts==0.2.3", "uroman", "-q"])
import torch
import torchaudio
import io
import base64
from transformers import AutoModelForCausalLM
from huggingface_hub import hf_hub_download
# Import local audiotokenizer (in same directory as handler.py)
from audiotokenizer import AudioTokenizerV2
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
print("Downloading WavTokenizer config...")
wav_config_path = hf_hub_download(
repo_id="novateur/WavTokenizer-medium-speech-75token",
filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
)
print("Downloading WavTokenizer model...")
wav_model_path = hf_hub_download(
repo_id="novateur/WavTokenizer-large-speech-75token",
filename="wavtokenizer_large_speech_320_v2.ckpt"
)
print("Initializing AudioTokenizer...")
self.audio_tokenizer = AudioTokenizerV2(
path,
wav_model_path,
wav_config_path
)
print("Loading YarnGPT model...")
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32
).to(self.device)
self.model.eval()
print("Model loaded successfully!")
self.voices = ["tayo", "idera", "jude", "osagie", "umar", "emma", "zainab", "remi", "regina", "chinenye", "joke"]
def __call__(self, data):
text = data.get("inputs", data.get("text", ""))
voice = data.get("voice", "tayo").lower()
language = data.get("language", "english")
if not text:
return {"error": "No input text provided"}
if voice not in self.voices:
voice = "tayo"
try:
prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
input_ids = self.audio_tokenizer.tokenize_prompt(prompt).to(self.device)
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
temperature=0.1,
repetition_penalty=1.1,
max_length=8000,
do_sample=True
)
codes = self.audio_tokenizer.get_codes(output)
audio = self.audio_tokenizer.get_audio(codes)
buffer = io.BytesIO()
torchaudio.save(buffer, audio.cpu(), sample_rate=24000, format="wav")
audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return {
"audio": audio_base64,
"sampling_rate": 24000,
"format": "wav",
"voice": voice
}
except Exception as e:
import traceback
return {"error": str(e), "traceback": traceback.format_exc()}