Spaces:
Runtime error
Runtime error
| import os | |
| import gc | |
| import torch | |
| import traceback | |
| import numpy as np | |
| import librosa | |
| import gradio as gr | |
| from pydub import AudioSegment | |
| from pydub.effects import normalize | |
| from huggingface_hub import snapshot_download | |
| from tts.infer_cli import MegaTTS3DiTInfer, cut_wav | |
| import spaces | |
| # ----------------------------- | |
| # Utility Functions | |
| # ----------------------------- | |
| def download_weights(): | |
| """Download model weights from HuggingFace if not already present.""" | |
| repo_id = "mrfakename/MegaTTS3-VoiceCloning" | |
| weights_dir = "checkpoints" | |
| if not os.path.exists(weights_dir): | |
| print("π₯ Downloading model weights from HuggingFace...") | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=weights_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| print("β Model weights downloaded successfully!") | |
| else: | |
| print("β Model weights already exist.") | |
| return weights_dir | |
| def cleanup_memory(): | |
| """Clean up GPU and system memory.""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def reset_model(): | |
| """Reset the inference pipeline to recover from CUDA errors.""" | |
| global infer_pipe | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| print("β»οΈ Reinitializing MegaTTS3 model...") | |
| infer_pipe = MegaTTS3DiTInfer() | |
| print("β Model reinitialized successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"β Failed to reinitialize model: {e}") | |
| return False | |
| def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30): | |
| """Robustly preprocess audio for TTS.""" | |
| try: | |
| audio = AudioSegment.from_file(audio_path) | |
| # Convert to mono | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| # Trim if longer than max duration | |
| if len(audio) > max_duration * 1000: | |
| audio = audio[:max_duration * 1000] | |
| # Normalize audio | |
| audio = normalize(audio) | |
| # Set sample rate | |
| audio = audio.set_frame_rate(target_sr) | |
| # Export temp wav | |
| temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav') | |
| audio.export( | |
| temp_path, | |
| format="wav", | |
| parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)] | |
| ) | |
| # Validate with librosa | |
| wav, sr = librosa.load(temp_path, sr=target_sr, mono=True) | |
| if np.any(np.isnan(wav)) or np.any(np.isinf(wav)): | |
| raise ValueError("Audio contains NaN or infinite values") | |
| if np.max(np.abs(wav)) < 1e-6: | |
| raise ValueError("Audio signal too quiet") | |
| import soundfile as sf | |
| sf.write(temp_path, wav, sr) | |
| return temp_path | |
| except Exception as e: | |
| raise ValueError(f"Audio preprocessing failed: {e}") | |
| # ----------------------------- | |
| # Model Initialization | |
| # ----------------------------- | |
| download_weights() | |
| print("π Initializing MegaTTS3 model...") | |
| infer_pipe = MegaTTS3DiTInfer() | |
| print("β Model loaded successfully!") | |
| # ----------------------------- | |
| # Speech Generation | |
| # ----------------------------- | |
| def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w): | |
| if not inp_audio or not inp_text: | |
| gr.Warning("β οΈ Please provide both reference audio and text.") | |
| return None | |
| try: | |
| print(f"π€ Generating speech for text: {inp_text}") | |
| # GPU check | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"π₯οΈ Using CUDA device: {torch.cuda.get_device_name()}") | |
| else: | |
| gr.Warning("β οΈ CUDA not available. Please check your GPU setup.") | |
| return None | |
| # Preprocess audio | |
| processed_audio_path = preprocess_audio_robust(inp_audio) | |
| cut_wav(processed_audio_path, max_len=28) | |
| wav_path = processed_audio_path | |
| with open(wav_path, 'rb') as f: | |
| file_content = f.read() | |
| # Inference | |
| resource_context = infer_pipe.preprocess(file_content) | |
| wav_bytes = infer_pipe.forward( | |
| resource_context, | |
| inp_text, | |
| time_step=infer_timestep, | |
| p_w=p_w, | |
| t_w=t_w | |
| ) | |
| cleanup_memory() | |
| return wav_bytes | |
| except RuntimeError as e: | |
| if "CUDA" in str(e): | |
| print(f"β CUDA error detected: {e}") | |
| if reset_model(): | |
| gr.Warning("β οΈ CUDA error occurred. Model reset. Please try again.") | |
| else: | |
| gr.Warning("β CUDA error & reset failed. Restart app.") | |
| return None | |
| else: | |
| raise | |
| except Exception as e: | |
| traceback.print_exc() | |
| gr.Warning(f"β Speech generation failed: {e}") | |
| cleanup_memory() | |
| return None | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| with gr.Blocks(title="Dhanush Voice Cloning") as demo: | |
| gr.Markdown( | |
| """ | |
| # ποΈ Dhanush Voice Cloning | |
| **Dhanush Voice Cloning** is powered by MegaTTS3 with the unofficial WavVAE encoder. | |
| Clone voices by uploading a short audio sample and typing your desired text! | |
| β οΈ For research and educational purposes only. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| reference_audio = gr.Audio( | |
| label="π§ Reference Audio", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| text_input = gr.Textbox( | |
| label="π Text to Generate", | |
| placeholder="Enter the text you want to synthesize...", | |
| lines=3 | |
| ) | |
| with gr.Accordion("βοΈ Advanced Options", open=False): | |
| infer_timestep = gr.Number(label="Inference Timesteps", value=32, minimum=1, maximum=100, step=1) | |
| p_w = gr.Number(label="Intelligibility Weight", value=1.4, minimum=0.1, maximum=5.0, step=0.1) | |
| t_w = gr.Number(label="Similarity Weight", value=3.0, minimum=0.1, maximum=10.0, step=0.1) | |
| generate_btn = gr.Button("π Generate Speech", variant="primary") | |
| with gr.Column(): | |
| output_audio = gr.Audio(label="π Generated Audio") | |
| generate_btn.click( | |
| fn=generate_speech, | |
| inputs=[reference_audio, text_input, infer_timestep, p_w, t_w], | |
| outputs=[output_audio] | |
| ) | |
| # ----------------------------- | |
| # Launch App | |
| # ----------------------------- | |
| if __name__ == '__main__': | |
| demo.launch(server_name='0.0.0.0', server_port=7860, debug=True) | |