Lishani / app.py
drvsbrkcn's picture
Update app.py
2043994 verified
import os
os.environ["GRADIO_MCP_SERVER"] = "True"
import sys
import uuid
import time
import threading
import tempfile
import subprocess
from pathlib import Path
from typing import List, Tuple
import gradio as gr
import numpy as np
import torch
import torchaudio
from pydub import AudioSegment
from huggingface_hub import snapshot_download
# Optional deps (graceful fallbacks)
try:
import webrtcvad
WEBRTCVAD_AVAILABLE = True
except Exception:
WEBRTCVAD_AVAILABLE = False
try:
from phonemizer import phonemize
PHONEMIZER_AVAILABLE = True
except Exception:
PHONEMIZER_AVAILABLE = False
try:
import num2words
NUM2WORDS_AVAILABLE = True
except Exception:
NUM2WORDS_AVAILABLE = False
# ---- Coqui XTTS imports ----
try:
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
print("TTS modules imported successfully")
except ImportError as e:
print(f"TTS import error: {e}")
print("Make sure you have installed coqui-tts.")
print("You can install it with: pip install coqui-tts")
# Don't exit immediately, let the user see the error in the UI
TTS_AVAILABLE = False
else:
TTS_AVAILABLE = True
# ----------------- Paths & setup -----------------
BASE_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
MODELS_DIR = BASE_DIR / "XTTS-v2"
REF_AUDIO_DIR = BASE_DIR / "ref_audio_files"
OUTPUT_DIR = BASE_DIR / "outputs"
TEMP_DIR = OUTPUT_DIR / "temp"
for p in [REF_AUDIO_DIR, OUTPUT_DIR, TEMP_DIR]:
p.mkdir(parents=True, exist_ok=True)
SUPPORTED_LANGUAGES = {
"English": "en",
"French": "fr",
"Spanish": "es",
"German": "de",
"Italian": "it",
"Portuguese": "pt",
"Polish": "pl",
"Turkish": "tr",
"Russian": "ru",
"Ukrainian": "uk",
"Dutch": "nl",
"Czech": "cs",
"Arabic": "ar",
"Chinese (zh)": "zh",
"Japanese": "ja",
"Korean": "ko",
"Hindi": "hi",
}
# ----------------- Model download / load -----------------
def ensure_xtts_repo():
if MODELS_DIR.exists() and (MODELS_DIR / "config.json").exists():
print("XTTS-v2 model already present.")
return
try:
print("Downloading XTTS-v2 model...")
snapshot_download(
repo_id="coqui/XTTS-v2",
local_dir=str(MODELS_DIR),
allow_patterns=["*.safetensors", "*.wav", "*.json", "*.pth"],
)
print("Model downloaded successfully!")
except Exception as e:
print(f"Snapshot download failed: {e}")
# Fallback: try git clone for Spaces that restrict hub fs ops
try:
print("Attempting git clone fallback...")
result = subprocess.run(
["git", "clone", "https://huggingface.co/coqui/XTTS-v2", str(MODELS_DIR)],
capture_output=True,
text=True,
)
if result.returncode == 0:
print("Model downloaded via git clone!")
else:
print("git clone error:", result.stderr)
raise RuntimeError(result.stderr)
except Exception as ge:
print(f"git clone failed: {ge}")
raise RuntimeError(
"Please add the model manually: git clone https://huggingface.co/coqui/XTTS-v2"
)
# Initialize model only if TTS is available
if TTS_AVAILABLE:
ensure_xtts_repo()
# Load config/model
print("Loading XTTS configuration...")
config = XttsConfig()
config.load_json(str(MODELS_DIR / "config.json"))
print("Configuration loaded.")
print("Initializing XTTS model...")
model = Xtts.init_from_config(config)
print("Model initialized.")
print("Loading checkpoint...")
model.load_checkpoint(
config,
checkpoint_dir=str(MODELS_DIR),
eval=True,
use_deepspeed=False,
)
print("Checkpoint loaded.")
if torch.cuda.is_available():
model.cuda()
print("Model on GPU.")
else:
print("GPU not available, using CPU.")
else:
print("TTS not available - model initialization skipped")
model = None
config = None
# ----------------- Audio/text utilities -----------------
def loudness_normalize_tensor(wav: torch.Tensor, target_rms: float = 0.03, eps: float = 1e-9) -> torch.Tensor:
"""Very light RMS-based normalization (EBU-like target without full LUFS graph)."""
rms = torch.sqrt(torch.clamp((wav ** 2).mean(), min=eps))
gain = target_rms / max(rms, eps)
out = torch.clamp(wav * gain, -1.0, 1.0)
return out
def optional_light_denoise(wav: torch.Tensor, sr: int) -> torch.Tensor:
"""Stub for RNNoise/spectral gate. Left identity by default."""
return wav
def normalize_text(txt: str, language_code: str) -> str:
# Expand bare integers to words for English; pass-through for other locales.
if language_code == "en" and NUM2WORDS_AVAILABLE:
import re
def repl(m):
try:
return num2words.num2words(int(m.group(0)), lang="en")
except Exception:
return m.group(0)
txt = re.sub(r"\b\d{1,6}\b", repl, txt)
txt = txt.replace("&", " and ")
return txt
def maybe_phonemize(txt: str, language_code: str) -> str:
# XTTS handles graphemes well; keep as no-op by default.
# You can switch to phoneme-only text here if you find recurrent mispronunciations.
return txt
def vad_segments_webrtc(y: torch.Tensor, sr: int, frame_ms: int = 20,
aggressiveness: int = 2, min_speech_ms: int = 200,
max_merge_gap_ms: int = 200, pad_ms: int = 80) -> List[Tuple[int, int]]:
"""Return [(start_ms, end_ms), ...] speech regions using WebRTC-VAD with padding."""
if not WEBRTCVAD_AVAILABLE:
return [(0, int(1000 * y.shape[-1] / sr))]
vad = webrtcvad.Vad(aggressiveness)
frame_len = int(sr * frame_ms / 1000)
num_frames = max(1, y.shape[-1] // frame_len)
regions = []
cur_start = None
last_t = 0
for i in range(num_frames):
seg = y[0, i * frame_len : (i + 1) * frame_len]
if seg.numel() < frame_len:
seg = torch.nn.functional.pad(seg, (0, frame_len - seg.numel()))
seg16 = (seg.clamp(-1, 1) * 32767.0).short().numpy().tobytes()
t_ms = i * frame_ms
is_sp = vad.is_speech(seg16, sample_rate=sr)
if is_sp and cur_start is None:
cur_start = t_ms
if (not is_sp) and cur_start is not None:
if t_ms - cur_start >= min_speech_ms:
regions.append([cur_start, t_ms])
cur_start = None
last_t = t_ms
if cur_start is not None:
regions.append([cur_start, last_t + frame_ms])
# Merge small gaps then pad
merged = []
for st, en in regions:
if not merged:
merged.append([st, en])
else:
if st - merged[-1][1] <= max_merge_gap_ms:
merged[-1][1] = en
else:
merged.append([st, en])
padded = []
for st, en in merged:
padded.append([max(0, st - pad_ms), en + pad_ms])
return [(st, en) for st, en in padded] if padded else [(0, int(1000 * y.shape[-1] / sr))]
# ----------------- Voice latent cache -----------------
_VOICE_CACHE = {} # key: (path, mtime) -> (gpt_latent, spk_emb) on model device
_MAX_CACHE_SIZE = 10 # Limit cache size to prevent memory issues
def get_latents(reference_audio_path: str):
if not TTS_AVAILABLE or model is None:
raise RuntimeError("TTS model not available. Please check your installation.")
key = (reference_audio_path, os.path.getmtime(reference_audio_path))
if key in _VOICE_CACHE:
return _VOICE_CACHE[key]
# Clean cache if it gets too large
if len(_VOICE_CACHE) >= _MAX_CACHE_SIZE:
# Remove oldest entries (simple FIFO)
oldest_keys = list(_VOICE_CACHE.keys())[:len(_VOICE_CACHE) - _MAX_CACHE_SIZE + 1]
for old_key in oldest_keys:
del _VOICE_CACHE[old_key]
try:
# Pre-clean reference then compute conditioning latents
ref, sr = torchaudio.load(reference_audio_path)
if sr != 24000:
ref = torchaudio.functional.resample(ref, sr, 24000)
sr = 24000
ref = ref.mean(dim=0, keepdim=True)
ref = loudness_normalize_tensor(ref)
ref = optional_light_denoise(ref, sr)
tmp_ref = str(TEMP_DIR / f"ref_{uuid.uuid4().hex}.wav")
torchaudio.save(tmp_ref, ref, sr)
gpt_latent, spk_emb = model.get_conditioning_latents(audio_path=[tmp_ref])
try:
os.remove(tmp_ref)
except Exception:
pass
dev = next(model.parameters()).device
_VOICE_CACHE[key] = (gpt_latent.to(dev), spk_emb.to(dev))
return _VOICE_CACHE[key]
except Exception as e:
print(f"Error getting latents: {e}")
raise
# ----------------- Synthesis core -----------------
def synthesize_speech(
text: str,
language: str,
temperature: float,
speed: float,
reference_audio_path: str,
do_sample: bool,
enable_text_splitting: bool,
repetition_penalty: float,
length_penalty: float,
gpt_cond_len: int, # kept for UI continuity (unused by inference w/ cached latents)
top_k: int,
top_p: float,
remove_silence_enabled: bool,
silence_threshold: float, # kept for back-compat; unused with VAD
min_silence_len: int,
keep_silence: int,
text_splitting_method: str,
max_chars_per_segment: int,
) -> Tuple[str, str]:
"""
Returns (mp3_path, wav_master_path)
"""
if not TTS_AVAILABLE or model is None:
print("Error: TTS model not available")
return None, None
try:
language_code = SUPPORTED_LANGUAGES.get(language, "en")
# Clean text
clean_text = normalize_text(text, language_code)
clean_text = maybe_phonemize(clean_text, language_code)
# Precompute latents once per request
gpt_latent, spk_emb = get_latents(reference_audio_path)
# Split strategy
def chunk_text(t: str, size: int = 250) -> List[str]:
if len(t) <= size:
return [t]
chunks, cur = [], []
for tok in t.split():
if sum(len(w) + 1 for w in cur) + len(tok) + 1 > size:
chunks.append(" ".join(cur))
cur = [tok]
else:
cur.append(tok)
if cur:
chunks.append(" ".join(cur))
return chunks
outputs_wav_list: List[np.ndarray] = []
if text_splitting_method == "Native XTTS splitting":
out = model.inference(
text=clean_text,
language=language_code,
gpt_cond_latent=gpt_latent,
speaker_embedding=spk_emb,
temperature=temperature,
do_sample=do_sample,
speed=speed,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
)
outputs_wav_list.append(out["wav"])
elif text_splitting_method == "Custom splitting":
chunks = chunk_text(clean_text, max_chars_per_segment)
for i, chunk in enumerate(chunks, 1):
print(f"Processing segment {i}/{len(chunks)}")
out = model.inference(
text=chunk,
language=language_code,
gpt_cond_latent=gpt_latent,
speaker_embedding=spk_emb,
temperature=temperature,
do_sample=do_sample,
speed=speed,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
)
outputs_wav_list.append(out["wav"])
else:
# No splitting
out = model.inference(
text=clean_text,
language=language_code,
gpt_cond_latent=gpt_latent,
speaker_embedding=spk_emb,
temperature=temperature,
do_sample=do_sample,
speed=speed,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
)
outputs_wav_list.append(out["wav"])
wav_np = np.concatenate(outputs_wav_list) if len(outputs_wav_list) > 1 else outputs_wav_list[0]
wav_tensor = torch.tensor(wav_np, dtype=torch.float32).unsqueeze(0) # [1, T], 24k
# Optional VAD-based trimming on the result (more natural than amplitude-split)
if remove_silence_enabled:
print("Applying VAD-based trimming...")
segs = vad_segments_webrtc(
wav_tensor, sr=24000, frame_ms=20, aggressiveness=2,
min_speech_ms=max(120, min_silence_len // 2),
max_merge_gap_ms=keep_silence,
pad_ms=max(50, keep_silence // 2),
)
# Build trimmed pydub audio
tmp_wav = str(TEMP_DIR / f"gen_{uuid.uuid4().hex}.wav")
torchaudio.save(tmp_wav, wav_tensor, 24000)
audio_seg = AudioSegment.from_wav(tmp_wav)
if segs:
out_seg = AudioSegment.silent(duration=0, frame_rate=audio_seg.frame_rate)
for st_ms, en_ms in segs:
out_seg += audio_seg[st_ms:en_ms]
processed_seg = out_seg
else:
processed_seg = audio_seg
try:
os.remove(tmp_wav)
except Exception:
pass
else:
# Directly to pydub for finalization
tmp_wav = str(TEMP_DIR / f"gen_{uuid.uuid4().hex}.wav")
torchaudio.save(tmp_wav, wav_tensor, 24000)
processed_seg = AudioSegment.from_wav(tmp_wav)
try:
os.remove(tmp_wav)
except Exception:
pass
# Save master WAV (lossless) and MP3 preview
ts = time.strftime("%Y%m%d-%H%M%S")
master_wav_path = str(OUTPUT_DIR / f"lishani_{ts}_{uuid.uuid4().hex}.wav")
# Always write WAV (this is very reliable)
processed_seg.export(master_wav_path, format="wav")
# Try MP3, but don't fail the whole call if it breaks
mp3_path = None
try:
mp3_path = str(Path(master_wav_path).with_suffix(".mp3"))
processed_seg.export(mp3_path, format="mp3", bitrate="320k")
except Exception as e:
print("MP3 export failed; returning WAV only:", e)
mp3_path = None
# Return whatever we have: the Audio output will happily preview WAV too
return mp3_path or master_wav_path, master_wav_path
return mp3_path, master_wav_path
except Exception as e:
print(f"Error in synthesis: {e}")
return None, None
# ----------------- File hygiene -----------------
def cleanup_old_files(max_age_minutes: int = 60) -> int:
removed = 0
cutoff = time.time() - max_age_minutes * 60
for folder in [OUTPUT_DIR, TEMP_DIR]:
for p in folder.glob("*"):
try:
if p.is_file() and p.stat().st_mtime < cutoff:
p.unlink()
removed += 1
except Exception:
pass
return removed
# ----------------- UI -----------------
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, css="""
#title-bar {display:flex; align-items:center; gap:8px;}
#outs {display:grid; grid-template-columns: 1fr 1fr; gap: 12px;}
.mark {font-size: 0.95rem; opacity: 0.9;}
""") as interface:
with gr.Row():
with gr.Column(scale=3):
gr.HTML("""
<div id="title-bar">
<h1 style="margin:0;font-size:1.8rem;">🎙️ Lishani — XTTS-v2 Voice Cloning</h1>
</div>
""")
gr.Markdown(
"Upload up to **5 minutes** of a reference voice. Enter text, pick a language, and generate. "
"Outputs appear as a **Preview (MP3)** and a **Master (WAV)**."
)
if not TTS_AVAILABLE:
gr.Markdown(
"⚠️ **Warning**: TTS model not available. Please install coqui-tts: `pip install coqui-tts`",
elem_classes=["mark"]
)
with gr.Column(scale=1):
gr.Markdown(
"⚠️ Use responsibly. Only upload audio you have the right to use. Label outputs as synthetic."
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(label="Text to speak", lines=6, placeholder="Type the text you want spoken…")
lang_dropdown = gr.Dropdown(choices=list(SUPPORTED_LANGUAGES.keys()), value="English", label="Language")
gr.Markdown("Adjust these settings to control style and quality.", elem_classes=["mark"])
with gr.Accordion("Generation Settings", open=True):
with gr.Row():
with gr.Column():
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.75, label="Temperature")
speed_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.05, value=1.0, label="Speed")
do_sample = gr.Checkbox(value=True, label="Enable Sampling")
with gr.Column():
repetition_penalty = gr.Slider(minimum=0.5, maximum=5.0, step=0.1, value=1.05, label="Repetition Penalty")
length_penalty = gr.Slider(minimum=0.8, maximum=2.0, step=0.1, value=1.2, label="Length Penalty")
gpt_cond_len = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="(Legacy) GPT Conditioning Length")
top_k = gr.Slider(minimum=0, maximum=50, step=1, value=50, label="Top-K")
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.85, label="Top-P")
with gr.Accordion("Text Splitting", open=False):
text_splitting_method = gr.Radio(
choices=["Native XTTS splitting", "Custom splitting", "No splitting"],
value="Native XTTS splitting",
label="Text Splitting Method"
)
enable_text_splitting = gr.Checkbox(
value=True,
label="enable_text_splitting (XTTS parameter)",
visible=False
)
max_chars_per_segment = gr.Slider(
minimum=50, maximum=400, step=10, value=250,
label="Max characters per segment (Custom splitting)"
)
with gr.Accordion("Silence Removal", open=False):
remove_silence_enabled = gr.Checkbox(value=True, label="Trim silence/breaths (VAD-based)")
silence_threshold = gr.Slider(minimum=-60, maximum=-20, step=5, value=-45,
label="Silence threshold (legacy; ignored with VAD)")
min_silence_len = gr.Slider(minimum=200, maximum=1000, step=50, value=300, label="Min speech (ms)")
keep_silence = gr.Slider(minimum=50, maximum=500, step=10, value=120, label="Padding (ms)")
with gr.Column(scale=1):
gr.Markdown("### Reference Voice")
reference_audio_input = gr.Audio(sources=["upload"], type="filepath", label="Reference audio (≤ 5 minutes)")
gr.Markdown("### Generate & Listen")
generate_button = gr.Button("Generate Audio", variant="primary", interactive=TTS_AVAILABLE)
status_text = gr.Textbox(label="Status", value="Ready" if TTS_AVAILABLE else "TTS model not available", interactive=False)
with gr.Row(elem_id="outs"):
output_audio_mp3 = gr.Audio(label="Preview (MP3)")
output_audio_wav = gr.File(label="Master (WAV)")
# -------- bindings --------
def validate_audio_file(file_path, max_size_mb=20, min_duration_sec=1, max_duration_sec=300):
try:
if file_path is None or not os.path.exists(file_path):
return False, "No audio file provided."
size_mb = os.path.getsize(file_path) / (1024 * 1024)
if size_mb > max_size_mb:
return False, f"Audio file is too large ({size_mb:.1f} MB). Max {max_size_mb} MB."
a = AudioSegment.from_file(file_path)
duration_sec = len(a) / 1000.0
if duration_sec < min_duration_sec:
return False, "Audio is too short."
if duration_sec > max_duration_sec:
return False, "Audio exceeds 5 minutes."
return True, None
except Exception as e:
return False, f"Failed to process audio: {e}"
def handle_click(
text, language, temperature, speed, reference_audio,
do_sample, enable_text_splitting, repetition_penalty, length_penalty,
gpt_cond_len, top_k, top_p, remove_silence_enabled, silence_threshold,
min_silence_len, keep_silence, text_splitting_method, max_chars_per_segment
):
if not TTS_AVAILABLE or model is None:
print("Error: TTS model not available. Please check your installation.")
return None, None
if not text or not reference_audio:
return None, None
ok, err = validate_audio_file(reference_audio)
if not ok:
print(err)
return None, None
try:
mp3_path, wav_path = synthesize_speech(
text=text,
language=language,
temperature=temperature,
speed=speed,
reference_audio_path=reference_audio,
do_sample=do_sample,
enable_text_splitting=enable_text_splitting,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
gpt_cond_len=gpt_cond_len,
top_k=top_k,
top_p=top_p,
remove_silence_enabled=remove_silence_enabled,
silence_threshold=silence_threshold,
min_silence_len=min_silence_len,
keep_silence=keep_silence,
text_splitting_method=text_splitting_method,
max_chars_per_segment=max_chars_per_segment,
)
return mp3_path, wav_path
except Exception as e:
print(f"Error in handle_click: {e}")
return None, None
generate_button.click(
handle_click,
inputs=[
text_input, lang_dropdown, temperature_slider, speed_slider,
reference_audio_input, do_sample,
enable_text_splitting, repetition_penalty, length_penalty,
gpt_cond_len, top_k, top_p, remove_silence_enabled,
silence_threshold, min_silence_len, keep_silence,
text_splitting_method, max_chars_per_segment
],
outputs=[output_audio_mp3, output_audio_wav],
api_name=False
)
def update_text_splitting_options(method):
is_native = method == "Native XTTS splitting"
is_custom = method == "Custom splitting"
return gr.update(value=is_native), gr.update(visible=is_custom)
text_splitting_method.change(
update_text_splitting_options,
inputs=[text_splitting_method],
outputs=[enable_text_splitting, max_chars_per_segment],
api_name=False
)
# ----------------- Background cleanup & launch -----------------
def periodic_cleanup():
while True:
try:
time.sleep(60 * 60) # 1 hour
removed = cleanup_old_files(60)
if removed:
print(f"Cleaned {removed} old files.")
except Exception as e:
print(f"Cleanup error: {e}")
if __name__ == "__main__":
if not TTS_AVAILABLE:
print("\n" + "="*50)
print("ERROR: TTS model not available!")
print("Please install coqui-tts: pip install coqui-tts")
print("="*50 + "\n")
# Start background cleanup thread
cleanup_thread = threading.Thread(target=periodic_cleanup, daemon=True)
cleanup_thread.start()
try:
interface.queue()
interface.launch(
share=False,
allowed_paths=[str(REF_AUDIO_DIR), str(OUTPUT_DIR), str(TEMP_DIR)]
)
except KeyboardInterrupt:
print("\nShutting down gracefully...")
except Exception as e:
print(f"Error launching interface: {e}")