Spaces:
Paused
Paused
| import os | |
| import torch | |
| import soundfile as sf | |
| import logging | |
| import argparse | |
| import gradio as gr | |
| import json | |
| import threading | |
| import queue | |
| from datetime import datetime | |
| from pathlib import Path | |
| from mira.model import MiraTTS | |
| MODEL = None | |
| # Safe device detection with fallback | |
| def get_device(): | |
| """Safely detect available device.""" | |
| try: | |
| if torch.cuda.is_available(): | |
| # Try to actually access CUDA to verify it works | |
| torch.cuda.current_device() | |
| return "cuda" | |
| except Exception as e: | |
| logging.warning(f"CUDA not available or driver error: {e}") | |
| return "cpu" | |
| DEVICE = get_device() | |
| HISTORY_FILE = "generation_history.json" | |
| GENERATION_QUEUE = queue.Queue() | |
| PROCESSING_LOCK = threading.Lock() | |
| class GenerationHistory: | |
| """Manage generation history with persistence.""" | |
| def __init__(self, history_file=HISTORY_FILE): | |
| self.history_file = history_file | |
| self.history = self.load_history() | |
| def load_history(self): | |
| """Load history from JSON file.""" | |
| if os.path.exists(self.history_file): | |
| try: | |
| with open(self.history_file, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logging.error(f"Error loading history: {e}") | |
| return [] | |
| return [] | |
| def save_history(self): | |
| """Save history to JSON file.""" | |
| try: | |
| with open(self.history_file, 'w', encoding='utf-8') as f: | |
| json.dump(self.history, f, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| logging.error(f"Error saving history: {e}") | |
| def add_entry(self, entry): | |
| """Add a new entry to history.""" | |
| self.history.insert(0, entry) # Add to beginning | |
| # Keep only last 100 entries | |
| if len(self.history) > 100: | |
| self.history = self.history[:100] | |
| self.save_history() | |
| def get_history(self): | |
| """Get all history entries.""" | |
| return self.history | |
| def clear_history(self): | |
| """Clear all history.""" | |
| self.history = [] | |
| self.save_history() | |
| # Global history manager | |
| HISTORY_MANAGER = GenerationHistory() | |
| def initialize_model(model_dir="YatharthS/MiraTTS", device=None): | |
| """Load the MiraTTS model once at the beginning.""" | |
| global DEVICE | |
| if device: | |
| # Verify the requested device is available | |
| if device == "cuda": | |
| try: | |
| if not torch.cuda.is_available(): | |
| logging.warning("CUDA requested but not available, falling back to CPU") | |
| DEVICE = "cpu" | |
| else: | |
| torch.cuda.current_device() # Test CUDA access | |
| DEVICE = device | |
| except Exception as e: | |
| logging.warning(f"CUDA test failed: {e}, falling back to CPU") | |
| DEVICE = "cpu" | |
| else: | |
| DEVICE = device | |
| logging.info(f"Loading MiraTTS model from: {model_dir}") | |
| logging.info(f"Using device: {DEVICE}") | |
| try: | |
| model = MiraTTS(model_dir) | |
| # Move model to appropriate device | |
| if hasattr(model, 'to') and DEVICE == "cuda": | |
| try: | |
| model = model.to(DEVICE) | |
| except Exception as e: | |
| logging.warning(f"Failed to move model to CUDA: {e}, using CPU") | |
| DEVICE = "cpu" | |
| return model | |
| except Exception as e: | |
| logging.error(f"Error initializing model: {e}") | |
| raise | |
| def generate_audio(text, prompt_audio_path): | |
| """Generate audio from text using MiraTTS with voice cloning.""" | |
| global MODEL | |
| if MODEL is None: | |
| MODEL = initialize_model() | |
| try: | |
| # Encode the prompt audio | |
| context_tokens = MODEL.encode_audio(prompt_audio_path) | |
| # Move context tokens to device if needed | |
| if torch.is_tensor(context_tokens) and DEVICE == "cuda": | |
| try: | |
| context_tokens = context_tokens.to(DEVICE) | |
| except Exception as e: | |
| logging.warning(f"Failed to move tensors to CUDA: {e}") | |
| # Generate audio with appropriate context | |
| try: | |
| if DEVICE == "cpu": | |
| with torch.inference_mode(): | |
| audio = MODEL.generate(text, context_tokens) | |
| else: | |
| with torch.cuda.amp.autocast(): | |
| audio = MODEL.generate(text, context_tokens) | |
| except Exception as e: | |
| # Fallback to simple generation if autocast fails | |
| logging.warning(f"Autocast failed: {e}, using standard generation") | |
| with torch.inference_mode(): | |
| audio = MODEL.generate(text, context_tokens) | |
| # Convert to numpy array if it's a tensor and handle dtype | |
| if torch.is_tensor(audio): | |
| audio = audio.cpu().numpy() | |
| # Ensure correct dtype for soundfile (convert from float16 to float32) | |
| if audio.dtype == 'float16': | |
| audio = audio.astype('float32') | |
| elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']: | |
| audio = audio.astype('float32') | |
| return audio, 48000 # Return audio and sample rate | |
| except Exception as e: | |
| logging.error(f"Error during generation: {e}") | |
| raise e | |
| def run_tts(text, prompt_audio_path, save_dir="results", mode="clone"): | |
| """Perform TTS inference and save the generated audio.""" | |
| logging.info(f"Saving audio to: {save_dir}") | |
| # Ensure the save directory exists | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Generate unique filename using timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = os.path.join(save_dir, f"mira_tts_{timestamp}.wav") | |
| logging.info("Starting MiraTTS inference...") | |
| # Generate audio | |
| audio, sample_rate = generate_audio(text, prompt_audio_path) | |
| # Save audio file | |
| sf.write(save_path, audio, samplerate=sample_rate) | |
| logging.info(f"Audio saved at: {save_path}") | |
| # Add to history | |
| history_entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "text": text[:100] + "..." if len(text) > 100 else text, | |
| "full_text": text, | |
| "mode": mode, | |
| "file_path": save_path, | |
| "reference_audio": prompt_audio_path if mode == "clone" else None, | |
| "device": DEVICE | |
| } | |
| HISTORY_MANAGER.add_entry(history_entry) | |
| return save_path | |
| def background_worker(): | |
| """Background worker to process generation tasks.""" | |
| while True: | |
| try: | |
| task = GENERATION_QUEUE.get() | |
| if task is None: # Poison pill to stop the worker | |
| break | |
| callback, args = task | |
| callback(*args) | |
| except Exception as e: | |
| logging.error(f"Error in background worker: {e}") | |
| finally: | |
| GENERATION_QUEUE.task_done() | |
| # Start background worker thread | |
| worker_thread = threading.Thread(target=background_worker, daemon=True) | |
| worker_thread.start() | |
| def voice_clone_callback(text, prompt_audio_upload, prompt_audio_record, progress=gr.Progress()): | |
| """Gradio callback for voice cloning using MiraTTS.""" | |
| if not text.strip(): | |
| return None, get_history_display() | |
| # Use uploaded audio or recorded audio | |
| prompt_audio = prompt_audio_upload if prompt_audio_upload else prompt_audio_record | |
| if not prompt_audio: | |
| return None, get_history_display() | |
| progress(0, desc="Initializing...") | |
| try: | |
| progress(0.3, desc="Encoding audio...") | |
| progress(0.6, desc="Generating speech...") | |
| audio_output_path = run_tts(text, prompt_audio, mode="clone") | |
| progress(1.0, desc="Complete!") | |
| return audio_output_path, get_history_display() | |
| except Exception as e: | |
| logging.error(f"Error in voice cloning: {e}") | |
| return None, get_history_display() | |
| def voice_creation_callback(text, temperature, top_p, top_k, progress=gr.Progress()): | |
| """Gradio callback for creating synthetic voice with custom parameters.""" | |
| if not text.strip(): | |
| return None, get_history_display() | |
| global MODEL | |
| if MODEL is None: | |
| MODEL = initialize_model() | |
| progress(0, desc="Initializing...") | |
| try: | |
| # Set custom generation parameters | |
| MODEL.set_params( | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| max_new_tokens=1024, | |
| repetition_penalty=1.2 | |
| ) | |
| progress(0.3, desc="Loading default voice...") | |
| # Use a default voice context | |
| possible_paths = [ | |
| "/models3/src/MiraTTS/models/MiraTTS/example1.wav", | |
| "models/MiraTTS/example1.wav", | |
| "./models/MiraTTS/example1.wav" | |
| ] | |
| default_audio = None | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| default_audio = path | |
| break | |
| if default_audio: | |
| progress(0.6, desc="Generating speech...") | |
| # Generate audio with dtype conversion | |
| context_tokens = MODEL.encode_audio(default_audio) | |
| # Move to device safely | |
| if torch.is_tensor(context_tokens) and DEVICE == "cuda": | |
| try: | |
| context_tokens = context_tokens.to(DEVICE) | |
| except Exception as e: | |
| logging.warning(f"Failed to move tensors to CUDA: {e}") | |
| try: | |
| if DEVICE == "cpu": | |
| with torch.inference_mode(): | |
| audio = MODEL.generate(text, context_tokens) | |
| else: | |
| with torch.cuda.amp.autocast(): | |
| audio = MODEL.generate(text, context_tokens) | |
| except Exception as e: | |
| # Fallback to simple generation | |
| logging.warning(f"Autocast failed: {e}, using standard generation") | |
| with torch.inference_mode(): | |
| audio = MODEL.generate(text, context_tokens) | |
| # Handle tensor conversion and dtype | |
| if torch.is_tensor(audio): | |
| audio = audio.cpu().numpy() | |
| # Ensure correct dtype for soundfile | |
| if audio.dtype == 'float16': | |
| audio = audio.astype('float32') | |
| elif audio.dtype not in ['float32', 'float64', 'int16', 'int32']: | |
| audio = audio.astype('float32') | |
| # Save the audio | |
| os.makedirs("results", exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = os.path.join("results", f"mira_tts_creation_{timestamp}.wav") | |
| sf.write(save_path, audio, samplerate=48000) | |
| # Add to history | |
| history_entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "text": text[:100] + "..." if len(text) > 100 else text, | |
| "full_text": text, | |
| "mode": "creation", | |
| "file_path": save_path, | |
| "reference_audio": None, | |
| "device": DEVICE, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k | |
| } | |
| HISTORY_MANAGER.add_entry(history_entry) | |
| progress(1.0, desc="Complete!") | |
| return save_path, get_history_display() | |
| else: | |
| logging.warning("No default audio found for voice creation") | |
| return None, get_history_display() | |
| except Exception as e: | |
| logging.error(f"Error in voice creation: {e}") | |
| return None, get_history_display() | |
| def get_history_display(): | |
| """Get formatted history for display.""" | |
| history = HISTORY_MANAGER.get_history() | |
| if not history: | |
| return "No generation history yet." | |
| display_text = "# Generation History\n\n" | |
| for idx, entry in enumerate(history[:20]): # Show last 20 | |
| timestamp = datetime.fromisoformat(entry['timestamp']).strftime("%Y-%m-%d %H:%M:%S") | |
| mode = entry['mode'].capitalize() | |
| text_preview = entry['text'] | |
| file_name = os.path.basename(entry['file_path']) | |
| display_text += f"### {idx + 1}. {timestamp} - {mode}\n" | |
| display_text += f"**Text:** {text_preview}\n" | |
| display_text += f"**File:** `{file_name}`\n" | |
| display_text += f"**Device:** {entry.get('device', 'N/A')}\n" | |
| if entry.get('temperature'): | |
| display_text += f"**Params:** T={entry.get('temperature')}, p={entry.get('top_p')}, k={entry.get('top_k')}\n" | |
| display_text += "\n---\n\n" | |
| return display_text | |
| def get_history_files(): | |
| """Get list of history files for download.""" | |
| history = HISTORY_MANAGER.get_history() | |
| return [(entry['file_path'], os.path.basename(entry['file_path'])) | |
| for entry in history if os.path.exists(entry['file_path'])] | |
| def clear_history_callback(): | |
| """Clear generation history.""" | |
| HISTORY_MANAGER.clear_history() | |
| return get_history_display(), [] | |
| def build_ui(): | |
| """Build the Gradio interface similar to SparkTTS.""" | |
| with gr.Blocks(title="MiraTTS Web Interface", theme=gr.themes.Soft()) as demo: | |
| # Title | |
| gr.HTML('<h1 style="text-align: center;">MiraTTS - High Quality Voice Synthesis</h1>') | |
| # Device info | |
| device_info = f"🖥️ Running on: **{DEVICE.upper()}**" | |
| if DEVICE == "cuda": | |
| try: | |
| device_info += f" (GPU: {torch.cuda.get_device_name(0)})" | |
| except: | |
| device_info += " (GPU)" | |
| else: | |
| device_info += " (CPU mode - slower but works without GPU)" | |
| gr.Markdown(device_info) | |
| # Description | |
| gr.Markdown(""" | |
| MiraTTS is a highly optimized Text-to-Speech model based on Spark-TTS with LMDeploy acceleration. | |
| It provides high-quality 48kHz audio output with background processing support. | |
| """) | |
| with gr.Tabs(): | |
| # Voice Clone Tab | |
| with gr.TabItem("🎤 Voice Clone"): | |
| gr.Markdown("### Clone any voice using a reference audio sample") | |
| with gr.Row(): | |
| prompt_audio_upload = gr.Audio( | |
| sources="upload", | |
| type="filepath", | |
| label="Upload Reference Audio (recommended: 3-30 seconds, 16kHz+)", | |
| ) | |
| prompt_audio_record = gr.Audio( | |
| sources="microphone", | |
| type="filepath", | |
| label="Record Reference Audio", | |
| ) | |
| text_input = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=3, | |
| placeholder="Enter the text you want to convert to speech...", | |
| value="Hello! This is a demonstration of MiraTTS voice cloning capabilities." | |
| ) | |
| with gr.Row(): | |
| clone_button = gr.Button("🎵 Generate Audio", variant="primary") | |
| clear_button = gr.Button("🗑️ Clear") | |
| audio_output_clone = gr.Audio( | |
| label="Generated Audio", | |
| autoplay=True | |
| ) | |
| history_display_clone = gr.Markdown(get_history_display()) | |
| clone_button.click( | |
| voice_clone_callback, | |
| inputs=[text_input, prompt_audio_upload, prompt_audio_record], | |
| outputs=[audio_output_clone, history_display_clone], | |
| ) | |
| clear_button.click( | |
| lambda: (None, None, "", None), | |
| outputs=[prompt_audio_upload, prompt_audio_record, text_input, audio_output_clone] | |
| ) | |
| # Voice Creation Tab | |
| with gr.TabItem("✨ Voice Creation"): | |
| gr.Markdown("### Create synthetic voices with custom parameters") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input_creation = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=3, | |
| placeholder="Enter text here...", | |
| value="You can create customized voices by adjusting the generation parameters below." | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| step=0.1, | |
| value=0.8, | |
| label="Temperature (creativity)" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.95, | |
| label="Top-p (nucleus sampling)" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="Top-k (vocabulary size)" | |
| ) | |
| with gr.Column(): | |
| create_button = gr.Button("🎨 Create Voice", variant="primary") | |
| audio_output_creation = gr.Audio( | |
| label="Generated Audio", | |
| autoplay=True | |
| ) | |
| history_display_creation = gr.Markdown(get_history_display()) | |
| create_button.click( | |
| voice_creation_callback, | |
| inputs=[text_input_creation, temperature, top_p, top_k], | |
| outputs=[audio_output_creation, history_display_creation], | |
| ) | |
| # History Tab | |
| with gr.TabItem("📜 History"): | |
| gr.Markdown("### Review and download previous generations") | |
| with gr.Row(): | |
| refresh_button = gr.Button("🔄 Refresh History", variant="secondary") | |
| clear_history_button = gr.Button("🗑️ Clear History", variant="stop") | |
| history_display_main = gr.Markdown(get_history_display()) | |
| gr.Markdown("### Download Files") | |
| file_browser = gr.File( | |
| label="Generated Audio Files", | |
| file_count="multiple", | |
| interactive=False | |
| ) | |
| def refresh_history(): | |
| files = get_history_files() | |
| return get_history_display(), [f[0] for f in files] | |
| refresh_button.click( | |
| refresh_history, | |
| outputs=[history_display_main, file_browser] | |
| ) | |
| clear_history_button.click( | |
| clear_history_callback, | |
| outputs=[history_display_main, file_browser] | |
| ) | |
| # Auto-load files on tab open | |
| demo.load( | |
| refresh_history, | |
| outputs=[history_display_main, file_browser] | |
| ) | |
| # About Tab | |
| with gr.TabItem("ℹ️ About"): | |
| gr.Markdown(f""" | |
| ## About MiraTTS | |
| MiraTTS is an optimized version of Spark-TTS with the following features: | |
| - **Ultra-fast generation**: Over 100x realtime speed using LMDeploy optimization | |
| - **High quality**: Generates crisp 48kHz audio outputs | |
| - **Memory efficient**: Works within 6GB VRAM or on CPU | |
| - **Low latency**: As low as 100ms generation time (GPU) | |
| - **Voice cloning**: Clone any voice from a short audio sample | |
| - **Background processing**: Non-blocking audio generation | |
| - **Generation history**: Review and download all generated audio | |
| ### Current Configuration | |
| - **Device**: {DEVICE.upper()} | |
| - **Base model**: Spark-TTS-0.5B | |
| - **Optimization**: LMDeploy + FlashSR | |
| - **Sample rate**: 48kHz | |
| - **Model size**: ~500M parameters | |
| ### Usage Tips | |
| - For voice cloning, use clear audio samples between 3-30 seconds | |
| - Ensure reference audio is at least 16kHz quality | |
| - Longer text inputs may require more memory | |
| - Adjust generation parameters for different voice styles | |
| - CPU mode is slower but works without GPU | |
| - Check the History tab to download previous generations | |
| ### Performance Notes | |
| - **GPU**: ~100-200ms per generation | |
| - **CPU**: ~2-5 seconds per generation (depending on CPU) | |
| """) | |
| return demo | |
| def parse_arguments(): | |
| """Parse command-line arguments.""" | |
| parser = argparse.ArgumentParser(description="MiraTTS Gradio Web Interface") | |
| parser.add_argument( | |
| "--model_dir", | |
| type=str, | |
| default="YatharthS/MiraTTS", | |
| help="Path to the MiraTTS model directory or HuggingFace model ID" | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=None, | |
| choices=["cuda", "cpu"], | |
| help="Device to run model on (default: auto-detect)" | |
| ) | |
| parser.add_argument( | |
| "--server_name", | |
| type=str, | |
| default="127.0.0.1", | |
| help="Server host/IP for Gradio app" | |
| ) | |
| parser.add_argument( | |
| "--server_port", | |
| type=int, | |
| default=7860, | |
| help="Server port for Gradio app" | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| help="Create a public shareable link" | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| # Parse arguments | |
| args = parse_arguments() | |
| # Set device if specified | |
| if args.device: | |
| if args.device == "cuda": | |
| try: | |
| if not torch.cuda.is_available(): | |
| logging.warning("CUDA requested but not available, falling back to CPU") | |
| DEVICE = "cpu" | |
| else: | |
| torch.cuda.current_device() # Test CUDA access | |
| DEVICE = args.device | |
| except Exception as e: | |
| logging.warning(f"CUDA test failed: {e}, falling back to CPU") | |
| DEVICE = "cpu" | |
| else: | |
| DEVICE = args.device | |
| logging.info(f"Device selected: {DEVICE}") | |
| # Initialize model | |
| logging.info("Initializing MiraTTS model...") | |
| MODEL = initialize_model(args.model_dir, args.device) | |
| # Build and launch interface | |
| logging.info("Building Gradio interface...") | |
| demo = build_ui() | |
| logging.info(f"Launching web interface on {args.server_name}:{args.server_port}") | |
| logging.info(f"Device: {DEVICE}") | |
| demo.launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| share=args.share | |
| ) |