Voice-cloning / app.py
Dhanush-37's picture
Upload 7 files
6bee22f verified
import os
import gc
import torch
import traceback
import numpy as np
import librosa
import gradio as gr
from pydub import AudioSegment
from pydub.effects import normalize
from huggingface_hub import snapshot_download
from tts.infer_cli import MegaTTS3DiTInfer, cut_wav
import spaces
# -----------------------------
# Utility Functions
# -----------------------------
def download_weights():
"""Download model weights from HuggingFace if not already present."""
repo_id = "mrfakename/MegaTTS3-VoiceCloning"
weights_dir = "checkpoints"
if not os.path.exists(weights_dir):
print("πŸ“₯ Downloading model weights from HuggingFace...")
snapshot_download(
repo_id=repo_id,
local_dir=weights_dir,
local_dir_use_symlinks=False
)
print("βœ… Model weights downloaded successfully!")
else:
print("βœ… Model weights already exist.")
return weights_dir
def cleanup_memory():
"""Clean up GPU and system memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def reset_model():
"""Reset the inference pipeline to recover from CUDA errors."""
global infer_pipe
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("♻️ Reinitializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("βœ… Model reinitialized successfully!")
return True
except Exception as e:
print(f"❌ Failed to reinitialize model: {e}")
return False
def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30):
"""Robustly preprocess audio for TTS."""
try:
audio = AudioSegment.from_file(audio_path)
# Convert to mono
if audio.channels > 1:
audio = audio.set_channels(1)
# Trim if longer than max duration
if len(audio) > max_duration * 1000:
audio = audio[:max_duration * 1000]
# Normalize audio
audio = normalize(audio)
# Set sample rate
audio = audio.set_frame_rate(target_sr)
# Export temp wav
temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav')
audio.export(
temp_path,
format="wav",
parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)]
)
# Validate with librosa
wav, sr = librosa.load(temp_path, sr=target_sr, mono=True)
if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
raise ValueError("Audio contains NaN or infinite values")
if np.max(np.abs(wav)) < 1e-6:
raise ValueError("Audio signal too quiet")
import soundfile as sf
sf.write(temp_path, wav, sr)
return temp_path
except Exception as e:
raise ValueError(f"Audio preprocessing failed: {e}")
# -----------------------------
# Model Initialization
# -----------------------------
download_weights()
print("πŸš€ Initializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("βœ… Model loaded successfully!")
# -----------------------------
# Speech Generation
# -----------------------------
@spaces.GPU
def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
if not inp_audio or not inp_text:
gr.Warning("⚠️ Please provide both reference audio and text.")
return None
try:
print(f"🎀 Generating speech for text: {inp_text}")
# GPU check
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"πŸ–₯️ Using CUDA device: {torch.cuda.get_device_name()}")
else:
gr.Warning("⚠️ CUDA not available. Please check your GPU setup.")
return None
# Preprocess audio
processed_audio_path = preprocess_audio_robust(inp_audio)
cut_wav(processed_audio_path, max_len=28)
wav_path = processed_audio_path
with open(wav_path, 'rb') as f:
file_content = f.read()
# Inference
resource_context = infer_pipe.preprocess(file_content)
wav_bytes = infer_pipe.forward(
resource_context,
inp_text,
time_step=infer_timestep,
p_w=p_w,
t_w=t_w
)
cleanup_memory()
return wav_bytes
except RuntimeError as e:
if "CUDA" in str(e):
print(f"❌ CUDA error detected: {e}")
if reset_model():
gr.Warning("⚠️ CUDA error occurred. Model reset. Please try again.")
else:
gr.Warning("❌ CUDA error & reset failed. Restart app.")
return None
else:
raise
except Exception as e:
traceback.print_exc()
gr.Warning(f"❌ Speech generation failed: {e}")
cleanup_memory()
return None
# -----------------------------
# Gradio UI
# -----------------------------
with gr.Blocks(title="Dhanush Voice Cloning") as demo:
gr.Markdown(
"""
# πŸŽ™οΈ Dhanush Voice Cloning
**Dhanush Voice Cloning** is powered by MegaTTS3 with the unofficial WavVAE encoder.
Clone voices by uploading a short audio sample and typing your desired text!
⚠️ For research and educational purposes only.
"""
)
with gr.Row():
with gr.Column():
reference_audio = gr.Audio(
label="🎧 Reference Audio",
type="filepath",
sources=["upload", "microphone"]
)
text_input = gr.Textbox(
label="πŸ“ Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=3
)
with gr.Accordion("βš™οΈ Advanced Options", open=False):
infer_timestep = gr.Number(label="Inference Timesteps", value=32, minimum=1, maximum=100, step=1)
p_w = gr.Number(label="Intelligibility Weight", value=1.4, minimum=0.1, maximum=5.0, step=0.1)
t_w = gr.Number(label="Similarity Weight", value=3.0, minimum=0.1, maximum=10.0, step=0.1)
generate_btn = gr.Button("πŸš€ Generate Speech", variant="primary")
with gr.Column():
output_audio = gr.Audio(label="πŸ”Š Generated Audio")
generate_btn.click(
fn=generate_speech,
inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
outputs=[output_audio]
)
# -----------------------------
# Launch App
# -----------------------------
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)