Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import sys | |
| import os | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| import numpy as np | |
| import tempfile | |
| import librosa | |
| from pathlib import Path | |
| print("=" * 60) | |
| print("🎙️ Fun-CosyVoice3 TTS Initialization") | |
| print("=" * 60) | |
| # Step 1: Setup directories | |
| print("\n📁 Step 1: Setting up directories...") | |
| WORK_DIR = Path.cwd() | |
| COSYVOICE_DIR = WORK_DIR / "CosyVoice" | |
| MODEL_DIR = COSYVOICE_DIR / "pretrained_models" / "Fun-CosyVoice3-0.5B" | |
| print(f"Working directory: {WORK_DIR}") | |
| print(f"CosyVoice directory: {COSYVOICE_DIR}") | |
| print(f"Model directory: {MODEL_DIR}") | |
| # Step 2: Clone CosyVoice if needed | |
| if not COSYVOICE_DIR.exists(): | |
| print("\n📥 Step 2: Cloning CosyVoice repository...") | |
| import subprocess | |
| try: | |
| subprocess.run([ | |
| 'git', 'clone', '--recursive', | |
| 'https://github.com/FunAudioLLM/CosyVoice.git', | |
| str(COSYVOICE_DIR) | |
| ], check=True) | |
| print("✅ Repository cloned successfully") | |
| except Exception as e: | |
| print(f"❌ Failed to clone repository: {e}") | |
| raise | |
| else: | |
| print("\n✅ Step 2: CosyVoice repository already exists") | |
| # Step 3: Download models | |
| if not MODEL_DIR.exists(): | |
| print("\n📥 Step 3: Downloading models (this may take a few minutes)...") | |
| from huggingface_hub import snapshot_download | |
| try: | |
| print("Downloading Fun-CosyVoice3-0.5B-2512...") | |
| snapshot_download( | |
| 'FunAudioLLM/Fun-CosyVoice3-0.5B-2512', | |
| local_dir=str(MODEL_DIR), | |
| local_dir_use_symlinks=False | |
| ) | |
| print("✅ Model downloaded successfully") | |
| except Exception as e: | |
| print(f"❌ Failed to download model: {e}") | |
| raise | |
| else: | |
| print("\n✅ Step 3: Models already exist") | |
| # Step 4: Download ttsfrd (optional) | |
| TTSFRD_DIR = COSYVOICE_DIR / "pretrained_models" / "CosyVoice-ttsfrd" | |
| if not TTSFRD_DIR.exists(): | |
| print("\n📥 Step 4: Downloading ttsfrd...") | |
| from huggingface_hub import snapshot_download | |
| try: | |
| snapshot_download( | |
| 'FunAudioLLM/CosyVoice-ttsfrd', | |
| local_dir=str(TTSFRD_DIR), | |
| local_dir_use_symlinks=False | |
| ) | |
| print("✅ ttsfrd downloaded successfully") | |
| except Exception as e: | |
| print(f"⚠️ Failed to download ttsfrd (will use WeText): {e}") | |
| else: | |
| print("\n✅ Step 4: ttsfrd already exists") | |
| # Step 5: Add to Python path | |
| print("\n🔧 Step 5: Configuring Python path...") | |
| sys.path.insert(0, str(COSYVOICE_DIR)) | |
| sys.path.insert(0, str(COSYVOICE_DIR / "third_party" / "Matcha-TTS")) | |
| print(f"Added to path: {COSYVOICE_DIR}") | |
| print(f"Added to path: {COSYVOICE_DIR / 'third_party' / 'Matcha-TTS'}") | |
| # Step 6: Import CosyVoice | |
| print("\n📦 Step 6: Importing CosyVoice...") | |
| try: | |
| from cosyvoice.cli.cosyvoice import AutoModel as CosyVoiceAutoModel | |
| from cosyvoice.utils.file_utils import load_wav | |
| from cosyvoice.utils.common import set_all_random_seed | |
| print("✅ CosyVoice imported successfully") | |
| except Exception as e: | |
| print(f"❌ Failed to import CosyVoice: {e}") | |
| raise | |
| print("\n" + "=" * 60) | |
| print("✅ Initialization completed successfully!") | |
| print("=" * 60 + "\n") | |
| # Global variables | |
| cosyvoice = None | |
| target_sr = 24000 | |
| prompt_sr = 16000 | |
| max_val = 0.8 | |
| top_db = 60 | |
| hop_length = 220 | |
| win_length = 440 | |
| def load_model(): | |
| """Load the CosyVoice model""" | |
| global cosyvoice | |
| if cosyvoice is None: | |
| print("🚀 Loading CosyVoice model...") | |
| try: | |
| cosyvoice = CosyVoiceAutoModel( | |
| model_dir=str(MODEL_DIR), | |
| load_trt=False, | |
| fp16=False | |
| ) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Failed to load model: {e}") | |
| return cosyvoice | |
| def postprocess(wav_path): | |
| """Post-process audio - trim silence and normalize (from official code)""" | |
| try: | |
| speech = load_wav(wav_path, target_sr=target_sr, min_sr=16000) | |
| # Trim silence from beginning and end | |
| speech, _ = librosa.effects.trim( | |
| speech, top_db=top_db, | |
| frame_length=win_length, | |
| hop_length=hop_length | |
| ) | |
| # Normalize if too loud | |
| if speech.abs().max() > max_val: | |
| speech = speech / speech.abs().max() * max_val | |
| # Add silence at the end | |
| speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1) | |
| # Save back | |
| torchaudio.save(wav_path, speech, target_sr) | |
| return wav_path | |
| except Exception as e: | |
| print(f"⚠️ Postprocess warning: {e}") | |
| return wav_path | |
| def process_audio(audio_input): | |
| """ | |
| Convert audio input to proper format for CosyVoice | |
| Handles: stereo->mono, different dtypes, resampling | |
| """ | |
| if audio_input is None: | |
| return None | |
| try: | |
| sr, audio_data = audio_input | |
| print(f"📊 Input audio - shape: {audio_data.shape}, dtype: {audio_data.dtype}, sr: {sr}Hz") | |
| # Step 1: Normalize data type to float32 | |
| if audio_data.dtype == np.int16: | |
| audio_data = audio_data.astype(np.float32) / 32768.0 | |
| elif audio_data.dtype == np.int32: | |
| audio_data = audio_data.astype(np.float32) / 2147483648.0 | |
| elif audio_data.dtype == np.float64: | |
| audio_data = audio_data.astype(np.float32) | |
| elif audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # Step 2: Convert stereo to mono if needed | |
| if len(audio_data.shape) == 2: | |
| print(f" Converting stereo ({audio_data.shape[1]} channels) to mono...") | |
| if audio_data.shape[1] == 2: | |
| audio_data = audio_data.mean(axis=1) | |
| elif audio_data.shape[1] == 1: | |
| audio_data = audio_data.squeeze() | |
| else: | |
| audio_data = audio_data[:, 0] | |
| # Step 3: Ensure 1D array | |
| audio_data = audio_data.flatten() | |
| # Step 4: Check and adjust duration | |
| duration = len(audio_data) / sr | |
| print(f" Duration: {duration:.2f}s") | |
| if duration < 1: | |
| return None, "❌ Audio too short (minimum 1 second)" | |
| if duration > 30: | |
| print(f" ⚠️ Truncating audio from {duration:.2f}s to 30s") | |
| audio_data = audio_data[:sr * 30] | |
| # Step 5: Convert to torch tensor | |
| audio_tensor = torch.from_numpy(audio_data).float() | |
| # Step 6: Add channel dimension (1, samples) | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| print(f" Tensor shape: {audio_tensor.shape}") | |
| # Step 7: Resample if needed | |
| if sr != target_sr: | |
| print(f" 🔄 Resampling from {sr}Hz to {target_sr}Hz...") | |
| resampler = T.Resample(sr, target_sr) | |
| audio_tensor = resampler(audio_tensor) | |
| sr = target_sr | |
| # Step 8: Save to temporary file | |
| temp_path = tempfile.mktemp(suffix='.wav') | |
| torchaudio.save(temp_path, audio_tensor, sr) | |
| # Step 9: Post-process (trim silence, normalize) | |
| temp_path = postprocess(temp_path) | |
| print(f" ✅ Audio processed and saved: {os.path.basename(temp_path)}") | |
| return temp_path | |
| except Exception as e: | |
| print(f"❌ Error processing audio: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def zero_shot_tts(tts_text, prompt_text, prompt_audio, seed, speed): | |
| """Zero-shot TTS synthesis - following official code structure""" | |
| try: | |
| # Validation | |
| if not tts_text or not tts_text.strip(): | |
| return None, "❌ Please provide text to synthesize" | |
| if len(tts_text) > 200: | |
| return None, "❌ Text too long, please keep within 200 characters" | |
| if not prompt_audio: | |
| return None, "❌ Please upload reference audio" | |
| if not prompt_text or not prompt_text.strip(): | |
| return None, "❌ Please provide prompt text" | |
| # Load model | |
| model = load_model() | |
| # Process audio | |
| prompt_audio_path = process_audio(prompt_audio) | |
| if prompt_audio_path is None: | |
| return None, "❌ Failed to process audio" | |
| # Check sample rate | |
| info = torchaudio.info(prompt_audio_path) | |
| if info.sample_rate < prompt_sr: | |
| return None, f"❌ Audio sample rate {info.sample_rate} is below {prompt_sr}Hz" | |
| # Check duration | |
| duration = info.num_frames / info.sample_rate | |
| if duration > 10: | |
| return None, "❌ Please keep prompt audio within 10 seconds" | |
| # Clean inputs | |
| tts_text = tts_text.strip() | |
| prompt_text = prompt_text.strip() | |
| # Build prompt following official format | |
| # IMPORTANT: This is the official format from the code | |
| full_prompt = f"You are a helpful assistant.<|endofprompt|>{prompt_text}" | |
| print(f"\n🎵 Generating speech...") | |
| print(f" TTS text: '{tts_text[:100]}{'...' if len(tts_text) > 100 else ''}'") | |
| print(f" Prompt text: '{prompt_text[:50]}{'...' if len(prompt_text) > 50 else ''}'") | |
| print(f" Full prompt: '{full_prompt[:80]}{'...' if len(full_prompt) > 80 else ''}'") | |
| print(f" Seed: {seed}, Speed: {speed}") | |
| # Set random seed | |
| set_all_random_seed(seed) | |
| # Generate - following official code exactly | |
| speech_list = [] | |
| for i in model.inference_zero_shot( | |
| tts_text, # Text to synthesize | |
| full_prompt, # Prompt with special format | |
| prompt_audio_path, # Processed prompt audio | |
| stream=False, | |
| speed=speed | |
| ): | |
| speech_list.append(i["tts_speech"]) | |
| # Concatenate all speech segments | |
| output_speech = torch.concat(speech_list, dim=1) | |
| # Clean up | |
| if os.path.exists(prompt_audio_path): | |
| os.remove(prompt_audio_path) | |
| print(f" ✅ Generated audio shape: {output_speech.shape}") | |
| print("✅ Speech generated successfully!\n") | |
| # Return as numpy array for Gradio | |
| return (target_sr, output_speech.numpy().flatten()), "✅ Success!" | |
| except Exception as e: | |
| print(f"❌ Error in zero_shot_tts: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Clean up on error | |
| try: | |
| if prompt_audio_path and os.path.exists(prompt_audio_path): | |
| os.remove(prompt_audio_path) | |
| except: | |
| pass | |
| return None, f"❌ Error: {str(e)}" | |
| def instruct_tts(tts_text, instruct_text, prompt_audio, seed, speed): | |
| """Instruction-based TTS - following official code structure""" | |
| try: | |
| # Validation | |
| if not tts_text or not tts_text.strip(): | |
| return None, "❌ Please provide text to synthesize" | |
| if len(tts_text) > 200: | |
| return None, "❌ Text too long, please keep within 200 characters" | |
| if not prompt_audio: | |
| return None, "❌ Please upload reference audio" | |
| if not instruct_text or not instruct_text.strip(): | |
| return None, "❌ Please provide instruction text" | |
| # Load model | |
| model = load_model() | |
| # Process audio | |
| prompt_audio_path = process_audio(prompt_audio) | |
| if prompt_audio_path is None: | |
| return None, "❌ Failed to process audio" | |
| # Clean inputs | |
| tts_text = tts_text.strip() | |
| instruct_text = instruct_text.strip() | |
| print(f"\n📝 Generating speech with instruction...") | |
| print(f" TTS text: '{tts_text[:100]}{'...' if len(tts_text) > 100 else ''}'") | |
| print(f" Instruction: '{instruct_text}'") | |
| print(f" Seed: {seed}, Speed: {speed}") | |
| # Set random seed | |
| set_all_random_seed(seed) | |
| # Generate - following official code | |
| speech_list = [] | |
| for i in model.inference_instruct2( | |
| tts_text, # Text to synthesize | |
| instruct_text, # Instruction | |
| prompt_audio_path, # Processed prompt audio | |
| stream=False, | |
| speed=speed | |
| ): | |
| speech_list.append(i["tts_speech"]) | |
| # Concatenate all speech segments | |
| output_speech = torch.concat(speech_list, dim=1) | |
| # Clean up | |
| if os.path.exists(prompt_audio_path): | |
| os.remove(prompt_audio_path) | |
| print(f" ✅ Generated audio shape: {output_speech.shape}") | |
| print("✅ Speech generated successfully!\n") | |
| # Return as numpy array for Gradio | |
| return (target_sr, output_speech.numpy().flatten()), "✅ Success!" | |
| except Exception as e: | |
| print(f"❌ Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Clean up on error | |
| try: | |
| if prompt_audio_path and os.path.exists(prompt_audio_path): | |
| os.remove(prompt_audio_path) | |
| except: | |
| pass | |
| return None, f"❌ Error: {str(e)}" | |
| # Instruction options (from official code) | |
| instruct_options = [ | |
| "You are a helpful assistant. 请用广东话表达。<|endofprompt|>", | |
| "You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>", | |
| "You are a helpful assistant. 请用正常的语速说一句话。<|endofprompt|>", | |
| "You are a helpful assistant. 请用慢一点的语速说一句话。<|endofprompt|>", | |
| "You are a helpful assistant. Please speak in a professional tone.<|endofprompt|>", | |
| "You are a helpful assistant. Please speak in a friendly tone.<|endofprompt|>", | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Fun-CosyVoice3 TTS") as demo: | |
| gr.Markdown(""" | |
| # 🎙️ Fun-CosyVoice3-0.5B Text-to-Speech | |
| Advanced multilingual zero-shot TTS system supporting **9 languages** and **18+ Chinese dialects**. | |
| Based on the official [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) implementation. | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Zero-Shot TTS | |
| with gr.Tab("🎯 Zero-Shot Voice Cloning (3s Fast Cloning)"): | |
| gr.Markdown(""" | |
| ### Clone any voice with 3-10 seconds of reference audio | |
| **Steps:** | |
| 1. Upload or record reference audio (≤30s, ≥16kHz) | |
| 2. Enter the **prompt text** (transcription of the reference audio) | |
| 3. Enter the **text to synthesize** (what you want the voice to say) | |
| 4. Click Generate | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| zs_tts_text = gr.Textbox( | |
| label="Text to synthesize (what will be spoken)", | |
| placeholder="Enter the text you want to synthesize...", | |
| lines=2, | |
| value="Her handwriting is very neat, which suggests she likes things tidy." | |
| ) | |
| zs_prompt_audio = gr.Audio( | |
| label="Reference audio (upload or record)", | |
| type="numpy" | |
| ) | |
| zs_prompt_text = gr.Textbox( | |
| label="Prompt text (transcription of reference audio)", | |
| placeholder="Enter what is said in the reference audio...", | |
| lines=2, | |
| value="" | |
| ) | |
| with gr.Row(): | |
| zs_seed = gr.Number(label="Random seed", value=0, precision=0) | |
| zs_speed = gr.Slider(label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1) | |
| zs_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") | |
| with gr.Column(): | |
| zs_output = gr.Audio(label="Generated speech") | |
| zs_status = gr.Textbox(label="Status", interactive=False) | |
| zs_btn.click( | |
| fn=zero_shot_tts, | |
| inputs=[zs_tts_text, zs_prompt_text, zs_prompt_audio, zs_seed, zs_speed], | |
| outputs=[zs_output, zs_status] | |
| ) | |
| gr.Markdown(""" | |
| **Important:** | |
| - **Text to synthesize**: The new text you want to hear in the cloned voice | |
| - **Prompt text**: Transcription of what is said in your reference audio | |
| - **Reference audio**: 3-10 seconds of clear speech | |
| **Example:** | |
| - Reference audio: Someone saying "Hello, how are you?" | |
| - Prompt text: "Hello, how are you?" | |
| - Text to synthesize: "This is a test of voice cloning" | |
| - Result: "This is a test of voice cloning" in the cloned voice | |
| """) | |
| # Tab 2: Instruction-Based TTS | |
| with gr.Tab("📝 Instruction-Based Control (Natural Language)"): | |
| gr.Markdown(""" | |
| ### Control voice characteristics with natural language instructions | |
| **Steps:** | |
| 1. Upload or record reference audio | |
| 2. Select or enter instruction (speed, dialect, emotion) | |
| 3. Enter text to synthesize | |
| 4. Click Generate | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inst_tts_text = gr.Textbox( | |
| label="Text to synthesize", | |
| placeholder="Enter your text...", | |
| lines=2, | |
| value="Welcome to the natural language control demo." | |
| ) | |
| inst_prompt_audio = gr.Audio( | |
| label="Reference audio", | |
| type="numpy" | |
| ) | |
| inst_text = gr.Dropdown( | |
| label="Instruction", | |
| choices=instruct_options, | |
| value=instruct_options[0] | |
| ) | |
| with gr.Row(): | |
| inst_seed = gr.Number(label="Random seed", value=0, precision=0) | |
| inst_speed = gr.Slider(label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1) | |
| inst_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") | |
| with gr.Column(): | |
| inst_output = gr.Audio(label="Generated speech") | |
| inst_status = gr.Textbox(label="Status", interactive=False) | |
| inst_btn.click( | |
| fn=instruct_tts, | |
| inputs=[inst_tts_text, inst_text, inst_prompt_audio, inst_seed, inst_speed], | |
| outputs=[inst_output, inst_status] | |
| ) | |
| gr.Markdown(""" | |
| **Example instructions:** | |
| - "请用广东话表达" (Speak in Cantonese) | |
| - "请用尽可能快地语速说" (Speak as fast as possible) | |
| - "Please speak in a professional tone" | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| ### 📋 Supported Languages & Dialects | |
| **Languages:** Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian | |
| **Chinese Dialects:** Guangdong, Minnan, Sichuan, Dongbei, Shanxi, Shanghai, Tianjin, Shandong, and more | |
| ### ⚡ Performance | |
| - Model: Fun-CosyVoice3-0.5B (500M parameters) | |
| - Sample Rate: 24kHz | |
| - Latency: ~5-10s on CPU, ~2-3s on GPU | |
| ### 📚 Resources | |
| [Paper](https://arxiv.org/abs/2505.17589) • [GitHub](https://github.com/FunAudioLLM/CosyVoice) • [Model](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512) | |
| """) | |
| if __name__ == "__main__": | |
| print("\n🚀 Launching Gradio interface...") | |
| demo.queue(max_size=10, default_concurrency_limit=2) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |