Spaces:
Running
Running
| import os | |
| import time | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from TTS.api import TTS | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import XttsAudioConfig, Xtts, XttsArgs | |
| from TTS.config.shared_configs import BaseDatasetConfig, BaseAudioConfig | |
| from pathlib import Path | |
| from datetime import datetime | |
| from pydub import AudioSegment | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| OUTPUT_DIR = Path("outputs") | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| HISTORY = [] | |
| tts = None | |
| # PyTorch 2.6 fix - allowlist all XTTS classes | |
| torch.serialization.add_safe_globals([ | |
| XttsConfig, | |
| XttsAudioConfig, | |
| BaseDatasetConfig, | |
| BaseAudioConfig, | |
| Xtts, | |
| XttsArgs, | |
| ]) | |
| def get_tts(): | |
| global tts | |
| if tts is None: | |
| os.environ["COQUI_TOS_AGREED"] = "1" | |
| tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
| return tts | |
| LANGUAGES = [ | |
| ("English", "en"), | |
| ("Hindi", "hi"), | |
| ("French", "fr"), | |
| ("German", "de"), | |
| ("Spanish", "es"), | |
| ("Italian", "it"), | |
| ("Portuguese", "pt"), | |
| ("Chinese", "zh"), | |
| ("Japanese", "ja"), | |
| ("Korean", "ko"), | |
| ("Arabic", "ar"), | |
| ("Turkish", "tr"), | |
| ("Russian", "ru"), | |
| ("Dutch", "nl"), | |
| ("Polish", "pl"), | |
| ("Czech", "cs"), | |
| ("Hungarian", "hu"), | |
| ] | |
| ALLOWED_FORMATS = [ | |
| ".wav", ".mp3", ".flac", ".ogg", | |
| ".m4a", ".aac", ".wma", ".opus", | |
| ".mpeg", ".mp4" | |
| ] | |
| def convert_to_wav(audio_path): | |
| ext = Path(audio_path).suffix.lower() | |
| if ext == ".wav": | |
| return audio_path | |
| try: | |
| audio = AudioSegment.from_file(audio_path) | |
| wav_path = str(Path(audio_path).with_suffix(".wav")) | |
| audio.export(wav_path, format="wav") | |
| return wav_path | |
| except Exception as e: | |
| raise Exception(f"Could not convert audio to WAV: {str(e)}") | |
| def validate_audio(audio_path): | |
| if audio_path is None: | |
| return False, "Please upload a voice sample." | |
| ext = Path(audio_path).suffix.lower() | |
| if ext not in ALLOWED_FORMATS: | |
| return False, f"Unsupported format '{ext}'. Allowed: WAV, MP3, FLAC, OGG, M4A, AAC, WMA, OPUS, MPEG, MP4" | |
| # Try torchaudio first | |
| try: | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| duration = waveform.shape[1] / sample_rate | |
| if duration < 3: | |
| return False, f"Audio too short ({duration:.1f}s). Minimum 3 seconds." | |
| if duration > 30: | |
| return False, f"Audio too long ({duration:.1f}s). Maximum 30 seconds." | |
| return True, f"Audio valid | Format: {ext.upper()} | Duration: {duration:.1f}s" | |
| except Exception: | |
| pass | |
| # Fallback to pydub | |
| try: | |
| audio = AudioSegment.from_file(audio_path) | |
| duration = len(audio) / 1000 | |
| if duration < 3: | |
| return False, f"Audio too short ({duration:.1f}s). Minimum 3 seconds." | |
| if duration > 30: | |
| return False, f"Audio too long ({duration:.1f}s). Maximum 30 seconds." | |
| return True, f"Audio valid | Format: {ext.upper()} | Duration: {duration:.1f}s" | |
| except Exception as e: | |
| return False, f"Could not read audio file: {str(e)}" | |
| def validate_text(text): | |
| if not text or not text.strip(): | |
| return False, "Please enter some text." | |
| if len(text.strip()) < 5: | |
| return False, "Text too short." | |
| if len(text.strip()) > 1000: | |
| return False, "Text too long." | |
| return True, "OK" | |
| def get_history_html(): | |
| if not HISTORY: | |
| return "<p style='color:gray;'>No clones yet.</p>" | |
| rows = "" | |
| for i, h in enumerate(reversed(HISTORY[-10:]), 1): | |
| rows += ( | |
| f"<tr>" | |
| f"<td style='padding:8px;'>{i}</td>" | |
| f"<td style='padding:8px;'>{h['timestamp']}</td>" | |
| f"<td style='padding:8px;'>{h['text']}</td>" | |
| f"<td style='padding:8px;'>{h['language']}</td>" | |
| f"<td style='padding:8px;'>{h['duration']}</td>" | |
| f"<td style='padding:8px;'>{h['time_taken']}</td>" | |
| f"</tr>" | |
| ) | |
| return ( | |
| f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>" | |
| f"<thead><tr style='background:#6366f1;color:white;'>" | |
| f"<th style='padding:8px;'>#</th>" | |
| f"<th style='padding:8px;'>Time</th>" | |
| f"<th style='padding:8px;'>Text</th>" | |
| f"<th style='padding:8px;'>Language</th>" | |
| f"<th style='padding:8px;'>Duration</th>" | |
| f"<th style='padding:8px;'>Generated in</th>" | |
| f"</tr></thead><tbody>{rows}</tbody></table>" | |
| ) | |
| def clone_voice(text, speaker_audio, language, speed, progress=gr.Progress()): | |
| start_time = time.time() | |
| progress(0, desc="Validating...") | |
| text_ok, text_msg = validate_text(text) | |
| if not text_ok: | |
| return None, text_msg, get_history_html() | |
| audio_path = speaker_audio if isinstance(speaker_audio, str) else ( | |
| speaker_audio.name if speaker_audio is not None else None | |
| ) | |
| audio_ok, audio_msg = validate_audio(audio_path) | |
| if not audio_ok: | |
| return None, audio_msg, get_history_html() | |
| try: | |
| progress(0.2, desc="Converting audio...") | |
| speaker_wav = convert_to_wav(audio_path) | |
| progress(0.3, desc="Loading model...") | |
| model = get_tts() | |
| progress(0.5, desc="Cloning voice...") | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_path = str(OUTPUT_DIR / f"clone_{timestamp}.wav") | |
| model.tts_to_file( | |
| text=text.strip(), | |
| speaker_wav=speaker_wav, | |
| language=language, | |
| file_path=output_path, | |
| speed=speed | |
| ) | |
| progress(0.9, desc="Finalizing...") | |
| elapsed = time.time() - start_time | |
| waveform, sr = torchaudio.load(output_path) | |
| output_duration = waveform.shape[1] / sr | |
| lang_label = dict(LANGUAGES).get(language, language) | |
| HISTORY.append({ | |
| "timestamp": datetime.now().strftime("%d %b %Y %I:%M %p"), | |
| "text": text.strip()[:60] + "..." if len(text.strip()) > 60 else text.strip(), | |
| "language": lang_label, | |
| "duration": f"{output_duration:.1f}s", | |
| "time_taken": f"{elapsed:.1f}s", | |
| "path": output_path | |
| }) | |
| progress(1.0, desc="Done!") | |
| return output_path, f"Cloned in {elapsed:.1f}s | {output_duration:.1f}s output", get_history_html() | |
| except Exception as e: | |
| return None, f"Error: {str(e)}", get_history_html() | |
| def count_chars(text): | |
| return f"{len(text) if text else 0} / 1000" | |
| def audio_info(audio_path): | |
| if audio_path is None: | |
| return "No audio uploaded." | |
| path = audio_path if isinstance(audio_path, str) else audio_path.name | |
| _, msg = validate_audio(path) | |
| return msg | |
| def clear_all(): | |
| return "", None, "en", 1.0, None, "Cleared.", get_history_html() | |
| def clear_history(): | |
| HISTORY.clear() | |
| return get_history_html() | |
| def get_system_info(): | |
| info = f"Device: {device.upper()}\nModel: XTTS-v2\n" | |
| if torch.cuda.is_available(): | |
| info += f"GPU: {torch.cuda.get_device_name(0)}\n" | |
| info += f"Total clones: {len(HISTORY)}" | |
| return info | |
| with gr.Blocks(title="AI Voice Cloning Studio") as demo: | |
| gr.Markdown("# AI Voice Cloning Studio\n### Powered by XTTS-v2") | |
| with gr.Tabs(): | |
| with gr.Tab("Clone Voice"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Text to speak", | |
| placeholder="Type what you want the cloned voice to say...", | |
| lines=6, | |
| max_lines=10 | |
| ) | |
| char_display = gr.Textbox( | |
| label="Character count", | |
| value="0 / 1000", | |
| interactive=False | |
| ) | |
| audio_input = gr.File( | |
| label="Voice sample (3 to 30 seconds)", | |
| file_types=[ | |
| ".wav", ".mp3", ".flac", ".ogg", | |
| ".m4a", ".aac", ".wma", ".opus", | |
| ".mpeg", ".mp4" | |
| ], | |
| type="filepath" | |
| ) | |
| audio_status = gr.Textbox( | |
| label="Audio info", | |
| interactive=False, | |
| value="No audio uploaded." | |
| ) | |
| with gr.Row(): | |
| language = gr.Dropdown( | |
| choices=LANGUAGES, | |
| value="en", | |
| label="Language" | |
| ) | |
| speed = gr.Slider( | |
| minimum=0.5, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Speed" | |
| ) | |
| with gr.Row(): | |
| clone_btn = gr.Button("Clone Voice", variant="primary", size="lg", scale=3) | |
| clear_btn = gr.Button("Clear", variant="secondary", size="lg", scale=1) | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio( | |
| label="Cloned voice", | |
| type="filepath", | |
| interactive=False | |
| ) | |
| status_msg = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown( | |
| "### Tips\n" | |
| "- Use clear, noise-free audio\n" | |
| "- 5 to 10 seconds works best\n" | |
| "- Supported: WAV, MP3, FLAC, OGG, M4A, AAC, WMA, OPUS, MPEG, MP4\n" | |
| "- One speaker only" | |
| ) | |
| with gr.Tab("History"): | |
| history_display = gr.HTML(value="<p style='color:gray;'>No clones yet.</p>") | |
| clear_history_btn = gr.Button("Clear History", variant="stop") | |
| with gr.Tab("System Info"): | |
| system_info = gr.Textbox( | |
| value=get_system_info(), | |
| interactive=False, | |
| lines=6, | |
| label="System" | |
| ) | |
| refresh_btn = gr.Button("Refresh", variant="secondary") | |
| text_input.change(fn=count_chars, inputs=text_input, outputs=char_display) | |
| audio_input.change(fn=audio_info, inputs=audio_input, outputs=audio_status) | |
| clone_btn.click( | |
| fn=clone_voice, | |
| inputs=[text_input, audio_input, language, speed], | |
| outputs=[audio_output, status_msg, history_display] | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[text_input, audio_input, language, speed, audio_output, status_msg, history_display] | |
| ) | |
| clear_history_btn.click(fn=clear_history, inputs=[], outputs=[history_display]) | |
| refresh_btn.click(fn=get_system_info, inputs=[], outputs=[system_info]) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="purple") | |
| ) |