"""
Gradio app for TADA inference (English-only, single model).
Usage:
pip install hume-tada
python app.py
# or with hot reload + share link:
GRADIO_SHARE=1 gradio app.py
"""
import html
import logging
import os
import shutil
import tempfile
import time
import torch
import torchaudio
import gradio as gr
try:
import spaces
gpu_decorator = spaces.GPU
except ImportError:
gpu_decorator = lambda fn=None, **kw: fn if fn else (lambda f: f)
from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402
from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402
from tada.utils.text import normalize_text as normalize_text_fn # noqa: E402
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Preset samples & transcripts (English only)
# ---------------------------------------------------------------------------
_script_dir = os.path.dirname(os.path.abspath(__file__))
_SAMPLES_DIR = os.path.join(_script_dir, "samples")
_AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac")
def _discover_preset_samples() -> dict[str, str]:
"""Return {display_name: absolute_path} for audio files in samples/en/."""
presets: dict[str, str] = {}
search_dir = os.path.join(_SAMPLES_DIR, "en")
if not os.path.isdir(search_dir):
return presets
for fname in sorted(os.listdir(search_dir)):
if fname.lower().endswith(_AUDIO_EXTENSIONS):
presets[fname] = os.path.join(search_dir, fname)
return presets
def _load_preset_transcripts() -> dict[str, str]:
"""Load preset transcripts from synth_transcripts.json."""
import json
candidate = os.path.join(_SAMPLES_DIR, "en", "synth_transcripts.json")
if os.path.isfile(candidate):
with open(candidate) as f:
return json.load(f)
return {}
def _load_prompt_transcripts() -> dict[str, str]:
"""Load prompt transcripts from prompt_transcripts.json."""
import json
candidate = os.path.join(_SAMPLES_DIR, "en", "prompt_transcripts.json")
if os.path.isfile(candidate):
with open(candidate) as f:
return json.load(f)
return {}
_PRESET_SAMPLES = _discover_preset_samples()
_PRESET_TRANSCRIPTS = _load_preset_transcripts()
_PROMPT_TRANSCRIPTS = _load_prompt_transcripts()
logger.info("Discovered %d preset audio samples, %d transcripts", len(_PRESET_SAMPLES), len(_PRESET_TRANSCRIPTS))
# ---------------------------------------------------------------------------
# Global model state — single model, single encoder
# ---------------------------------------------------------------------------
_MODEL_NAME = "HumeAI/tada-3b-ml"
_device = "cuda"
def _validate_no_meta_tensors(model, name: str = "model"):
"""Raise if any parameter is on the meta device (not materialised)."""
for param_name, param in model.named_parameters():
if param.device.type == "meta":
raise RuntimeError(
f"{name} has meta-device parameter: {param_name}. "
"Pass low_cpu_mem_usage=False to from_pretrained()."
)
logger.info("Loading encoder ...")
_encoder = Encoder.from_pretrained("HumeAI/tada-codec", language=None, low_cpu_mem_usage=False).to(_device)
_validate_no_meta_tensors(_encoder, "Encoder")
logger.info("Loading %s ...", _MODEL_NAME)
_model = TadaForCausalLM.from_pretrained(_MODEL_NAME, low_cpu_mem_usage=False)
_validate_no_meta_tensors(_model, "TadaForCausalLM")
logger.info("Models loaded.")
# ---------------------------------------------------------------------------
# Core inference helpers
# ---------------------------------------------------------------------------
def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]:
"""Decode a list of token IDs into per-token strings, handling multi-byte characters."""
labels: list[str] = []
for i in range(len(token_ids)):
prefix = tokenizer.decode(token_ids[:i], skip_special_tokens=True)
full = tokenizer.decode(token_ids[: i + 1], skip_special_tokens=True)
token_str = full[len(prefix) :]
labels.append(token_str)
return labels
def _format_token_alignment(prompt: EncoderOutput) -> str:
"""Build an HTML string: dots in grey, tokens as bold coloured spans."""
if prompt.text_tokens is None or prompt.token_positions is None:
return ""
tokenizer = _encoder.tokenizer
n_tokens = (
int(prompt.text_tokens_len[0].item()) if prompt.text_tokens_len is not None else prompt.text_tokens.shape[1]
)
token_ids = prompt.text_tokens[0, :n_tokens].cpu().tolist()
positions = prompt.token_positions[0, :n_tokens].cpu().long().tolist()
labels = _decode_tokens_individually(tokenizer, token_ids)
audio_dur = prompt.audio.shape[-1] / prompt.sample_rate if prompt.audio.numel() > 0 else 0.0
header = f"{n_tokens} tokens | {audio_dur:.2f}s audio"
parts: list[str] = []
prev_pos = 0
for pos, label in zip(positions, labels):
gap = max(0, pos - prev_pos)
if gap > 0:
parts.append(f'{"." * gap}')
escaped = html.escape(label)
parts.append(
f'{escaped}'
)
prev_pos = pos + 1
body = "".join(parts)
return (
f'
"
)
def _decode_byte_tokens(raw_tokens: list[str]) -> list[str]:
"""Decode GPT-2 byte-level token strings into proper Unicode per-token labels."""
if not raw_tokens:
return raw_tokens
try:
tokenizer = _model.tokenizer
token_ids = tokenizer.convert_tokens_to_ids(raw_tokens)
return _decode_tokens_individually(tokenizer, token_ids)
except Exception:
return [t.replace("\u0120", " ") for t in raw_tokens]
def _format_step_logs(step_logs: list[dict], audio_duration: float, wall_time: float) -> str:
"""Build an HTML string from step_logs: dots for n_frames_before, tokens highlighted."""
if not step_logs:
return ""
n_tokens = len(step_logs)
total_frames = sum(entry.get("n_frames_before", 0) for entry in step_logs)
rtf = wall_time / audio_duration if audio_duration > 0 else float("inf")
header = f"{n_tokens} steps | {audio_duration:.1f}s audio | {total_frames} frames | {wall_time:.1f}s wall | RTF {rtf:.2f}"
raw_tokens = [entry.get("token", "") for entry in step_logs]
labels = _decode_byte_tokens(raw_tokens)
parts: list[str] = []
for entry, label in zip(step_logs, labels):
n_frames = entry.get("n_frames_before", 0)
if n_frames > 0:
parts.append(f'{"." * n_frames}')
escaped = html.escape(label)
parts.append(
f'{escaped}'
)
body = "".join(parts)
return (
f'"
)
# ---------------------------------------------------------------------------
# Single generate function (merged prompt encoding + generation)
# ---------------------------------------------------------------------------
@gpu_decorator(duration=120)
@torch.inference_mode()
def generate(
audio_path: str | None,
text: str,
num_extra_steps: float = 0,
noise_temperature: float = 0.9,
acoustic_cfg_scale: float = 2.0,
duration_cfg_scale: float = 2.0,
num_flow_matching_steps: float = 20,
negative_condition_source: str = "negative_step_output",
text_only_logit_scale: float = 0.0,
num_acoustic_candidates: float = 1,
scorer: str = "likelihood",
spkr_verification_weight: float = 1.0,
speed_up_factor: float = 0.0,
normalize_text: bool = True,
) -> tuple[str | None, str, str]:
"""Encode prompt + generate speech in a single GPU call.
Returns (wav_path, prompt_alignment_html, generated_alignment_html).
"""
# Move model + encoder to GPU
_encoder.to(_device)
_model.to(_device)
_model.decoder.to(_device)
# --- Encode prompt ---
if audio_path is None or audio_path == "":
prompt = EncoderOutput.empty(_device)
prompt_html = "No audio provided (zero-shot mode)."
else:
audio, sample_rate = torchaudio.load(audio_path)
audio = audio.mean(dim=0, keepdim=True) # mono
audio = audio / audio.abs().max().clamp(min=1e-8) * 0.95
audio = audio.to(_device)
# Look up prompt transcript for preset samples
prompt_text = None
if audio_path:
audio_fname = os.path.basename(audio_path)
for key in (audio_fname, audio_fname.replace("tada_preset_", "")):
if key in _PROMPT_TRANSCRIPTS:
prompt_text = _PROMPT_TRANSCRIPTS[key]
break
text_kwarg = [prompt_text] if prompt_text else None
prompt = _encoder(audio, text=text_kwarg, sample_rate=sample_rate)
prompt_html = _format_token_alignment(prompt)
# --- Generate speech ---
try:
logger.info("Generating speech for text: %s", text)
suf = float(speed_up_factor) if speed_up_factor > 0 else None
t0 = time.time()
output = _model.generate(
prompt=prompt,
text=text,
num_transition_steps=0,
num_extra_steps=int(num_extra_steps),
normalize_text=normalize_text,
inference_options=InferenceOptions(
acoustic_cfg_scale=float(acoustic_cfg_scale),
duration_cfg_scale=float(duration_cfg_scale),
num_flow_matching_steps=int(num_flow_matching_steps),
noise_temperature=float(noise_temperature),
speed_up_factor=suf,
time_schedule="logsnr",
negative_condition_source=negative_condition_source,
text_only_logit_scale=float(text_only_logit_scale),
num_acoustic_candidates=int(num_acoustic_candidates),
scorer=scorer,
spkr_verification_weight=float(spkr_verification_weight),
),
system_prompt="",
)
wall_time = time.time() - t0
wav = output.audio[0].detach().cpu().float()
if wav.dim() == 1:
wav = wav.unsqueeze(0)
tmp_path = os.path.join(tempfile.gettempdir(), f"tada_output_{id(output)}.wav")
torchaudio.save(tmp_path, wav, 24_000)
audio_duration = wav.shape[-1] / 24_000
# Extract text-to-speak step_logs
all_logs = output.step_logs or []
if text and output.input_text_ids is not None:
input_ids = output.input_text_ids[0]
seq_len = input_ids.shape[0]
n_eos = _model.config.shift_acoustic
normalized = normalize_text_fn(text) if normalize_text else text
n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False))
text_end = seq_len - n_eos
text_start = text_end - n_text_tokens
log_by_step = {e["step"]: e for e in all_logs}
text_logs = []
for s in range(text_start, text_end):
if s in log_by_step:
text_logs.append(log_by_step[s])
else:
token_id = input_ids[s].item()
token_str = _model.tokenizer.convert_ids_to_tokens([token_id])[0]
text_logs.append({
"step": s,
"token": token_str,
"n_frames_before": 0,
"n_frames_after": 0,
"n_frames_src": "prefilled",
"acoustic_mask": 1,
"acoustic_feat_src": "prefilled",
"acoustic_feat_norm": 0.0,
})
generated_logs = text_logs
else:
generated_logs = all_logs
generated_html = _format_step_logs(generated_logs, audio_duration, wall_time)
return tmp_path, prompt_html, generated_html
except gr.Error:
raise
except Exception as e:
logger.exception("Generation failed")
raise gr.Error(f"Generation failed: {e}")
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_ui() -> gr.Blocks:
with gr.Blocks(
title="TADA Inference",
css=(
".gradio-container { max-width: 1400px !important; width: 100% !important; margin: auto !important; } "
".compact-audio { min-height: 0 !important; } "
".compact-audio audio { height: 36px !important; } "
),
) as demo:
gr.Markdown(
"# TADA - Text-Acoustic Dual Alignment LLM\n"
"A demo of **tada-3b-ml** \u2014 "
"a text-to-speech model that clones voice, emotion, and timing from a short audio prompt.\n\n"
"**How to use:** Choose a voice prompt (or upload your own), enter text, and click **Generate**. "
"The model will encode the prompt and generate speech in one step."
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
with gr.Accordion("Text Settings", open=False):
num_extra_steps = gr.Slider(
minimum=0, maximum=200, value=0, step=1,
label="Text Tokens to Generate",
)
text_only_logit_scale = gr.Slider(
minimum=0.0, maximum=5.0, value=0.0, step=0.1,
label="Text-Only Logit Scale",
info="0 = disabled. Blends text-only logits with audio-conditioned logits.",
)
normalize_text_cb = gr.Checkbox(
value=True,
label="Normalize Text",
info="Apply text normalization before generation",
)
with gr.Accordion("Acoustic Settings", open=False):
acoustic_cfg_scale = gr.Slider(
minimum=1.0, maximum=3.0, value=1.6, step=0.1,
label="Acoustic CFG Scale",
)
duration_cfg_scale = gr.Slider(
minimum=1.0, maximum=3.0, value=1.0, step=0.1,
label="Duration CFG Scale",
)
negative_condition_source = gr.Dropdown(
choices=["negative_step_output", "prompt", "zero"],
value="negative_step_output",
label="Negative Condition Source",
)
noise_temperature = gr.Slider(
minimum=0.4, maximum=1.2, value=0.9, step=0.1,
label="Noise Temperature",
)
num_flow_matching_steps = gr.Slider(
minimum=5, maximum=50, value=20, step=5,
label="Flow Matching Steps",
)
speed_up_factor = gr.Slider(
minimum=0.0, maximum=3.0, value=0.0, step=0.1,
label="Speed Up Factor",
info="0 = disabled (natural duration). >0 scales speech speed.",
)
num_acoustic_candidates = gr.Slider(
minimum=1, maximum=16, value=1, step=1,
label="Acoustic Candidates",
info="Number of candidates to generate and rank.",
)
scorer_dropdown = gr.Dropdown(
choices=["likelihood", "spkr_verification", "duration_median"],
value="likelihood",
label="Scorer",
info="How to rank acoustic candidates.",
)
spkr_verification_weight = gr.Slider(
minimum=0.0, maximum=5.0, value=1.0, step=0.1,
label="Speaker Verification Weight",
info="Weight for spkr_verification scorer.",
)
with gr.Column(scale=2):
preset_choices = ["None (zero-shot)"] + list(_PRESET_SAMPLES.keys())
_default_voice = "fb_ears_emo_amusement_freeform.wav"
preset_dropdown = gr.Dropdown(
choices=preset_choices,
value=_default_voice if _default_voice in _PRESET_SAMPLES else "None (zero-shot)",
label="Voice Prompt",
info="Pick a preset or upload / record your own",
)
_default_voice_path = _PRESET_SAMPLES.get(_default_voice)
audio_input = gr.Audio(
label="Prompt Preview",
type="filepath",
sources=["upload", "microphone"],
value=_default_voice_path,
elem_classes=["compact-audio"],
)
def _on_preset_selected(choice: str) -> str | None:
if choice == "None (zero-shot)":
return None
path = _PRESET_SAMPLES.get(choice)
if path is None:
return None
tmp_path = os.path.join(tempfile.gettempdir(), f"tada_preset_{choice}")
shutil.copy2(path, tmp_path)
return tmp_path
preset_dropdown.change(
fn=_on_preset_selected,
inputs=[preset_dropdown],
outputs=[audio_input],
)
with gr.Accordion("Prompt Token Alignment", open=True):
prompt_alignment = gr.HTML(value="Generate to see prompt alignment.")
with gr.Column(scale=2):
_default_transcript = "emo_interest_sentences"
transcript_choices = ["(custom)"] + list(_PRESET_TRANSCRIPTS.keys())
transcript_dropdown = gr.Dropdown(
choices=transcript_choices,
value=_default_transcript if _default_transcript in _PRESET_TRANSCRIPTS else "(custom)",
label="Transcript",
info="Pick a preset or type your own below",
)
text_input = gr.Textbox(
label="Text to Speak",
placeholder="Type what you want the model to say ...",
autoscroll=False,
max_lines=20,
value=_PRESET_TRANSCRIPTS.get(_default_transcript, ""),
)
def _on_transcript_selected(choice: str) -> str:
if choice == "(custom)":
return ""
return _PRESET_TRANSCRIPTS.get(choice, "")
transcript_dropdown.change(
fn=_on_transcript_selected,
inputs=[transcript_dropdown],
outputs=[text_input],
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
# --- Output ---
audio_output = gr.Audio(label="Generated Audio")
with gr.Accordion("Generated Alignment", open=False):
generated_text_display = gr.HTML(value="Generate speech to see the alignment")
# Wire up generate button
all_inputs = [
audio_input,
text_input,
num_extra_steps,
noise_temperature,
acoustic_cfg_scale,
duration_cfg_scale,
num_flow_matching_steps,
negative_condition_source,
text_only_logit_scale,
num_acoustic_candidates,
scorer_dropdown,
spkr_verification_weight,
speed_up_factor,
normalize_text_cb,
]
generate_btn.click(
fn=generate,
inputs=all_inputs,
outputs=[audio_output, prompt_alignment, generated_text_display],
)
return demo
# ---------------------------------------------------------------------------
# Entry-point
# ---------------------------------------------------------------------------
_share = os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes")
_port = int(os.environ.get("GRADIO_PORT", "7860"))
# `demo` at module scope so the `gradio` CLI / HF Spaces can discover it.
demo = build_ui()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="TADA Inference Gradio App")
parser.add_argument("--share", action="store_true", default=_share, help="Create a public Gradio share link")
parser.add_argument("--port", type=int, default=_port, help="Server port (default: 7860)")
args = parser.parse_args()
demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=[_SAMPLES_DIR])
else:
demo.launch(server_name="0.0.0.0", server_port=_port, share=_share, allowed_paths=[_SAMPLES_DIR])