VoiceApp / app.py
mujahid1214's picture
Update app.py
388d00a verified
import os
import torch
import gradio as gr
from openvoice import se_extractor
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
# Auto-download checkpoints if not exist (from OpenVoice repo)
if not os.path.isdir("checkpoints"):
from openvoice.utils import download_checkpoints_v2
print("Downloading OpenVoice V2 checkpoints (~1.5GB)...")
download_checkpoints_v2()
# ------------------- Setup -------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
# English Base Speaker (V1 style)
ckpt_base_en = 'checkpoints/base_speakers/EN'
base_speaker_tts_en = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
base_speaker_tts_en.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')
# Chinese Base Speaker
ckpt_base_zh = 'checkpoints/base_speakers/ZH'
base_speaker_tts_zh = BaseSpeakerTTS(f'{ckpt_base_zh}/config.json', device=device)
base_speaker_tts_zh.load_ckpt(f'{ckpt_base_zh}/checkpoint.pth')
# Tone Color Converter (shared)
ckpt_converter = 'checkpoints/converter'
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
# Pre-load source style embeddings
source_se_en_default = torch.load(f'{ckpt_base_en}/en_default_se.pth', map_location=device)
source_se_en_whisper = torch.load(f'{ckpt_base_en}/en_style_se.pth', map_location=device)
source_se_zh_default = torch.load(f'{ckpt_base_zh}/zh_default_se.pth', map_location=device)
# ------------------- Main Function -------------------
def voice_clone(
reference_audio,
text,
language="English",
style="default",
speed=1.0
):
if reference_audio is None:
return None, "Please upload a reference voice (5-30 seconds)"
if not text.strip():
return None, "Please enter some text"
# Extract target speaker embedding
target_se, _ = se_extractor.get_se(
reference_audio, tone_color_converter, target_dir='processed', vad=True
)
# Choose base TTS and source SE
if language == "English":
tts = base_speaker_tts_en
if style == "whispering":
source_se = source_se_en_whisper
speed = 0.9
else:
source_se = source_se_en_default
else: # Chinese
tts = base_speaker_tts_zh
source_se = source_se_zh_default
style = "default"
# Generate base speech
src_path = f"{output_dir}/tmp.wav"
tts.tts(text, src_path, speaker=style, language=language, speed=speed)
# Convert to cloned voice
save_path = f"{output_dir}/output_cloned.wav"
encode_message = "@MyShell"
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message=encode_message
)
return save_path, f"Success! Cloned in {language} ({style})"
# ------------------- Gradio Interface -------------------
with gr.Blocks(title="OpenVoice Voice Style Control Demo") as demo:
gr.Markdown("# OpenVoice Voice Style Control Demo")
gr.Markdown("Upload any voice → Choose language & style → Generate cloned speech instantly!")
with gr.Row():
ref_audio = gr.Audio(
label="Reference Voice (5-30s, clear speech)",
type="filepath",
sources=["upload"]
)
with gr.Row():
text_input = gr.Textbox(
label="Text to Speak",
value="This audio is generated by OpenVoice.",
lines=3
)
with gr.Row():
language = gr.Dropdown(
["English", "Chinese"],
value="English",
label="Language"
)
style = gr.Dropdown(
["default", "whispering"],
value="default",
label="Style (English only)"
)
generate_btn = gr.Button("Generate Cloned Voice", variant="primary")
with gr.Row():
output_audio = gr.Audio(label="Cloned Output")
status = gr.Textbox(label="Status")
generate_btn.click(
fn=voice_clone,
inputs=[ref_audio, text_input, language, style],
outputs=[output_audio, status]
)
gr.Markdown("""
**Tech for good**: All outputs contain @MyShell watermark.
Made with [OpenVoice by MyShell.ai](https://github.com/myshell-ai/OpenVoice)
""")
demo.launch()