""" LongCat-AudioDiT Enhanced – Gradio Web UI Primary workflow: Voice Cloning 1. Upload reference audio → auto-transcribe with Whisper 2. Type text to synthesise in the cloned voice 3. Generate → save to Voice Library with a name 4. Reuse any saved voice from the dropdown All actions are exposed as Gradio REST API endpoints. Usage: python app.py python app.py --port 7860 --share python app.py --device cpu """ import argparse import logging import os import socket import time from pathlib import Path import gradio as gr import numpy as np import soundfile as sf import torch import torch.nn.functional as F from utils import normalize_text, load_audio, approx_duration_from_text from memory_manager import ModelMemoryManager from voice_library import get_library from download_models import ( download_audiodit, download_whisper, _audiodit_present, _whisper_present, AUDIODIT_MODELS, WHISPER_MODELS, AUDIODIT_DIR, WHISPER_DIR, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) OUTPUT_DIR = Path(__file__).parent / "outputs" OUTPUT_DIR.mkdir(exist_ok=True) # --------------------------------------------------------------------------- # Memory manager # --------------------------------------------------------------------------- _mgr: ModelMemoryManager = None def get_manager(mode: str = "auto") -> ModelMemoryManager: global _mgr if _mgr is None or _mgr.mode.value != mode: if _mgr is not None: _mgr.release_all() _mgr = ModelMemoryManager(mode=mode) return _mgr # --------------------------------------------------------------------------- # Port helpers # --------------------------------------------------------------------------- def _port_free(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(1) return s.connect_ex(("127.0.0.1", port)) != 0 def find_free_port(start: int = 7860, end: int = 7960) -> int: for p in range(start, end): if _port_free(p): return p raise RuntimeError(f"No free port found in {start}-{end}") # --------------------------------------------------------------------------- # Core: transcribe reference audio # --------------------------------------------------------------------------- def transcribe_reference(audio_path, whisper_size: str, language: str, memory_mode: str, device: str): """ Transcribe a reference audio file with Whisper. Returns (transcription_text, status_msg). """ if audio_path is None: return "", "Upload a reference audio file first." mgr = get_manager(memory_mode) try: whisper = mgr.get_whisper(whisper_size=whisper_size) except Exception as e: return "", f"Failed to load Whisper: {e}" lang_arg = language if language and language != "auto" else None try: text, detected = whisper.transcribe(str(audio_path), language=lang_arg) except Exception as e: return "", f"Transcription failed: {e}" return text, f"Transcribed [{detected}] — {len(text)} characters" # --------------------------------------------------------------------------- # Core: clone voice (reference audio + transcription → new speech) # --------------------------------------------------------------------------- def clone_voice( text: str, ref_audio_path, ref_transcription: str, audiodit_size: str, nfe: int, guidance_strength: float, guidance_method: str, seed: int, memory_mode: str, device: str, ): """ Synthesise `text` in the voice captured from `ref_audio_path`. Returns (output_audio_path, status_msg). """ if not text or not text.strip(): return None, "Enter text to synthesise." if ref_audio_path is None: return None, "Upload a reference audio file." if not ref_transcription or not ref_transcription.strip(): return None, "Reference transcription is empty. Use 'Auto-Transcribe' first." mgr = get_manager(memory_mode) try: model, tokenizer = mgr.get_tts(audiodit_size=audiodit_size, device=device) except Exception as e: return None, f"Failed to load TTS model: {e}" torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) sr = model.config.sampling_rate full_hop = model.config.latent_hop max_dur = model.config.max_wav_duration synth_text = normalize_text(text) ref_text = normalize_text(ref_transcription) full_text = f"{ref_text} {synth_text}" inputs = tokenizer([full_text], padding="longest", return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Encode reference audio to get prompt duration try: off = 3 pw = load_audio(str(ref_audio_path), sr) if pw.shape[-1] % full_hop != 0: pw = F.pad(pw, (0, full_hop - pw.shape[-1] % full_hop)) pw_padded = F.pad(pw, (0, full_hop * off)) with torch.no_grad(): plt = model.vae.encode(pw_padded.unsqueeze(0).to(device)) if off: plt = plt[..., :-off] prompt_dur = plt.shape[-1] prompt_wav = load_audio(str(ref_audio_path), sr).unsqueeze(0) except Exception as e: return None, f"Failed to process reference audio: {e}" prompt_time = prompt_dur * full_hop / sr dur_sec = approx_duration_from_text(synth_text, max_duration=max_dur - prompt_time) try: approx_pd = approx_duration_from_text(ref_text, max_duration=max_dur) ratio = np.clip(prompt_time / approx_pd, 1.0, 1.5) dur_sec = dur_sec * ratio except Exception: pass duration = int(dur_sec * sr // full_hop) duration = min(duration + prompt_dur, int(max_dur * sr // full_hop)) try: with torch.no_grad(): output = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], prompt_audio=prompt_wav, duration=duration, steps=nfe, cfg_strength=guidance_strength, guidance_method=guidance_method, ) except Exception as e: return None, f"Generation failed: {e}" wav = output.waveform.squeeze().detach().cpu().numpy() out_path = OUTPUT_DIR / f"clone_{int(time.time())}.wav" sf.write(str(out_path), wav, sr) return str(out_path), f"Done — {len(wav)/sr:.2f}s generated" # --------------------------------------------------------------------------- # Core: plain TTS (no reference voice) # --------------------------------------------------------------------------- def plain_tts( text: str, audiodit_size: str, nfe: int, guidance_strength: float, guidance_method: str, seed: int, memory_mode: str, device: str, ): """Synthesise text with no voice reference (random voice).""" if not text or not text.strip(): return None, "Enter text to synthesise." mgr = get_manager(memory_mode) try: model, tokenizer = mgr.get_tts(audiodit_size=audiodit_size, device=device) except Exception as e: return None, f"Failed to load TTS model: {e}" torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) sr = model.config.sampling_rate full_hop = model.config.latent_hop max_dur = model.config.max_wav_duration t = normalize_text(text) inputs = tokenizer([t], padding="longest", return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} dur_sec = approx_duration_from_text(t, max_duration=max_dur) duration = int(dur_sec * sr // full_hop) duration = min(duration, int(max_dur * sr // full_hop)) try: with torch.no_grad(): output = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], prompt_audio=None, duration=duration, steps=nfe, cfg_strength=guidance_strength, guidance_method=guidance_method, ) except Exception as e: return None, f"Generation failed: {e}" wav = output.waveform.squeeze().detach().cpu().numpy() out_path = OUTPUT_DIR / f"tts_{int(time.time())}.wav" sf.write(str(out_path), wav, sr) return str(out_path), f"Done — {len(wav)/sr:.2f}s generated" # --------------------------------------------------------------------------- # Voice Library helpers (called from UI) # --------------------------------------------------------------------------- def library_names_with_placeholder() -> list[str]: lib = get_library() names = lib.names() return ["— select saved voice —"] + names def save_voice_to_library(name: str, audio_path, transcription: str): """Save a (audio, transcription) pair to the library. Returns (new_dropdown, status).""" name = (name or "").strip() if not name: return gr.update(), "Enter a name for this voice." if audio_path is None: return gr.update(), "No reference audio to save." if not transcription or not transcription.strip(): return gr.update(), "Transcription is empty — auto-transcribe first." try: get_library().add(name, str(audio_path), transcription) except Exception as e: return gr.update(), f"Save failed: {e}" choices = library_names_with_placeholder() return gr.update(choices=choices, value=name), f"Saved '{name}' to voice library." def load_voice_from_library(name: str): """Load a saved voice. Returns (audio_path, transcription, status).""" if not name or name.startswith("—"): return None, "", "" entry = get_library().get(name) if entry is None: return None, "", f"Voice '{name}' not found." audio = entry["audio_path"] if not Path(audio).exists(): return None, "", f"Audio file missing: {audio}" return audio, entry["transcription"], f"Loaded '{name}'" def delete_voice_from_library(name: str): """Delete a voice. Returns (new_dropdown_update, status).""" if not name or name.startswith("—"): return gr.update(), "Select a voice to delete." ok = get_library().remove(name) choices = library_names_with_placeholder() msg = f"Deleted '{name}'." if ok else f"Voice '{name}' not found." return gr.update(choices=choices, value=choices[0]), msg def refresh_library_dropdown(): choices = library_names_with_placeholder() return gr.update(choices=choices) def library_summary(): return get_library().summary_text() # --------------------------------------------------------------------------- # Status / unload # --------------------------------------------------------------------------- def get_status(memory_mode: str) -> str: return get_manager(memory_mode).status_str() def unload_all(memory_mode: str) -> str: mgr = get_manager(memory_mode) mgr.release_all() return "All models unloaded.\n" + mgr.status_str() # --------------------------------------------------------------------------- # Download helpers # --------------------------------------------------------------------------- def _model_inventory() -> str: lines = ["AudioDiT TTS models:"] for k, (repo, hint) in AUDIODIT_MODELS.items(): st = "[downloaded]" if _audiodit_present(k) else "not downloaded" lines.append(f" AudioDiT-{k:<6} {hint:<8} {st}") lines.append("") lines.append("Whisper STT models:") for k, (repo, hint) in WHISPER_MODELS.items(): st = "[downloaded]" if _whisper_present(k) else "not downloaded" lines.append(f" Whisper-{k:<10} {hint:<8} {st}") return "\n".join(lines) def download_with_progress(selected_models: list): if not selected_models: yield "Nothing selected." return log = [] def emit(msg): log.append(msg) for label in selected_models: if label.startswith("AudioDiT-"): size = label.replace("AudioDiT-", "") _, hint = AUDIODIT_MODELS.get(size, ("", "?")) log.append(f"AudioDiT-{size} ({hint}): {'already downloaded' if _audiodit_present(size) else 'downloading...'}"); yield "\n".join(log) download_audiodit(size, callback=emit); yield "\n".join(log) elif label.startswith("Whisper-"): size = label.replace("Whisper-", "") _, hint = WHISPER_MODELS.get(size, ("", "?")) log.append(f"Whisper-{size} ({hint}): {'already downloaded' if _whisper_present(size) else 'downloading...'}"); yield "\n".join(log) download_whisper(size, callback=emit); yield "\n".join(log) log.extend(["", _model_inventory()]) yield "\n".join(log) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def build_ui(default_device: str = "cuda"): AUDIODIT_CHOICES = ["1B", "3.5B"] WHISPER_CHOICES = ["turbo", "large-v3", "medium", "small"] MEMORY_MODES = ["auto", "simultaneous", "sequential"] GUIDANCE_METHODS = ["cfg", "apg"] LANGUAGE_CHOICES = [ "auto", "en", "zh", "ja", "ko", "de", "fr", "es", "pt", "ru", "ar", "hi", "it", "nl", "pl", "tr", "uk", "vi", "id", "th", ] with gr.Blocks(title="LongCat-AudioDiT — Voice Cloning") as demo: gr.Markdown( "# LongCat-AudioDiT — Voice Cloning Studio\n" "State-of-the-art voice cloning based on " "[LongCat-AudioDiT](https://github.com/meituan-longcat/LongCat-AudioDiT) " "by the Meituan LongCat Team. " "Give it a reference audio, type your text, get the result." ) gr.HTML( '