hysts's picture
hysts HF Staff
Add files
f0a5bff
import re
import sys
from pathlib import Path
import gradio as gr
import librosa
import numpy as np
import spaces
import torch
# Register audiodit model type with transformers
sys.path.insert(0, str(Path(__file__).resolve().parent / "vendor" / "LongCat-AudioDiT"))
import audiodit # noqa: F401
from audiodit import AudioDiTModel
from transformers import AutoTokenizer
# ---------------------------------------------------------------------------
# Text utilities (from upstream utils.py)
# ---------------------------------------------------------------------------
MAX_SEED = 2**32 - 1
EN_DUR_PER_CHAR = 0.082
ZH_DUR_PER_CHAR = 0.21
def normalize_text(text: str) -> str:
text = text.lower()
text = re.sub(r"[\u201c\u201d\u201e\u2018\u2019]", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def approx_duration_from_text(text: str, max_duration: float = 30.0) -> float:
text = re.sub(r"\s+", "", text)
num_zh = num_en = num_other = 0
for c in text:
if "\u4e00" <= c <= "\u9fff":
num_zh += 1
elif c.isalpha():
num_en += 1
else:
num_other += 1
if num_zh > num_en:
num_zh += num_other
else:
num_en += num_other
return min(max_duration, num_zh * ZH_DUR_PER_CHAR + num_en * EN_DUR_PER_CHAR)
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
MODEL_ID = "meituan-longcat/LongCat-AudioDiT-3.5B"
model = AudioDiTModel.from_pretrained(MODEL_ID).to("cuda")
model.vae.to_half()
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder_model)
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def get_seed(randomize_seed: bool, seed: int) -> int:
rng = np.random.default_rng()
return int(rng.integers(0, MAX_SEED)) if randomize_seed else seed
@spaces.GPU
def generate_tts(
text: str,
guidance_method: str,
nfe: int,
guidance_strength: float,
seed: int,
) -> tuple[int, np.ndarray]:
text = normalize_text(text)
if not text:
raise gr.Error("Text is empty (or contains only whitespace/quotes).")
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
sr = model.config.sampling_rate
full_hop = model.config.latent_hop
max_duration = model.config.max_wav_duration
inputs = tokenizer([text], padding="longest", return_tensors="pt")
dur_sec = approx_duration_from_text(text, max_duration=max_duration)
duration = int(dur_sec * sr // full_hop)
output = model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
duration=duration,
steps=nfe,
cfg_strength=guidance_strength,
guidance_method=guidance_method,
)
wav = output.waveform.squeeze().detach().cpu().numpy()
return (sr, wav)
@spaces.GPU
def generate_voice_clone(
text: str,
prompt_text: str,
prompt_audio: tuple[int, np.ndarray] | str | None,
guidance_method: str,
nfe: int,
guidance_strength: float,
seed: int,
) -> tuple[int, np.ndarray]:
if prompt_audio is None:
raise gr.Error("Prompt audio is required.")
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
sr = model.config.sampling_rate
full_hop = model.config.latent_hop
max_duration = model.config.max_wav_duration
# Load prompt audio — gr.Audio returns (sample_rate, ndarray)
input_sr, audio_np = prompt_audio
if audio_np.ndim > 1:
audio_np = audio_np.mean(axis=-1)
audio_np = audio_np.astype(np.float32)
if np.abs(audio_np).max() > 1.0:
audio_np = audio_np / np.abs(audio_np).max()
if input_sr != sr:
audio_np = librosa.resample(audio_np, orig_sr=input_sr, target_sr=sr)
prompt_wav = torch.from_numpy(audio_np).unsqueeze(0).unsqueeze(0) # (1, 1, T)
# encode_prompt_audio handles VAE padding/encoding/trimming internally
_, prompt_dur = model.encode_prompt_audio(prompt_wav)
# Text
text = normalize_text(text)
if not text:
raise gr.Error("Text is empty (or contains only whitespace/quotes).")
prompt_text = normalize_text(prompt_text)
if not prompt_text:
raise gr.Error("Prompt text is empty (or contains only whitespace/quotes).")
full_text = f"{prompt_text} {text}"
inputs = tokenizer([full_text], padding="longest", return_tensors="pt")
# Duration estimation
prompt_time = prompt_dur * full_hop / sr
dur_sec = approx_duration_from_text(text, max_duration=max_duration - prompt_time)
approx_pd = approx_duration_from_text(prompt_text, max_duration=max_duration)
ratio = np.clip(prompt_time / approx_pd, 1.0, 1.5)
dur_sec = dur_sec * ratio
duration = int(dur_sec * sr // full_hop)
duration = min(duration + prompt_dur, int(max_duration * sr // full_hop))
output = model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
prompt_audio=prompt_wav,
duration=duration,
steps=nfe,
cfg_strength=guidance_strength,
guidance_method=guidance_method,
)
wav = output.waveform.squeeze().detach().cpu().numpy()
return (sr, wav)
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# LongCat-AudioDiT")
gr.Markdown(
"Diffusion-based text-to-speech with zero-shot voice cloning. "
"Based on [meituan-longcat/LongCat-AudioDiT](https://github.com/meituan-longcat/LongCat-AudioDiT)."
)
with gr.Tabs():
with gr.Tab("TTS"):
with gr.Row():
with gr.Column():
tts_text = gr.Textbox(
label="Text",
lines=5,
placeholder="Enter text to synthesize...",
)
tts_btn = gr.Button("Generate")
with gr.Column():
tts_output = gr.Audio(label="Output")
gr.Examples(
examples=[
[
"She sells seashells by the seashore. The shells she sells are surely seashells. So if she sells shells on the seashore, I'm sure she sells seashore shells."
],
["今天晴暖转阴雨,空气质量优至良,空气相对湿度较低。"], # noqa: RUF001 — Chinese punctuation
],
inputs=tts_text,
)
with gr.Tab("Voice Cloning"):
with gr.Row():
with gr.Column():
vc_prompt_audio = gr.Audio(label="Prompt Audio", type="numpy")
vc_prompt_text = gr.Textbox(
label="Prompt Text",
lines=2,
placeholder="Transcription of the prompt audio...",
)
vc_text = gr.Textbox(
label="Text to Synthesize",
lines=3,
placeholder="Enter text to synthesize in the cloned voice...",
)
vc_btn = gr.Button("Generate")
with gr.Column():
vc_output = gr.Audio(label="Output")
with gr.Accordion("Advanced Settings", open=False):
guidance_method = gr.Radio(
label="Guidance",
choices=["cfg", "apg"],
value="cfg",
)
nfe = gr.Slider(label="NFE Steps", minimum=1, maximum=64, step=1, value=16)
guidance_strength = gr.Slider(
label="Guidance Strength",
minimum=0.0,
maximum=10.0,
step=0.1,
value=4.0,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=1024,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
tts_btn.click(
fn=get_seed,
inputs=[randomize_seed, seed],
outputs=seed,
queue=False,
).then(
fn=generate_tts,
inputs=[tts_text, guidance_method, nfe, guidance_strength, seed],
outputs=tts_output,
)
vc_btn.click(
fn=get_seed,
inputs=[randomize_seed, seed],
outputs=seed,
queue=False,
).then(
fn=generate_voice_clone,
inputs=[
vc_text,
vc_prompt_text,
vc_prompt_audio,
guidance_method,
nfe,
guidance_strength,
seed,
],
outputs=vc_output,
)
if __name__ == "__main__":
demo.launch()