File size: 3,329 Bytes
cfe276e
 
e7fd9f7
cbf8a35
 
 
 
 
626684c
eb10866
 
cbf8a35
 
5e4143f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbf8a35
 
 
 
 
 
3c6df0d
cbf8a35
 
 
3c6df0d
 
 
cbf8a35
 
 
 
 
 
 
 
5e4143f
cbf8a35
 
 
 
 
5e4143f
8a5d43b
 
 
cbf8a35
 
 
5e4143f
cbf8a35
 
 
 
 
8a5d43b
 
f122944
805fc7d
cbf8a35
 
1a6a41a
 
 
cbf8a35
c74f8e0
cbf8a35
5e4143f
 
 
 
 
 
 
cbf8a35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os

os.environ["COQUI_TOS_AGREED"] = "1"
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import torch
import time
import torchaudio
import io
import base64
import requests
import tempfile


def convert_audio_urls_to_paths(audio_urls):
    temp_files = []
    audio_paths = []

    for url in audio_urls:
        filename = url.split("/")[-1]
        file_destination_path, file_object = download_tempfile(
            file_url=url, filename=filename
        )
        temp_files.append(file_object)
        audio_paths.append(file_destination_path)

    return audio_paths, temp_files


def download_tempfile(file_url, filename):
    try:
        response = requests.get(file_url)
        response.raise_for_status()
        filetype = filename.split(".")[-1]
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{filetype}")
        temp_file.write(response.content)
        return temp_file.name, temp_file
    except Exception as e:
        print(f"Error downloading file: {e}")
        return None, None


class EndpointHandler:
    def __init__(self, path=""):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        config = XttsConfig()
        config.load_json("/repository/model/config.json")
        model = Xtts.init_from_config(config)
        model.load_checkpoint(
            config,
            checkpoint_path="/repository/model/model.pth",
            vocab_path="/repository/model/vocab.json",
            speaker_file_path="/repository/model/speakers_xtts.pth",
            eval=True,
            use_deepspeed=device == "cuda",
        )
        model.to(device)

        self.model = model

    def __call__(self, model_input):
        audio_paths, temp_files = convert_audio_urls_to_paths(model_input["audio_urls"])

        (
            gpt_cond_latent,
            speaker_embedding,
        ) = self.model.get_conditioning_latents(
            audio_path=audio_paths,
            gpt_cond_len=int(model_input["gpt_cond_len"]),
            gpt_cond_chunk_len=int(model_input["gpt_cond_chunk_len"]),
            max_ref_length=int(model_input["max_ref_length"]),
        )

        print("Generating audio")

        t0 = time.time()
        out = self.model.inference(
            text=model_input["text"],
            speaker_embedding=speaker_embedding,
            gpt_cond_latent=gpt_cond_latent,
            temperature=float(model_input["temperature"]),
            repetition_penalty=float(model_input["repetition_penalty"]),
            language=model_input["language"][0],
            enable_text_splitting=True,
        )
        audio_file = io.BytesIO()
        torchaudio.save(
            audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000, format="wav"
        )
        inference_time = time.time() - t0
        print(f"I: Time to generate audio: {inference_time} seconds")
        audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")

        try:
            for temp_file in temp_files:
                os.remove(temp_file)
        except Exception as e:
            print(f"Error removing temp files: {e}")

        return {"data": audio_str, "format": "wav"}