voiceclone / main.py
Anni24181410's picture
Update main.py
c0afaa3 verified
raw
history blame
7.39 kB
import os
import time
import torch
import torchaudio
import gradio as gr
from TTS.api import TTS
from pathlib import Path
from datetime import datetime
device = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True)
HISTORY = []
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
LANGUAGES = [
("English", "en"),
("Hindi", "hi"),
("French", "fr"),
("German", "de"),
("Spanish", "es"),
("Italian", "it"),
("Portuguese", "pt"),
("Chinese", "zh"),
("Japanese", "ja"),
("Korean", "ko"),
("Arabic", "ar"),
("Turkish", "tr"),
("Russian", "ru"),
("Dutch", "nl"),
("Polish", "pl"),
("Czech", "cs"),
("Hungarian", "hu"),
]
def validate_audio(audio_path):
if audio_path is None:
return False, "Please upload a voice sample."
try:
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration < 3:
return False, f"Audio too short ({duration:.1f}s). Minimum 3 seconds."
if duration > 30:
return False, f"Audio too long ({duration:.1f}s). Maximum 30 seconds."
return True, f"Audio valid ({duration:.1f}s)"
except Exception as e:
return False, f"Invalid audio: {str(e)}"
def validate_text(text):
if not text or not text.strip():
return False, "Please enter some text."
if len(text.strip()) < 5:
return False, "Text too short."
if len(text.strip()) > 1000:
return False, "Text too long."
return True, "OK"
def get_history_html():
if not HISTORY:
return "<p style='color:gray;'>No clones yet.</p>"
rows = ""
for i, h in enumerate(reversed(HISTORY[-10:]), 1):
rows += f"<tr><td style='padding:8px;'>{i}</td><td style='padding:8px;'>{h['timestamp']}</td><td style='padding:8px;'>{h['text']}</td><td style='padding:8px;'>{h['language']}</td><td style='padding:8px;'>{h['duration']}</td><td style='padding:8px;'>{h['time_taken']}</td></tr>"
return f"<table style='width:100%;border-collapse:collapse;font-size:13px;'><thead><tr style='background:#6366f1;color:white;'><th style='padding:8px;'>#</th><th style='padding:8px;'>Time</th><th style='padding:8px;'>Text</th><th style='padding:8px;'>Language</th><th style='padding:8px;'>Duration</th><th style='padding:8px;'>Generated in</th></tr></thead><tbody>{rows}</tbody></table>"
def clone_voice(text, speaker_audio, language, speed, progress=gr.Progress()):
start_time = time.time()
progress(0, desc="Validating...")
text_ok, text_msg = validate_text(text)
if not text_ok:
return None, text_msg, get_history_html()
audio_ok, audio_msg = validate_audio(speaker_audio)
if not audio_ok:
return None, audio_msg, get_history_html()
try:
progress(0.3, desc="Cloning voice...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = str(OUTPUT_DIR / f"clone_{timestamp}.wav")
tts.tts_to_file(
text=text.strip(),
speaker_wav=speaker_audio,
language=language,
file_path=output_path,
speed=speed
)
progress(0.9, desc="Finalizing...")
elapsed = time.time() - start_time
waveform, sr = torchaudio.load(output_path)
output_duration = waveform.shape[1] / sr
lang_label = dict(LANGUAGES).get(language, language)
HISTORY.append({
"timestamp": datetime.now().strftime("%d %b %Y %I:%M %p"),
"text": text.strip()[:60] + "..." if len(text.strip()) > 60 else text.strip(),
"language": lang_label,
"duration": f"{output_duration:.1f}s",
"time_taken": f"{elapsed:.1f}s",
"path": output_path
})
progress(1.0, desc="Done!")
return output_path, f"Cloned in {elapsed:.1f}s | {output_duration:.1f}s output", get_history_html()
except Exception as e:
return None, f"Error: {str(e)}", get_history_html()
def count_chars(text):
return f"{len(text) if text else 0} / 1000"
def audio_info(audio_path):
if audio_path is None:
return "No audio uploaded."
_, msg = validate_audio(audio_path)
return msg
def clear_all():
return "", None, "en", 1.0, None, "Cleared.", get_history_html()
def clear_history():
HISTORY.clear()
return get_history_html()
def get_system_info():
info = f"Device: {device.upper()}\nModel: XTTS-v2\n"
if torch.cuda.is_available():
info += f"GPU: {torch.cuda.get_device_name(0)}\n"
info += f"Total clones: {len(HISTORY)}"
return info
with gr.Blocks(title="AI Voice Cloning Studio", theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="purple")) as demo:
gr.Markdown("# AI Voice Cloning Studio\n### Powered by XTTS-v2")
with gr.Tabs():
with gr.Tab("Clone Voice"):
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(label="Text to speak", placeholder="Type what you want the cloned voice to say...", lines=6)
char_display = gr.Textbox(label="Character count", value="0 / 1000", interactive=False)
audio_input = gr.Audio(label="Voice sample (3 to 30 seconds)", type="filepath", sources=["upload", "microphone"])
audio_status = gr.Textbox(label="Audio info", interactive=False, value="No audio uploaded.")
with gr.Row():
language = gr.Dropdown(choices=LANGUAGES, value="en", label="Language")
speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed")
with gr.Row():
clone_btn = gr.Button("Clone Voice", variant="primary", size="lg", scale=3)
clear_btn = gr.Button("Clear", variant="secondary", size="lg", scale=1)
with gr.Column(scale=1):
audio_output = gr.Audio(label="Cloned voice", type="filepath", interactive=False)
status_msg = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Tips\n- Clear noise free audio\n- 5 to 10 seconds works best\n- WAV or MP3 supported\n- One speaker only")
with gr.Tab("History"):
history_display = gr.HTML(value="<p style='color:gray;'>No clones yet.</p>")
clear_history_btn = gr.Button("Clear History", variant="stop")
with gr.Tab("System Info"):
system_info = gr.Textbox(value=get_system_info(), interactive=False, lines=6, label="System")
refresh_btn = gr.Button("Refresh", variant="secondary")
text_input.change(fn=count_chars, inputs=text_input, outputs=char_display)
audio_input.change(fn=audio_info, inputs=audio_input, outputs=audio_status)
clone_btn.click(fn=clone_voice, inputs=[text_input, audio_input, language, speed], outputs=[audio_output, status_msg, history_display])
clear_btn.click(fn=clear_all, inputs=[], outputs=[text_input, audio_input, language, speed, audio_output, status_msg, history_display])
clear_history_btn.click(fn=clear_history, inputs=[], outputs=[history_display])
refresh_btn.click(fn=get_system_info, inputs=[], outputs=[system_info])
demo.launch(server_name="0.0.0.0", server_port=7860)