voiceclone / main.py
Anni24181410's picture
Update main.py
50ba386 verified
import os
import time
import torch
import torchaudio
import gradio as gr
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsAudioConfig, Xtts, XttsArgs
from TTS.config.shared_configs import BaseDatasetConfig, BaseAudioConfig
from pathlib import Path
from datetime import datetime
from pydub import AudioSegment
device = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True)
HISTORY = []
tts = None
# PyTorch 2.6 fix - allowlist all XTTS classes
torch.serialization.add_safe_globals([
XttsConfig,
XttsAudioConfig,
BaseDatasetConfig,
BaseAudioConfig,
Xtts,
XttsArgs,
])
def get_tts():
global tts
if tts is None:
os.environ["COQUI_TOS_AGREED"] = "1"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
return tts
LANGUAGES = [
("English", "en"),
("Hindi", "hi"),
("French", "fr"),
("German", "de"),
("Spanish", "es"),
("Italian", "it"),
("Portuguese", "pt"),
("Chinese", "zh"),
("Japanese", "ja"),
("Korean", "ko"),
("Arabic", "ar"),
("Turkish", "tr"),
("Russian", "ru"),
("Dutch", "nl"),
("Polish", "pl"),
("Czech", "cs"),
("Hungarian", "hu"),
]
ALLOWED_FORMATS = [
".wav", ".mp3", ".flac", ".ogg",
".m4a", ".aac", ".wma", ".opus",
".mpeg", ".mp4"
]
def convert_to_wav(audio_path):
ext = Path(audio_path).suffix.lower()
if ext == ".wav":
return audio_path
try:
audio = AudioSegment.from_file(audio_path)
wav_path = str(Path(audio_path).with_suffix(".wav"))
audio.export(wav_path, format="wav")
return wav_path
except Exception as e:
raise Exception(f"Could not convert audio to WAV: {str(e)}")
def validate_audio(audio_path):
if audio_path is None:
return False, "Please upload a voice sample."
ext = Path(audio_path).suffix.lower()
if ext not in ALLOWED_FORMATS:
return False, f"Unsupported format '{ext}'. Allowed: WAV, MP3, FLAC, OGG, M4A, AAC, WMA, OPUS, MPEG, MP4"
# Try torchaudio first
try:
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration < 3:
return False, f"Audio too short ({duration:.1f}s). Minimum 3 seconds."
if duration > 30:
return False, f"Audio too long ({duration:.1f}s). Maximum 30 seconds."
return True, f"Audio valid | Format: {ext.upper()} | Duration: {duration:.1f}s"
except Exception:
pass
# Fallback to pydub
try:
audio = AudioSegment.from_file(audio_path)
duration = len(audio) / 1000
if duration < 3:
return False, f"Audio too short ({duration:.1f}s). Minimum 3 seconds."
if duration > 30:
return False, f"Audio too long ({duration:.1f}s). Maximum 30 seconds."
return True, f"Audio valid | Format: {ext.upper()} | Duration: {duration:.1f}s"
except Exception as e:
return False, f"Could not read audio file: {str(e)}"
def validate_text(text):
if not text or not text.strip():
return False, "Please enter some text."
if len(text.strip()) < 5:
return False, "Text too short."
if len(text.strip()) > 1000:
return False, "Text too long."
return True, "OK"
def get_history_html():
if not HISTORY:
return "<p style='color:gray;'>No clones yet.</p>"
rows = ""
for i, h in enumerate(reversed(HISTORY[-10:]), 1):
rows += (
f"<tr>"
f"<td style='padding:8px;'>{i}</td>"
f"<td style='padding:8px;'>{h['timestamp']}</td>"
f"<td style='padding:8px;'>{h['text']}</td>"
f"<td style='padding:8px;'>{h['language']}</td>"
f"<td style='padding:8px;'>{h['duration']}</td>"
f"<td style='padding:8px;'>{h['time_taken']}</td>"
f"</tr>"
)
return (
f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>"
f"<thead><tr style='background:#6366f1;color:white;'>"
f"<th style='padding:8px;'>#</th>"
f"<th style='padding:8px;'>Time</th>"
f"<th style='padding:8px;'>Text</th>"
f"<th style='padding:8px;'>Language</th>"
f"<th style='padding:8px;'>Duration</th>"
f"<th style='padding:8px;'>Generated in</th>"
f"</tr></thead><tbody>{rows}</tbody></table>"
)
def clone_voice(text, speaker_audio, language, speed, progress=gr.Progress()):
start_time = time.time()
progress(0, desc="Validating...")
text_ok, text_msg = validate_text(text)
if not text_ok:
return None, text_msg, get_history_html()
audio_path = speaker_audio if isinstance(speaker_audio, str) else (
speaker_audio.name if speaker_audio is not None else None
)
audio_ok, audio_msg = validate_audio(audio_path)
if not audio_ok:
return None, audio_msg, get_history_html()
try:
progress(0.2, desc="Converting audio...")
speaker_wav = convert_to_wav(audio_path)
progress(0.3, desc="Loading model...")
model = get_tts()
progress(0.5, desc="Cloning voice...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = str(OUTPUT_DIR / f"clone_{timestamp}.wav")
model.tts_to_file(
text=text.strip(),
speaker_wav=speaker_wav,
language=language,
file_path=output_path,
speed=speed
)
progress(0.9, desc="Finalizing...")
elapsed = time.time() - start_time
waveform, sr = torchaudio.load(output_path)
output_duration = waveform.shape[1] / sr
lang_label = dict(LANGUAGES).get(language, language)
HISTORY.append({
"timestamp": datetime.now().strftime("%d %b %Y %I:%M %p"),
"text": text.strip()[:60] + "..." if len(text.strip()) > 60 else text.strip(),
"language": lang_label,
"duration": f"{output_duration:.1f}s",
"time_taken": f"{elapsed:.1f}s",
"path": output_path
})
progress(1.0, desc="Done!")
return output_path, f"Cloned in {elapsed:.1f}s | {output_duration:.1f}s output", get_history_html()
except Exception as e:
return None, f"Error: {str(e)}", get_history_html()
def count_chars(text):
return f"{len(text) if text else 0} / 1000"
def audio_info(audio_path):
if audio_path is None:
return "No audio uploaded."
path = audio_path if isinstance(audio_path, str) else audio_path.name
_, msg = validate_audio(path)
return msg
def clear_all():
return "", None, "en", 1.0, None, "Cleared.", get_history_html()
def clear_history():
HISTORY.clear()
return get_history_html()
def get_system_info():
info = f"Device: {device.upper()}\nModel: XTTS-v2\n"
if torch.cuda.is_available():
info += f"GPU: {torch.cuda.get_device_name(0)}\n"
info += f"Total clones: {len(HISTORY)}"
return info
with gr.Blocks(title="AI Voice Cloning Studio") as demo:
gr.Markdown("# AI Voice Cloning Studio\n### Powered by XTTS-v2")
with gr.Tabs():
with gr.Tab("Clone Voice"):
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Text to speak",
placeholder="Type what you want the cloned voice to say...",
lines=6,
max_lines=10
)
char_display = gr.Textbox(
label="Character count",
value="0 / 1000",
interactive=False
)
audio_input = gr.File(
label="Voice sample (3 to 30 seconds)",
file_types=[
".wav", ".mp3", ".flac", ".ogg",
".m4a", ".aac", ".wma", ".opus",
".mpeg", ".mp4"
],
type="filepath"
)
audio_status = gr.Textbox(
label="Audio info",
interactive=False,
value="No audio uploaded."
)
with gr.Row():
language = gr.Dropdown(
choices=LANGUAGES,
value="en",
label="Language"
)
speed = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed"
)
with gr.Row():
clone_btn = gr.Button("Clone Voice", variant="primary", size="lg", scale=3)
clear_btn = gr.Button("Clear", variant="secondary", size="lg", scale=1)
with gr.Column(scale=1):
audio_output = gr.Audio(
label="Cloned voice",
type="filepath",
interactive=False
)
status_msg = gr.Textbox(label="Status", interactive=False)
gr.Markdown(
"### Tips\n"
"- Use clear, noise-free audio\n"
"- 5 to 10 seconds works best\n"
"- Supported: WAV, MP3, FLAC, OGG, M4A, AAC, WMA, OPUS, MPEG, MP4\n"
"- One speaker only"
)
with gr.Tab("History"):
history_display = gr.HTML(value="<p style='color:gray;'>No clones yet.</p>")
clear_history_btn = gr.Button("Clear History", variant="stop")
with gr.Tab("System Info"):
system_info = gr.Textbox(
value=get_system_info(),
interactive=False,
lines=6,
label="System"
)
refresh_btn = gr.Button("Refresh", variant="secondary")
text_input.change(fn=count_chars, inputs=text_input, outputs=char_display)
audio_input.change(fn=audio_info, inputs=audio_input, outputs=audio_status)
clone_btn.click(
fn=clone_voice,
inputs=[text_input, audio_input, language, speed],
outputs=[audio_output, status_msg, history_display]
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[text_input, audio_input, language, speed, audio_output, status_msg, history_display]
)
clear_history_btn.click(fn=clear_history, inputs=[], outputs=[history_display])
refresh_btn.click(fn=get_system_info, inputs=[], outputs=[system_info])
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="purple")
)