File size: 3,138 Bytes
78c53cb
 
 
837e0e4
e17efa5
78c53cb
b591407
 
 
 
78c53cb
b591407
ff81a96
837e0e4
 
b591407
 
 
 
 
78c53cb
b591407
78c53cb
b591407
 
 
 
 
78c53cb
b591407
 
 
 
 
78c53cb
b591407
 
 
 
 
 
78c53cb
b591407
 
 
 
 
78c53cb
b591407
ff81a96
b591407
 
 
 
 
 
 
 
 
 
 
 
 
ff81a96
 
 
b591407
 
 
 
 
5b9a927
b591407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78c53cb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import subprocess
import sys

# Install dependencies first
subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts==0.2.3", "uroman", "-q"])

import torch
import torchaudio
import io
import base64
from transformers import AutoModelForCausalLM
from huggingface_hub import hf_hub_download

# Import local audiotokenizer (in same directory as handler.py)
from audiotokenizer import AudioTokenizerV2


class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        print("Downloading WavTokenizer config...")
        wav_config_path = hf_hub_download(
            repo_id="novateur/WavTokenizer-medium-speech-75token",
            filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
        )

        print("Downloading WavTokenizer model...")
        wav_model_path = hf_hub_download(
            repo_id="novateur/WavTokenizer-large-speech-75token",
            filename="wavtokenizer_large_speech_320_v2.ckpt"
        )

        print("Initializing AudioTokenizer...")
        self.audio_tokenizer = AudioTokenizerV2(
            path,
            wav_model_path,
            wav_config_path
        )

        print("Loading YarnGPT model...")
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32
        ).to(self.device)
        self.model.eval()
        print("Model loaded successfully!")

        self.voices = ["tayo", "idera", "jude", "osagie", "umar", "emma", "zainab", "remi", "regina", "chinenye", "joke"]

    def __call__(self, data):
        text = data.get("inputs", data.get("text", ""))
        voice = data.get("voice", "tayo").lower()
        language = data.get("language", "english")

        if not text:
            return {"error": "No input text provided"}

        if voice not in self.voices:
            voice = "tayo"

        try:
            prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
            input_ids = self.audio_tokenizer.tokenize_prompt(prompt).to(self.device)

            with torch.no_grad():
                output = self.model.generate(
                    input_ids=input_ids,
                    temperature=0.1,
                    repetition_penalty=1.1,
                    max_length=8000,
                    do_sample=True
                )

            codes = self.audio_tokenizer.get_codes(output)
            audio = self.audio_tokenizer.get_audio(codes)

            buffer = io.BytesIO()
            torchaudio.save(buffer, audio.cpu(), sample_rate=24000, format="wav")
            audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")

            return {
                "audio": audio_base64,
                "sampling_rate": 24000,
                "format": "wav",
                "voice": voice
            }

        except Exception as e:
            import traceback
            return {"error": str(e), "traceback": traceback.format_exc()}