File size: 3,138 Bytes
78c53cb 837e0e4 e17efa5 78c53cb b591407 78c53cb b591407 ff81a96 837e0e4 b591407 78c53cb b591407 78c53cb b591407 78c53cb b591407 78c53cb b591407 78c53cb b591407 78c53cb b591407 ff81a96 b591407 ff81a96 b591407 5b9a927 b591407 78c53cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import subprocess
import sys
# Install dependencies first
subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts==0.2.3", "uroman", "-q"])
import torch
import torchaudio
import io
import base64
from transformers import AutoModelForCausalLM
from huggingface_hub import hf_hub_download
# Import local audiotokenizer (in same directory as handler.py)
from audiotokenizer import AudioTokenizerV2
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
print("Downloading WavTokenizer config...")
wav_config_path = hf_hub_download(
repo_id="novateur/WavTokenizer-medium-speech-75token",
filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
)
print("Downloading WavTokenizer model...")
wav_model_path = hf_hub_download(
repo_id="novateur/WavTokenizer-large-speech-75token",
filename="wavtokenizer_large_speech_320_v2.ckpt"
)
print("Initializing AudioTokenizer...")
self.audio_tokenizer = AudioTokenizerV2(
path,
wav_model_path,
wav_config_path
)
print("Loading YarnGPT model...")
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32
).to(self.device)
self.model.eval()
print("Model loaded successfully!")
self.voices = ["tayo", "idera", "jude", "osagie", "umar", "emma", "zainab", "remi", "regina", "chinenye", "joke"]
def __call__(self, data):
text = data.get("inputs", data.get("text", ""))
voice = data.get("voice", "tayo").lower()
language = data.get("language", "english")
if not text:
return {"error": "No input text provided"}
if voice not in self.voices:
voice = "tayo"
try:
prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
input_ids = self.audio_tokenizer.tokenize_prompt(prompt).to(self.device)
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
temperature=0.1,
repetition_penalty=1.1,
max_length=8000,
do_sample=True
)
codes = self.audio_tokenizer.get_codes(output)
audio = self.audio_tokenizer.get_audio(codes)
buffer = io.BytesIO()
torchaudio.save(buffer, audio.cpu(), sample_rate=24000, format="wav")
audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return {
"audio": audio_base64,
"sampling_rate": 24000,
"format": "wav",
"voice": voice
}
except Exception as e:
import traceback
return {"error": str(e), "traceback": traceback.format_exc()}
|