File size: 3,329 Bytes
cfe276e e7fd9f7 cbf8a35 626684c eb10866 cbf8a35 5e4143f cbf8a35 3c6df0d cbf8a35 3c6df0d cbf8a35 5e4143f cbf8a35 5e4143f 8a5d43b cbf8a35 5e4143f cbf8a35 8a5d43b f122944 805fc7d cbf8a35 1a6a41a cbf8a35 c74f8e0 cbf8a35 5e4143f cbf8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
os.environ["COQUI_TOS_AGREED"] = "1"
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import torch
import time
import torchaudio
import io
import base64
import requests
import tempfile
def convert_audio_urls_to_paths(audio_urls):
temp_files = []
audio_paths = []
for url in audio_urls:
filename = url.split("/")[-1]
file_destination_path, file_object = download_tempfile(
file_url=url, filename=filename
)
temp_files.append(file_object)
audio_paths.append(file_destination_path)
return audio_paths, temp_files
def download_tempfile(file_url, filename):
try:
response = requests.get(file_url)
response.raise_for_status()
filetype = filename.split(".")[-1]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{filetype}")
temp_file.write(response.content)
return temp_file.name, temp_file
except Exception as e:
print(f"Error downloading file: {e}")
return None, None
class EndpointHandler:
def __init__(self, path=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
config = XttsConfig()
config.load_json("/repository/model/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_path="/repository/model/model.pth",
vocab_path="/repository/model/vocab.json",
speaker_file_path="/repository/model/speakers_xtts.pth",
eval=True,
use_deepspeed=device == "cuda",
)
model.to(device)
self.model = model
def __call__(self, model_input):
audio_paths, temp_files = convert_audio_urls_to_paths(model_input["audio_urls"])
(
gpt_cond_latent,
speaker_embedding,
) = self.model.get_conditioning_latents(
audio_path=audio_paths,
gpt_cond_len=int(model_input["gpt_cond_len"]),
gpt_cond_chunk_len=int(model_input["gpt_cond_chunk_len"]),
max_ref_length=int(model_input["max_ref_length"]),
)
print("Generating audio")
t0 = time.time()
out = self.model.inference(
text=model_input["text"],
speaker_embedding=speaker_embedding,
gpt_cond_latent=gpt_cond_latent,
temperature=float(model_input["temperature"]),
repetition_penalty=float(model_input["repetition_penalty"]),
language=model_input["language"][0],
enable_text_splitting=True,
)
audio_file = io.BytesIO()
torchaudio.save(
audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000, format="wav"
)
inference_time = time.time() - t0
print(f"I: Time to generate audio: {inference_time} seconds")
audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
try:
for temp_file in temp_files:
os.remove(temp_file)
except Exception as e:
print(f"Error removing temp files: {e}")
return {"data": audio_str, "format": "wav"}
|