Spaces:
Running on Zero
Running on Zero
| import re | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import spaces | |
| import torch | |
| # Register audiodit model type with transformers | |
| sys.path.insert(0, str(Path(__file__).resolve().parent / "vendor" / "LongCat-AudioDiT")) | |
| import audiodit # noqa: F401 | |
| from audiodit import AudioDiTModel | |
| from transformers import AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| # Text utilities (from upstream utils.py) | |
| # --------------------------------------------------------------------------- | |
| MAX_SEED = 2**32 - 1 | |
| EN_DUR_PER_CHAR = 0.082 | |
| ZH_DUR_PER_CHAR = 0.21 | |
| def normalize_text(text: str) -> str: | |
| text = text.lower() | |
| text = re.sub(r"[\u201c\u201d\u201e\u2018\u2019]", " ", text) | |
| text = re.sub(r"\s+", " ", text) | |
| return text.strip() | |
| def approx_duration_from_text(text: str, max_duration: float = 30.0) -> float: | |
| text = re.sub(r"\s+", "", text) | |
| num_zh = num_en = num_other = 0 | |
| for c in text: | |
| if "\u4e00" <= c <= "\u9fff": | |
| num_zh += 1 | |
| elif c.isalpha(): | |
| num_en += 1 | |
| else: | |
| num_other += 1 | |
| if num_zh > num_en: | |
| num_zh += num_other | |
| else: | |
| num_en += num_other | |
| return min(max_duration, num_zh * ZH_DUR_PER_CHAR + num_en * EN_DUR_PER_CHAR) | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "meituan-longcat/LongCat-AudioDiT-3.5B" | |
| model = AudioDiTModel.from_pretrained(MODEL_ID).to("cuda") | |
| model.vae.to_half() | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder_model) | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def get_seed(randomize_seed: bool, seed: int) -> int: | |
| rng = np.random.default_rng() | |
| return int(rng.integers(0, MAX_SEED)) if randomize_seed else seed | |
| def generate_tts( | |
| text: str, | |
| guidance_method: str, | |
| nfe: int, | |
| guidance_strength: float, | |
| seed: int, | |
| ) -> tuple[int, np.ndarray]: | |
| text = normalize_text(text) | |
| if not text: | |
| raise gr.Error("Text is empty (or contains only whitespace/quotes).") | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| sr = model.config.sampling_rate | |
| full_hop = model.config.latent_hop | |
| max_duration = model.config.max_wav_duration | |
| inputs = tokenizer([text], padding="longest", return_tensors="pt") | |
| dur_sec = approx_duration_from_text(text, max_duration=max_duration) | |
| duration = int(dur_sec * sr // full_hop) | |
| output = model( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| duration=duration, | |
| steps=nfe, | |
| cfg_strength=guidance_strength, | |
| guidance_method=guidance_method, | |
| ) | |
| wav = output.waveform.squeeze().detach().cpu().numpy() | |
| return (sr, wav) | |
| def generate_voice_clone( | |
| text: str, | |
| prompt_text: str, | |
| prompt_audio: tuple[int, np.ndarray] | str | None, | |
| guidance_method: str, | |
| nfe: int, | |
| guidance_strength: float, | |
| seed: int, | |
| ) -> tuple[int, np.ndarray]: | |
| if prompt_audio is None: | |
| raise gr.Error("Prompt audio is required.") | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| sr = model.config.sampling_rate | |
| full_hop = model.config.latent_hop | |
| max_duration = model.config.max_wav_duration | |
| # Load prompt audio — gr.Audio returns (sample_rate, ndarray) | |
| input_sr, audio_np = prompt_audio | |
| if audio_np.ndim > 1: | |
| audio_np = audio_np.mean(axis=-1) | |
| audio_np = audio_np.astype(np.float32) | |
| if np.abs(audio_np).max() > 1.0: | |
| audio_np = audio_np / np.abs(audio_np).max() | |
| if input_sr != sr: | |
| audio_np = librosa.resample(audio_np, orig_sr=input_sr, target_sr=sr) | |
| prompt_wav = torch.from_numpy(audio_np).unsqueeze(0).unsqueeze(0) # (1, 1, T) | |
| # encode_prompt_audio handles VAE padding/encoding/trimming internally | |
| _, prompt_dur = model.encode_prompt_audio(prompt_wav) | |
| # Text | |
| text = normalize_text(text) | |
| if not text: | |
| raise gr.Error("Text is empty (or contains only whitespace/quotes).") | |
| prompt_text = normalize_text(prompt_text) | |
| if not prompt_text: | |
| raise gr.Error("Prompt text is empty (or contains only whitespace/quotes).") | |
| full_text = f"{prompt_text} {text}" | |
| inputs = tokenizer([full_text], padding="longest", return_tensors="pt") | |
| # Duration estimation | |
| prompt_time = prompt_dur * full_hop / sr | |
| dur_sec = approx_duration_from_text(text, max_duration=max_duration - prompt_time) | |
| approx_pd = approx_duration_from_text(prompt_text, max_duration=max_duration) | |
| ratio = np.clip(prompt_time / approx_pd, 1.0, 1.5) | |
| dur_sec = dur_sec * ratio | |
| duration = int(dur_sec * sr // full_hop) | |
| duration = min(duration + prompt_dur, int(max_duration * sr // full_hop)) | |
| output = model( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| prompt_audio=prompt_wav, | |
| duration=duration, | |
| steps=nfe, | |
| cfg_strength=guidance_strength, | |
| guidance_method=guidance_method, | |
| ) | |
| wav = output.waveform.squeeze().detach().cpu().numpy() | |
| return (sr, wav) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# LongCat-AudioDiT") | |
| gr.Markdown( | |
| "Diffusion-based text-to-speech with zero-shot voice cloning. " | |
| "Based on [meituan-longcat/LongCat-AudioDiT](https://github.com/meituan-longcat/LongCat-AudioDiT)." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("TTS"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| tts_text = gr.Textbox( | |
| label="Text", | |
| lines=5, | |
| placeholder="Enter text to synthesize...", | |
| ) | |
| tts_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| tts_output = gr.Audio(label="Output") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "She sells seashells by the seashore. The shells she sells are surely seashells. So if she sells shells on the seashore, I'm sure she sells seashore shells." | |
| ], | |
| ["今天晴暖转阴雨,空气质量优至良,空气相对湿度较低。"], # noqa: RUF001 — Chinese punctuation | |
| ], | |
| inputs=tts_text, | |
| ) | |
| with gr.Tab("Voice Cloning"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| vc_prompt_audio = gr.Audio(label="Prompt Audio", type="numpy") | |
| vc_prompt_text = gr.Textbox( | |
| label="Prompt Text", | |
| lines=2, | |
| placeholder="Transcription of the prompt audio...", | |
| ) | |
| vc_text = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=3, | |
| placeholder="Enter text to synthesize in the cloned voice...", | |
| ) | |
| vc_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| vc_output = gr.Audio(label="Output") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| guidance_method = gr.Radio( | |
| label="Guidance", | |
| choices=["cfg", "apg"], | |
| value="cfg", | |
| ) | |
| nfe = gr.Slider(label="NFE Steps", minimum=1, maximum=64, step=1, value=16) | |
| guidance_strength = gr.Slider( | |
| label="Guidance Strength", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=4.0, | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=1024, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| tts_btn.click( | |
| fn=get_seed, | |
| inputs=[randomize_seed, seed], | |
| outputs=seed, | |
| queue=False, | |
| ).then( | |
| fn=generate_tts, | |
| inputs=[tts_text, guidance_method, nfe, guidance_strength, seed], | |
| outputs=tts_output, | |
| ) | |
| vc_btn.click( | |
| fn=get_seed, | |
| inputs=[randomize_seed, seed], | |
| outputs=seed, | |
| queue=False, | |
| ).then( | |
| fn=generate_voice_clone, | |
| inputs=[ | |
| vc_text, | |
| vc_prompt_text, | |
| vc_prompt_audio, | |
| guidance_method, | |
| nfe, | |
| guidance_strength, | |
| seed, | |
| ], | |
| outputs=vc_output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |