CopyBextts / app.py_13.11
archivartaunik's picture
Rename app.py to app.py_13.11
cdb8d30 verified
import os
import sys
import tempfile
import subprocess
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from scipy.io.wavfile import write
import numpy as np
from tqdm import tqdm
from underthesea import sent_tokenize
# ---------------------------------------------------------
# 1. Клануем і падключаем coqui-ai-TTS (fork з падтрымкай BE)
# ---------------------------------------------------------
REPO_URL = "https://github.com/tuteishygpt/coqui-ai-TTS.git"
REPO_DIR = "coqui-ai-TTS"
if not os.path.exists(REPO_DIR):
# Клануем fork з беларускай падтрымкай
subprocess.run(
["git", "clone", REPO_URL, REPO_DIR],
check=True,
)
# Дадаём корань рэпазіторыя ў sys.path, каб "import TTS" бачыў пакет
repo_root = os.path.abspath(REPO_DIR)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
# ---------------------------------------------------------
# 2. Шляхі да файлаў мадэлі
# ---------------------------------------------------------
repo_id = "archivartaunik/BE_XTTS_V2_10ep250k"
model_dir = "./model"
os.makedirs(model_dir, exist_ok=True)
checkpoint_file = os.path.join(model_dir, "model.pth")
config_file = os.path.join(model_dir, "config.json")
vocab_file = os.path.join(model_dir, "vocab.json")
default_voice_file = os.path.join(model_dir, "voice.wav")
if not os.path.exists(checkpoint_file):
hf_hub_download(repo_id, filename="model.pth", local_dir=model_dir)
if not os.path.exists(config_file):
hf_hub_download(repo_id, filename="config.json", local_dir=model_dir)
if not os.path.exists(vocab_file):
hf_hub_download(repo_id, filename="vocab.json", local_dir=model_dir)
if not os.path.exists(default_voice_file):
hf_hub_download(repo_id, filename="voice.wav", local_dir=model_dir)
# ---------------------------------------------------------
# 3. Загрузка мадэлі
# ---------------------------------------------------------
config = XttsConfig()
config.load_json(config_file)
XTTS_MODEL = Xtts.init_from_config(config)
XTTS_MODEL.load_checkpoint(
config,
checkpoint_path=checkpoint_file,
vocab_path=vocab_file,
use_deepspeed=False,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
XTTS_MODEL.to(device)
sampling_rate = XTTS_MODEL.config.audio["sample_rate"]
# Базавыя значэнні для кантролю кандыцыянавання (бярэм з канфіга, з запасам па змоўчанні)
CFG_GPT_COND = int(getattr(XTTS_MODEL.config, "gpt_cond_len", 6))
CFG_MAX_REF = int(getattr(XTTS_MODEL.config, "max_ref_len", 20))
CFG_NORM = bool(getattr(XTTS_MODEL.config, "sound_norm_refs", True))
# ---------------------------------------------------------
# 4. Функцыя TTS з параметрамі кіравання
# ---------------------------------------------------------
@spaces.GPU(duration=60)
def text_to_speech(
belarusian_story: str,
speaker_audio_file: str | None,
language: str = "be",
gpt_cond_len: int = CFG_GPT_COND,
max_ref_len: int = CFG_MAX_REF,
sound_norm_refs: bool = CFG_NORM,
temperature: float = 0.2,
length_penalty: float = 1.0,
repetition_penalty: float = 7.0,
top_k: int = 30,
top_p: float = 0.8,
):
"""Генерацыя аўдыя з кіраваннем параметрамі.
Вяртае (sr, np.ndarray) для gr.Audio(type="numpy").
"""
if not belarusian_story or belarusian_story.strip() == "":
raise gr.Error("Увядзі хоць нейкі тэкст 🙂")
# калі аўдыё не перададзена — бярэм голас па змаўчанні
ref_path = speaker_audio_file
if not ref_path or (
not isinstance(ref_path, str)
and getattr(ref_path, "name", "") == ""
):
ref_path = default_voice_file
try:
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=ref_path,
gpt_cond_len=int(gpt_cond_len),
max_ref_length=int(max_ref_len),
sound_norm_refs=bool(sound_norm_refs),
)
except Exception as e:
raise gr.Error(f"Памылка пры атрыманні латэнтаў голасу: {e}")
try:
tts_texts = sent_tokenize(belarusian_story)
except Exception as e:
raise gr.Error(f"Памылка пры падзеле тэксту на сказы: {e}")
all_wavs = []
for text in tqdm(tts_texts):
try:
with torch.no_grad():
wav_chunk = XTTS_MODEL.inference(
text=text,
language=language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=float(temperature),
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
top_k=int(top_k),
top_p=float(top_p),
)
# wav_chunk["wav"] already np.ndarray float32 range [-1,1]
all_wavs.append(wav_chunk["wav"])
except Exception as e:
raise gr.Error(f"Памылка пры генерырацыі аўдыя: {e}")
if not all_wavs:
raise gr.Error("Нічога не згенеравалася — праверце ўваходныя даныя.")
try:
out_wav = np.concatenate(all_wavs).astype(np.float32)
except ValueError:
raise gr.Error(
"Немагчыма згенераваць аўдыё. Праверце ўваходны тэкст і аўдыёфайл."
)
except Exception as e:
raise gr.Error(f"Памылка пры аб'яднанні аўдыя: {e}")
# Для type="numpy" вяртаем (sr, waveform)
return (int(sampling_rate), out_wav)
# ---------------------------------------------------------
# 5. UI (Gradio Blocks) з параметрамі кіравання
# ---------------------------------------------------------
analytics_script = """
<script async src="https://www.googletagmanager.com/gtag/js?id=G-TKDCRCQ7FK"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-TKDCRCQ7FK');
</script>
"""
with gr.Blocks() as demo:
gr.HTML(analytics_script)
gr.Markdown("# Belarusian TTS Demo")
gr.Markdown("Увядзіце тэкст, абярыце/загрузіце узор голасу і згенеруйце аўдыя.")
with gr.Row():
txt = gr.Textbox(lines=6, label="Тэкст на беларускай мове")
ref = gr.Audio(type="filepath", label="Прыклад голасу (≥7 с)")
with gr.Accordion("Параметры кандыцыянавання", open=False):
with gr.Row():
language = gr.Dropdown(
label="Мова (language)",
choices=[
"be","ru","uk","pl","cs","en","de","fr","es","it","pt","tr","vi","zh","ja"
],
value="be",
)
gpt_cond_len = gr.Slider(1, max(1, CFG_GPT_COND*3), step=1, value=CFG_GPT_COND, label="gpt_cond_len (сек.)")
max_ref_len = gr.Slider(1, max(1, CFG_MAX_REF*3), step=1, value=CFG_MAX_REF, label="max_ref_len (сек.)")
sound_norm_refs = gr.Checkbox(value=CFG_NORM, label="sound_norm_refs")
with gr.Accordion("Параметры генерацыі", open=True):
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.2, step=0.01, label="temperature")
top_k = gr.Slider(1, 100, value=30, step=1, label="top_k")
top_p = gr.Slider(0.0, 1.0, value=0.8, step=0.01, label="top_p")
with gr.Row():
length_penalty = gr.Slider(0.5, 3.5, value=1.0, step=0.05, label="length_penalty")
repetition_penalty = gr.Slider(0.5, 20.0, value=7.0, step=0.1, label="repetition_penalty")
out_audio = gr.Audio(type="numpy", label="Згенераванае аўдыя")
btn = gr.Button("🔊 Генераваць")
btn.click(
fn=text_to_speech,
inputs=[
txt,
ref,
language,
gpt_cond_len,
max_ref_len,
sound_norm_refs,
temperature,
length_penalty,
repetition_penalty,
top_k,
top_p,
],
outputs=out_audio,
)
# ---------------------------------------------------------
# 6. Запуск
# ---------------------------------------------------------
if __name__ == "__main__":
demo.launch()