Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import spaces | |
| from transformers import pipeline | |
| # Import the inference module | |
| from infer import DMOInference | |
| # Global variables | |
| model = None | |
| asr_pipe = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Initialize ASR pipeline | |
| def initialize_asr_pipeline(device=device, dtype=None): | |
| """Initialize the ASR pipeline on startup.""" | |
| global asr_pipe | |
| if dtype is None: | |
| dtype = ( | |
| torch.float16 | |
| if "cuda" in device | |
| and torch.cuda.is_available() | |
| and torch.cuda.get_device_properties(device).major >= 7 | |
| and not torch.cuda.get_device_name().endswith("[ZLUDA]") | |
| else torch.float32 | |
| ) | |
| print("Initializing ASR pipeline...") | |
| try: | |
| asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3-turbo", | |
| torch_dtype=dtype, | |
| device="cpu" # Keep ASR on CPU to save GPU memory | |
| ) | |
| print("ASR pipeline initialized successfully") | |
| except Exception as e: | |
| print(f"Error initializing ASR pipeline: {e}") | |
| asr_pipe = None | |
| # Transcribe function | |
| def transcribe(ref_audio, language=None): | |
| """Transcribe audio using the pre-loaded ASR pipeline.""" | |
| global asr_pipe | |
| if asr_pipe is None: | |
| return "" # Return empty string if ASR is not available | |
| try: | |
| result = asr_pipe( | |
| ref_audio, | |
| chunk_length_s=30, | |
| batch_size=128, | |
| generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, | |
| return_timestamps=False, | |
| ) | |
| return result["text"].strip() | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| return "" | |
| def download_models(): | |
| """Download models from HuggingFace Hub.""" | |
| try: | |
| print("Downloading models from HuggingFace...") | |
| # Download student model | |
| student_path = hf_hub_download( | |
| repo_id="yl4579/DMOSpeech2", | |
| filename="model_85000.pt", | |
| cache_dir="./models" | |
| ) | |
| # Download duration predictor | |
| duration_path = hf_hub_download( | |
| repo_id="yl4579/DMOSpeech2", | |
| filename="model_1500.pt", | |
| cache_dir="./models" | |
| ) | |
| print(f"Student model: {student_path}") | |
| print(f"Duration model: {duration_path}") | |
| return student_path, duration_path | |
| except Exception as e: | |
| print(f"Error downloading models: {e}") | |
| return None, None | |
| def initialize_model(): | |
| """Initialize the model on startup.""" | |
| global model | |
| try: | |
| # Download models | |
| student_path, duration_path = download_models() | |
| if not student_path or not duration_path: | |
| return False, "Failed to download models from HuggingFace" | |
| # Initialize model | |
| model = DMOInference( | |
| student_checkpoint_path=student_path, | |
| duration_predictor_path=duration_path, | |
| device=device, | |
| model_type="F5TTS_Base" | |
| ) | |
| return True, f"Model loaded successfully on {device.upper()}" | |
| except Exception as e: | |
| return False, f"Error initializing model: {str(e)}" | |
| # Initialize models on startup | |
| print("Initializing models...") | |
| model_loaded, status_message = initialize_model() | |
| initialize_asr_pipeline() # Initialize ASR pipeline | |
| # Request GPU for up to 120 seconds | |
| def generate_speech( | |
| prompt_audio, | |
| prompt_text, | |
| target_text, | |
| mode, | |
| temperature, | |
| custom_teacher_steps, | |
| custom_teacher_stopping_time, | |
| custom_student_start_step, | |
| verbose | |
| ): | |
| """Generate speech with different configurations.""" | |
| if not model_loaded or model is None: | |
| return None, "Model not loaded! Please refresh the page.", "", "" | |
| if prompt_audio is None: | |
| return None, "Please upload a reference audio!", "", "" | |
| if not target_text: | |
| return None, "Please enter text to generate!", "", "" | |
| try: | |
| # Auto-transcribe if prompt_text is empty | |
| if not prompt_text and prompt_text != "": | |
| print("Auto-transcribing reference audio...") | |
| prompt_text = transcribe(prompt_audio) | |
| print(f"Transcribed: {prompt_text}") | |
| start_time = time.time() | |
| # Configure parameters based on mode | |
| if mode == "Student Only (4 steps)": | |
| teacher_steps = 0 | |
| student_start_step = 0 | |
| teacher_stopping_time = 1.0 | |
| elif mode == "Teacher-Guided (8 steps)": | |
| # Default configuration from the notebook | |
| teacher_steps = 16 | |
| teacher_stopping_time = 0.07 | |
| student_start_step = 1 | |
| elif mode == "High Diversity (16 steps)": | |
| teacher_steps = 24 | |
| teacher_stopping_time = 0.3 | |
| student_start_step = 2 | |
| else: # Custom | |
| teacher_steps = custom_teacher_steps | |
| teacher_stopping_time = custom_teacher_stopping_time | |
| student_start_step = custom_student_start_step | |
| # Generate speech | |
| generated_audio = model.generate( | |
| gen_text=target_text, | |
| audio_path=prompt_audio, | |
| prompt_text=prompt_text if prompt_text else None, | |
| teacher_steps=teacher_steps, | |
| teacher_stopping_time=teacher_stopping_time, | |
| student_start_step=student_start_step, | |
| temperature=temperature, | |
| verbose=verbose | |
| ) | |
| end_time = time.time() | |
| # Calculate metrics | |
| processing_time = end_time - start_time | |
| audio_duration = generated_audio.shape[-1] / 24000 | |
| rtf = processing_time / audio_duration | |
| # Save audio | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| output_path = tmp_file.name | |
| if isinstance(generated_audio, np.ndarray): | |
| generated_audio = torch.from_numpy(generated_audio) | |
| if generated_audio.dim() == 1: | |
| generated_audio = generated_audio.unsqueeze(0) | |
| torchaudio.save(output_path, generated_audio, 24000) | |
| # Format metrics | |
| metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio" | |
| return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}", "", "" | |
| # Create Gradio interface | |
| with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS") as demo: | |
| gr.Markdown(f""" | |
| # ποΈ DMOSpeech 2: Zero-Shot Text-to-Speech | |
| Generate natural speech in any voice with just a short reference audio! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Reference audio input | |
| prompt_audio = gr.Audio( | |
| label="π Reference Audio", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| prompt_text = gr.Textbox( | |
| label="π Reference Text (leave empty for auto-transcription)", | |
| placeholder="The text spoken in the reference audio...", | |
| lines=2 | |
| ) | |
| target_text = gr.Textbox( | |
| label="βοΈ Text to Generate", | |
| placeholder="Enter the text you want to synthesize...", | |
| lines=4 | |
| ) | |
| # Generation mode | |
| mode = gr.Radio( | |
| choices=[ | |
| "Student Only (4 steps)", | |
| "Teacher-Guided (8 steps)", | |
| "High Diversity (16 steps)", | |
| "Custom" | |
| ], | |
| value="Teacher-Guided (8 steps)", | |
| label="π Generation Mode", | |
| info="Choose speed vs quality/diversity tradeoff" | |
| ) | |
| # Advanced settings (collapsible) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.0, | |
| step=0.1, | |
| label="Duration Temperature", | |
| info="0 = deterministic, >0 = more variation in speech rhythm" | |
| ) | |
| with gr.Group(visible=False) as custom_settings: | |
| gr.Markdown("### Custom Mode Settings") | |
| custom_teacher_steps = gr.Slider( | |
| minimum=0, | |
| maximum=32, | |
| value=16, | |
| step=1, | |
| label="Teacher Steps", | |
| info="More steps = higher quality" | |
| ) | |
| custom_teacher_stopping_time = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.07, | |
| step=0.01, | |
| label="Teacher Stopping Time", | |
| info="When to switch to student" | |
| ) | |
| custom_student_start_step = gr.Slider( | |
| minimum=0, | |
| maximum=4, | |
| value=1, | |
| step=1, | |
| label="Student Start Step", | |
| info="Which student step to start from" | |
| ) | |
| verbose = gr.Checkbox( | |
| value=False, | |
| label="Verbose Output", | |
| info="Show detailed generation steps" | |
| ) | |
| generate_btn = gr.Button("π΅ Generate Speech", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| # Output | |
| output_audio = gr.Audio( | |
| label="π Generated Speech", | |
| type="filepath", | |
| autoplay=True | |
| ) | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False | |
| ) | |
| metrics = gr.Textbox( | |
| label="Performance Metrics", | |
| interactive=False | |
| ) | |
| info = gr.Textbox( | |
| label="Generation Info", | |
| interactive=False | |
| ) | |
| # Tips | |
| gr.Markdown(""" | |
| ### π‘ Quick Tips: | |
| - **Auto-transcription**: Leave reference text empty to auto-transcribe | |
| - **Student Only**: Fastest (4 steps), good quality | |
| - **Teacher-Guided**: Best balance (8 steps), recommended | |
| - **High Diversity**: More natural prosody (16 steps) | |
| - **Custom Mode**: Fine-tune all parameters | |
| ### π Expected RTF (Real-Time Factor): | |
| - Student Only: ~0.05x (20x faster than real-time) | |
| - Teacher-Guided: ~0.10x (10x faster) | |
| - High Diversity: ~0.20x (5x faster) | |
| """) | |
| # Event handler | |
| generate_btn.click( | |
| generate_speech, | |
| inputs=[ | |
| prompt_audio, | |
| prompt_text, | |
| target_text, | |
| mode, | |
| temperature, | |
| custom_teacher_steps, | |
| custom_teacher_stopping_time, | |
| custom_student_start_step, | |
| verbose | |
| ], | |
| outputs=[output_audio, status, metrics, info] | |
| ) | |
| # Update visibility of custom settings based on mode | |
| def update_custom_visibility(mode): | |
| is_custom = (mode == "Custom") | |
| return gr.update(visible=is_custom) | |
| mode.change( | |
| update_custom_visibility, | |
| inputs=[mode], | |
| outputs=[custom_settings] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| if not model_loaded: | |
| print(f"Warning: Model failed to load - {status_message}") | |
| if not asr_pipe: | |
| print("Warning: ASR pipeline not available - auto-transcription disabled") | |
| demo.launch() |