Arrcttacsrks's picture
Update app.py
61f0f43 verified
raw
history blame
20.7 kB
import gradio as gr
import sys
import os
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
import tempfile
import librosa
from pathlib import Path
print("=" * 60)
print("🎙️ Fun-CosyVoice3 TTS Initialization")
print("=" * 60)
# Step 1: Setup directories
print("\n📁 Step 1: Setting up directories...")
WORK_DIR = Path.cwd()
COSYVOICE_DIR = WORK_DIR / "CosyVoice"
MODEL_DIR = COSYVOICE_DIR / "pretrained_models" / "Fun-CosyVoice3-0.5B"
print(f"Working directory: {WORK_DIR}")
print(f"CosyVoice directory: {COSYVOICE_DIR}")
print(f"Model directory: {MODEL_DIR}")
# Step 2: Clone CosyVoice if needed
if not COSYVOICE_DIR.exists():
print("\n📥 Step 2: Cloning CosyVoice repository...")
import subprocess
try:
subprocess.run([
'git', 'clone', '--recursive',
'https://github.com/FunAudioLLM/CosyVoice.git',
str(COSYVOICE_DIR)
], check=True)
print("✅ Repository cloned successfully")
except Exception as e:
print(f"❌ Failed to clone repository: {e}")
raise
else:
print("\n✅ Step 2: CosyVoice repository already exists")
# Step 3: Download models
if not MODEL_DIR.exists():
print("\n📥 Step 3: Downloading models (this may take a few minutes)...")
from huggingface_hub import snapshot_download
try:
print("Downloading Fun-CosyVoice3-0.5B-2512...")
snapshot_download(
'FunAudioLLM/Fun-CosyVoice3-0.5B-2512',
local_dir=str(MODEL_DIR),
local_dir_use_symlinks=False
)
print("✅ Model downloaded successfully")
except Exception as e:
print(f"❌ Failed to download model: {e}")
raise
else:
print("\n✅ Step 3: Models already exist")
# Step 4: Download ttsfrd (optional)
TTSFRD_DIR = COSYVOICE_DIR / "pretrained_models" / "CosyVoice-ttsfrd"
if not TTSFRD_DIR.exists():
print("\n📥 Step 4: Downloading ttsfrd...")
from huggingface_hub import snapshot_download
try:
snapshot_download(
'FunAudioLLM/CosyVoice-ttsfrd',
local_dir=str(TTSFRD_DIR),
local_dir_use_symlinks=False
)
print("✅ ttsfrd downloaded successfully")
except Exception as e:
print(f"⚠️ Failed to download ttsfrd (will use WeText): {e}")
else:
print("\n✅ Step 4: ttsfrd already exists")
# Step 5: Add to Python path
print("\n🔧 Step 5: Configuring Python path...")
sys.path.insert(0, str(COSYVOICE_DIR))
sys.path.insert(0, str(COSYVOICE_DIR / "third_party" / "Matcha-TTS"))
print(f"Added to path: {COSYVOICE_DIR}")
print(f"Added to path: {COSYVOICE_DIR / 'third_party' / 'Matcha-TTS'}")
# Step 6: Import CosyVoice
print("\n📦 Step 6: Importing CosyVoice...")
try:
from cosyvoice.cli.cosyvoice import AutoModel as CosyVoiceAutoModel
from cosyvoice.utils.file_utils import load_wav
from cosyvoice.utils.common import set_all_random_seed
print("✅ CosyVoice imported successfully")
except Exception as e:
print(f"❌ Failed to import CosyVoice: {e}")
raise
print("\n" + "=" * 60)
print("✅ Initialization completed successfully!")
print("=" * 60 + "\n")
# Global variables
cosyvoice = None
target_sr = 24000
prompt_sr = 16000
max_val = 0.8
top_db = 60
hop_length = 220
win_length = 440
def load_model():
"""Load the CosyVoice model"""
global cosyvoice
if cosyvoice is None:
print("🚀 Loading CosyVoice model...")
try:
cosyvoice = CosyVoiceAutoModel(
model_dir=str(MODEL_DIR),
load_trt=False,
fp16=False
)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
raise gr.Error(f"Failed to load model: {e}")
return cosyvoice
def postprocess(wav_path):
"""Post-process audio - trim silence and normalize (from official code)"""
try:
speech = load_wav(wav_path, target_sr=target_sr, min_sr=16000)
# Trim silence from beginning and end
speech, _ = librosa.effects.trim(
speech, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
# Normalize if too loud
if speech.abs().max() > max_val:
speech = speech / speech.abs().max() * max_val
# Add silence at the end
speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
# Save back
torchaudio.save(wav_path, speech, target_sr)
return wav_path
except Exception as e:
print(f"⚠️ Postprocess warning: {e}")
return wav_path
def process_audio(audio_input):
"""
Convert audio input to proper format for CosyVoice
Handles: stereo->mono, different dtypes, resampling
"""
if audio_input is None:
return None
try:
sr, audio_data = audio_input
print(f"📊 Input audio - shape: {audio_data.shape}, dtype: {audio_data.dtype}, sr: {sr}Hz")
# Step 1: Normalize data type to float32
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
elif audio_data.dtype == np.float64:
audio_data = audio_data.astype(np.float32)
elif audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Step 2: Convert stereo to mono if needed
if len(audio_data.shape) == 2:
print(f" Converting stereo ({audio_data.shape[1]} channels) to mono...")
if audio_data.shape[1] == 2:
audio_data = audio_data.mean(axis=1)
elif audio_data.shape[1] == 1:
audio_data = audio_data.squeeze()
else:
audio_data = audio_data[:, 0]
# Step 3: Ensure 1D array
audio_data = audio_data.flatten()
# Step 4: Check and adjust duration
duration = len(audio_data) / sr
print(f" Duration: {duration:.2f}s")
if duration < 1:
return None, "❌ Audio too short (minimum 1 second)"
if duration > 30:
print(f" ⚠️ Truncating audio from {duration:.2f}s to 30s")
audio_data = audio_data[:sr * 30]
# Step 5: Convert to torch tensor
audio_tensor = torch.from_numpy(audio_data).float()
# Step 6: Add channel dimension (1, samples)
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
print(f" Tensor shape: {audio_tensor.shape}")
# Step 7: Resample if needed
if sr != target_sr:
print(f" 🔄 Resampling from {sr}Hz to {target_sr}Hz...")
resampler = T.Resample(sr, target_sr)
audio_tensor = resampler(audio_tensor)
sr = target_sr
# Step 8: Save to temporary file
temp_path = tempfile.mktemp(suffix='.wav')
torchaudio.save(temp_path, audio_tensor, sr)
# Step 9: Post-process (trim silence, normalize)
temp_path = postprocess(temp_path)
print(f" ✅ Audio processed and saved: {os.path.basename(temp_path)}")
return temp_path
except Exception as e:
print(f"❌ Error processing audio: {e}")
import traceback
traceback.print_exc()
return None
def zero_shot_tts(tts_text, prompt_text, prompt_audio, seed, speed):
"""Zero-shot TTS synthesis - following official code structure"""
try:
# Validation
if not tts_text or not tts_text.strip():
return None, "❌ Please provide text to synthesize"
if len(tts_text) > 200:
return None, "❌ Text too long, please keep within 200 characters"
if not prompt_audio:
return None, "❌ Please upload reference audio"
if not prompt_text or not prompt_text.strip():
return None, "❌ Please provide prompt text"
# Load model
model = load_model()
# Process audio
prompt_audio_path = process_audio(prompt_audio)
if prompt_audio_path is None:
return None, "❌ Failed to process audio"
# Check sample rate
info = torchaudio.info(prompt_audio_path)
if info.sample_rate < prompt_sr:
return None, f"❌ Audio sample rate {info.sample_rate} is below {prompt_sr}Hz"
# Check duration
duration = info.num_frames / info.sample_rate
if duration > 10:
return None, "❌ Please keep prompt audio within 10 seconds"
# Clean inputs
tts_text = tts_text.strip()
prompt_text = prompt_text.strip()
# Build prompt following official format
# IMPORTANT: This is the official format from the code
full_prompt = f"You are a helpful assistant.<|endofprompt|>{prompt_text}"
print(f"\n🎵 Generating speech...")
print(f" TTS text: '{tts_text[:100]}{'...' if len(tts_text) > 100 else ''}'")
print(f" Prompt text: '{prompt_text[:50]}{'...' if len(prompt_text) > 50 else ''}'")
print(f" Full prompt: '{full_prompt[:80]}{'...' if len(full_prompt) > 80 else ''}'")
print(f" Seed: {seed}, Speed: {speed}")
# Set random seed
set_all_random_seed(seed)
# Generate - following official code exactly
speech_list = []
for i in model.inference_zero_shot(
tts_text, # Text to synthesize
full_prompt, # Prompt with special format
prompt_audio_path, # Processed prompt audio
stream=False,
speed=speed
):
speech_list.append(i["tts_speech"])
# Concatenate all speech segments
output_speech = torch.concat(speech_list, dim=1)
# Clean up
if os.path.exists(prompt_audio_path):
os.remove(prompt_audio_path)
print(f" ✅ Generated audio shape: {output_speech.shape}")
print("✅ Speech generated successfully!\n")
# Return as numpy array for Gradio
return (target_sr, output_speech.numpy().flatten()), "✅ Success!"
except Exception as e:
print(f"❌ Error in zero_shot_tts: {e}")
import traceback
traceback.print_exc()
# Clean up on error
try:
if prompt_audio_path and os.path.exists(prompt_audio_path):
os.remove(prompt_audio_path)
except:
pass
return None, f"❌ Error: {str(e)}"
def instruct_tts(tts_text, instruct_text, prompt_audio, seed, speed):
"""Instruction-based TTS - following official code structure"""
try:
# Validation
if not tts_text or not tts_text.strip():
return None, "❌ Please provide text to synthesize"
if len(tts_text) > 200:
return None, "❌ Text too long, please keep within 200 characters"
if not prompt_audio:
return None, "❌ Please upload reference audio"
if not instruct_text or not instruct_text.strip():
return None, "❌ Please provide instruction text"
# Load model
model = load_model()
# Process audio
prompt_audio_path = process_audio(prompt_audio)
if prompt_audio_path is None:
return None, "❌ Failed to process audio"
# Clean inputs
tts_text = tts_text.strip()
instruct_text = instruct_text.strip()
print(f"\n📝 Generating speech with instruction...")
print(f" TTS text: '{tts_text[:100]}{'...' if len(tts_text) > 100 else ''}'")
print(f" Instruction: '{instruct_text}'")
print(f" Seed: {seed}, Speed: {speed}")
# Set random seed
set_all_random_seed(seed)
# Generate - following official code
speech_list = []
for i in model.inference_instruct2(
tts_text, # Text to synthesize
instruct_text, # Instruction
prompt_audio_path, # Processed prompt audio
stream=False,
speed=speed
):
speech_list.append(i["tts_speech"])
# Concatenate all speech segments
output_speech = torch.concat(speech_list, dim=1)
# Clean up
if os.path.exists(prompt_audio_path):
os.remove(prompt_audio_path)
print(f" ✅ Generated audio shape: {output_speech.shape}")
print("✅ Speech generated successfully!\n")
# Return as numpy array for Gradio
return (target_sr, output_speech.numpy().flatten()), "✅ Success!"
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()
# Clean up on error
try:
if prompt_audio_path and os.path.exists(prompt_audio_path):
os.remove(prompt_audio_path)
except:
pass
return None, f"❌ Error: {str(e)}"
# Instruction options (from official code)
instruct_options = [
"You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请用正常的语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请用慢一点的语速说一句话。<|endofprompt|>",
"You are a helpful assistant. Please speak in a professional tone.<|endofprompt|>",
"You are a helpful assistant. Please speak in a friendly tone.<|endofprompt|>",
]
# Create Gradio interface
with gr.Blocks(title="Fun-CosyVoice3 TTS") as demo:
gr.Markdown("""
# 🎙️ Fun-CosyVoice3-0.5B Text-to-Speech
Advanced multilingual zero-shot TTS system supporting **9 languages** and **18+ Chinese dialects**.
Based on the official [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) implementation.
""")
with gr.Tabs():
# Tab 1: Zero-Shot TTS
with gr.Tab("🎯 Zero-Shot Voice Cloning (3s Fast Cloning)"):
gr.Markdown("""
### Clone any voice with 3-10 seconds of reference audio
**Steps:**
1. Upload or record reference audio (≤30s, ≥16kHz)
2. Enter the **prompt text** (transcription of the reference audio)
3. Enter the **text to synthesize** (what you want the voice to say)
4. Click Generate
""")
with gr.Row():
with gr.Column():
zs_tts_text = gr.Textbox(
label="Text to synthesize (what will be spoken)",
placeholder="Enter the text you want to synthesize...",
lines=2,
value="Her handwriting is very neat, which suggests she likes things tidy."
)
zs_prompt_audio = gr.Audio(
label="Reference audio (upload or record)",
type="numpy"
)
zs_prompt_text = gr.Textbox(
label="Prompt text (transcription of reference audio)",
placeholder="Enter what is said in the reference audio...",
lines=2,
value=""
)
with gr.Row():
zs_seed = gr.Number(label="Random seed", value=0, precision=0)
zs_speed = gr.Slider(label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1)
zs_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
with gr.Column():
zs_output = gr.Audio(label="Generated speech")
zs_status = gr.Textbox(label="Status", interactive=False)
zs_btn.click(
fn=zero_shot_tts,
inputs=[zs_tts_text, zs_prompt_text, zs_prompt_audio, zs_seed, zs_speed],
outputs=[zs_output, zs_status]
)
gr.Markdown("""
**Important:**
- **Text to synthesize**: The new text you want to hear in the cloned voice
- **Prompt text**: Transcription of what is said in your reference audio
- **Reference audio**: 3-10 seconds of clear speech
**Example:**
- Reference audio: Someone saying "Hello, how are you?"
- Prompt text: "Hello, how are you?"
- Text to synthesize: "This is a test of voice cloning"
- Result: "This is a test of voice cloning" in the cloned voice
""")
# Tab 2: Instruction-Based TTS
with gr.Tab("📝 Instruction-Based Control (Natural Language)"):
gr.Markdown("""
### Control voice characteristics with natural language instructions
**Steps:**
1. Upload or record reference audio
2. Select or enter instruction (speed, dialect, emotion)
3. Enter text to synthesize
4. Click Generate
""")
with gr.Row():
with gr.Column():
inst_tts_text = gr.Textbox(
label="Text to synthesize",
placeholder="Enter your text...",
lines=2,
value="Welcome to the natural language control demo."
)
inst_prompt_audio = gr.Audio(
label="Reference audio",
type="numpy"
)
inst_text = gr.Dropdown(
label="Instruction",
choices=instruct_options,
value=instruct_options[0]
)
with gr.Row():
inst_seed = gr.Number(label="Random seed", value=0, precision=0)
inst_speed = gr.Slider(label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1)
inst_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
with gr.Column():
inst_output = gr.Audio(label="Generated speech")
inst_status = gr.Textbox(label="Status", interactive=False)
inst_btn.click(
fn=instruct_tts,
inputs=[inst_tts_text, inst_text, inst_prompt_audio, inst_seed, inst_speed],
outputs=[inst_output, inst_status]
)
gr.Markdown("""
**Example instructions:**
- "请用广东话表达" (Speak in Cantonese)
- "请用尽可能快地语速说" (Speak as fast as possible)
- "Please speak in a professional tone"
""")
gr.Markdown("""
---
### 📋 Supported Languages & Dialects
**Languages:** Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian
**Chinese Dialects:** Guangdong, Minnan, Sichuan, Dongbei, Shanxi, Shanghai, Tianjin, Shandong, and more
### ⚡ Performance
- Model: Fun-CosyVoice3-0.5B (500M parameters)
- Sample Rate: 24kHz
- Latency: ~5-10s on CPU, ~2-3s on GPU
### 📚 Resources
[Paper](https://arxiv.org/abs/2505.17589) • [GitHub](https://github.com/FunAudioLLM/CosyVoice) • [Model](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512)
""")
if __name__ == "__main__":
print("\n🚀 Launching Gradio interface...")
demo.queue(max_size=10, default_concurrency_limit=2)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)