VOX33 / server.py
ssasio's picture
Upload 10 files
7c72478 verified
"""
BgTTS-38M Web Server — Gradio Interface
========================================
Voice cloning TTS with Bulgarian + English support.
"""
import sys
import os
import torch
import numpy as np
import tempfile
import time
import soundfile as sf
# Add parent dir to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from config import (
AUDIO_OFFSET, NUM_AUDIO_TOKENS, END_OF_SPEECH_TOKEN_ID,
START_OF_SPEECH_TOKEN_ID, CODEC_SAMPLE_RATE, CODEC_FRAME_RATE,
)
from tokenizer import TTSTokenizer
from codec import CodecV6
from model import load_for_inference
from inference import generate, _split_text
# ── Global state ──────────────────────────────────────────────
MODEL = None
TOKENIZER = None
CODEC = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoint_inference.pt")
def load_model():
"""Load model, tokenizer, codec once at startup."""
global MODEL, TOKENIZER, CODEC
print(f"Loading model from {CHECKPOINT_PATH} on {DEVICE}...")
MODEL = load_for_inference(CHECKPOINT_PATH, device=DEVICE)
TOKENIZER = TTSTokenizer()
CODEC = CodecV6(device=DEVICE)
print("Model loaded!")
def synthesize_speech(text, ref_audio, temperature, top_k, top_p, rep_penalty):
"""
Generate speech from text using reference audio for voice cloning.
Returns: (sample_rate, audio_array) tuple for Gradio
"""
if not text or not text.strip():
return None
if ref_audio is None:
return None
# Encode reference audio for speaker embedding
sr_ref, audio_ref = ref_audio
audio_ref = audio_ref.astype(np.float32)
if audio_ref.max() > 1.0 or audio_ref.min() < -1.0:
audio_ref = audio_ref / max(abs(audio_ref.max()), abs(audio_ref.min()))
waveform = torch.from_numpy(audio_ref)
if waveform.dim() == 2:
waveform = waveform.mean(1)
result = CODEC.encode_waveform(waveform, sr_ref)
speaker_emb = result['global_embedding'].to(DEVICE)
# Split text into chunks
chunks = _split_text(text, TOKENIZER, max_len=250)
t0 = time.time()
all_codes = []
for chunk in chunks:
codes = generate(
MODEL, TOKENIZER, chunk, speaker_emb,
max_new_tokens=512,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
rep_penalty=rep_penalty,
device=DEVICE
)
if codes is not None and len(codes) > 0:
all_codes.append(codes)
gen_time = time.time() - t0
if not all_codes:
return None
codes = torch.cat(all_codes)
audio_dur = len(codes) / CODEC_FRAME_RATE
rtf = gen_time / audio_dur if audio_dur > 0 else float('inf')
# Decode to waveform
wav = CODEC.decode(codes, speaker_emb)
wav_np = wav.numpy()
info = f"✅ {len(codes)} tokens | {audio_dur:.1f}s audio | {gen_time:.1f}s gen | RTF: {rtf:.3f}"
return (CODEC_SAMPLE_RATE, wav_np), info
def build_ui():
"""Build Gradio interface."""
import gradio as gr
with gr.Blocks(
title="BgTTS-38M — Bulgarian Text-to-Speech",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
),
css="""
.main-title { text-align: center; margin-bottom: 0.5em; }
.subtitle { text-align: center; color: #666; margin-bottom: 1.5em; }
"""
) as app:
gr.HTML('<h1 class="main-title">🎙️ BgTTS-38M</h1>')
gr.HTML('<p class="subtitle">Bulgarian + English Text-to-Speech with Voice Cloning | 38M params | 153MB</p>')
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Текст / Text",
placeholder="Въведете текст на български или английски...\nEnter text in Bulgarian or English...",
lines=5,
max_lines=15,
)
ref_audio = gr.Audio(
label="🎤 Reference Voice (за клониране на глас)",
type="numpy",
sources=["upload", "microphone"],
)
with gr.Row():
generate_btn = gr.Button("🔊 Генерирай / Generate", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Изчисти", size="lg")
with gr.Column(scale=1):
with gr.Accordion("⚙️ Настройки / Settings", open=False):
temperature = gr.Slider(
minimum=0.05, maximum=1.5, value=0.3, step=0.05,
label="Temperature",
info="По-ниска = по-чисто, по-висока = по-разнообразно"
)
top_k = gr.Slider(
minimum=1, maximum=500, value=250, step=10,
label="Top-K"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top-P (Nucleus)"
)
rep_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.1, step=0.05,
label="Repetition Penalty"
)
output_audio = gr.Audio(
label="🔊 Резултат / Output",
type="numpy",
interactive=False,
)
info_text = gr.Textbox(
label="ℹ️ Информация",
interactive=False,
lines=2,
)
# Examples
gr.Examples(
examples=[
["Българският език е изключително богат и мелодичен."],
["Artificial intelligence has reached a fascinating stage."],
["Когато говорим за истински multitasking, способността ми да превключвам плавно между български и English е от огромно значение."],
["Здравейте! Казвам се Ани и мога да говоря на български и английски."],
["The quick brown fox jumps over the lazy dog."],
],
inputs=[text_input],
label="📝 Примери / Examples",
)
# Event handlers
generate_btn.click(
fn=synthesize_speech,
inputs=[text_input, ref_audio, temperature, top_k, top_p, rep_penalty],
outputs=[output_audio, info_text],
)
clear_btn.click(
fn=lambda: (None, None, ""),
outputs=[text_input, output_audio, info_text],
)
return app
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
p.add_argument("--checkpoint", default=CHECKPOINT_PATH)
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=7860)
p.add_argument("--share", action="store_true")
p.add_argument("--device", default=DEVICE)
args = p.parse_args()
CHECKPOINT_PATH = args.checkpoint
DEVICE = args.device
load_model()
app = build_ui()
app.launch(
server_name=args.host,
server_port=args.port,
share=args.share,
)