| """ |
| Gradio app for TADA inference (English-only, single model). |
| |
| Usage: |
| pip install hume-tada |
| python app.py |
| # or with hot reload + share link: |
| GRADIO_SHARE=1 gradio app.py |
| """ |
|
|
| import html |
| import logging |
| import os |
| import shutil |
| import tempfile |
| import time |
|
|
| import torch |
| import torchaudio |
|
|
| import gradio as gr |
|
|
| try: |
| import spaces |
|
|
| gpu_decorator = spaces.GPU |
| except ImportError: |
| gpu_decorator = lambda fn=None, **kw: fn if fn else (lambda f: f) |
|
|
| from tada.modules.encoder import Encoder, EncoderOutput |
| from tada.modules.tada import InferenceOptions, TadaForCausalLM |
| from tada.utils.text import normalize_text as normalize_text_fn |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| _script_dir = os.path.dirname(os.path.abspath(__file__)) |
| _SAMPLES_DIR = os.path.join(_script_dir, "samples") |
|
|
| _AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac") |
|
|
|
|
| def _discover_preset_samples() -> dict[str, str]: |
| """Return {display_name: absolute_path} for audio files in samples/en/.""" |
| presets: dict[str, str] = {} |
| search_dir = os.path.join(_SAMPLES_DIR, "en") |
| if not os.path.isdir(search_dir): |
| return presets |
| for fname in sorted(os.listdir(search_dir)): |
| if fname.lower().endswith(_AUDIO_EXTENSIONS): |
| presets[fname] = os.path.join(search_dir, fname) |
| return presets |
|
|
|
|
| def _load_preset_transcripts() -> dict[str, str]: |
| """Load preset transcripts from synth_transcripts.json.""" |
| import json |
| candidate = os.path.join(_SAMPLES_DIR, "en", "synth_transcripts.json") |
| if os.path.isfile(candidate): |
| with open(candidate) as f: |
| return json.load(f) |
| return {} |
|
|
|
|
| def _load_prompt_transcripts() -> dict[str, str]: |
| """Load prompt transcripts from prompt_transcripts.json.""" |
| import json |
| candidate = os.path.join(_SAMPLES_DIR, "en", "prompt_transcripts.json") |
| if os.path.isfile(candidate): |
| with open(candidate) as f: |
| return json.load(f) |
| return {} |
|
|
|
|
| _PRESET_SAMPLES = _discover_preset_samples() |
| _PRESET_TRANSCRIPTS = _load_preset_transcripts() |
| _PROMPT_TRANSCRIPTS = _load_prompt_transcripts() |
| logger.info("Discovered %d preset audio samples, %d transcripts", len(_PRESET_SAMPLES), len(_PRESET_TRANSCRIPTS)) |
|
|
| |
| |
| |
| _MODEL_NAME = "HumeAI/tada-1b" |
| _device = "cpu" |
|
|
|
|
| def _validate_no_meta_tensors(model, name: str = "model"): |
| """Raise if any parameter is on the meta device (not materialised).""" |
| for param_name, param in model.named_parameters(): |
| if param.device.type == "meta": |
| raise RuntimeError( |
| f"{name} has meta-device parameter: {param_name}. " |
| "Pass low_cpu_mem_usage=False to from_pretrained()." |
| ) |
|
|
|
|
| logger.info("Loading encoder ...") |
| _encoder = Encoder.from_pretrained("HumeAI/tada-codec", language=None, low_cpu_mem_usage=False).to(_device) |
| _validate_no_meta_tensors(_encoder, "Encoder") |
|
|
| logger.info("Loading %s ...", _MODEL_NAME) |
| _model = TadaForCausalLM.from_pretrained(_MODEL_NAME, low_cpu_mem_usage=False) |
| _validate_no_meta_tensors(_model, "TadaForCausalLM") |
| logger.info("Models loaded.") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]: |
| """Decode a list of token IDs into per-token strings, handling multi-byte characters.""" |
| labels: list[str] = [] |
| for i in range(len(token_ids)): |
| prefix = tokenizer.decode(token_ids[:i], skip_special_tokens=True) |
| full = tokenizer.decode(token_ids[: i + 1], skip_special_tokens=True) |
| token_str = full[len(prefix) :] |
| labels.append(token_str) |
| return labels |
|
|
|
|
| def _format_token_alignment(prompt: EncoderOutput) -> str: |
| """Build an HTML string: dots in grey, tokens as bold coloured spans.""" |
| if prompt.text_tokens is None or prompt.token_positions is None: |
| return "" |
|
|
| tokenizer = _encoder.tokenizer |
| n_tokens = ( |
| int(prompt.text_tokens_len[0].item()) if prompt.text_tokens_len is not None else prompt.text_tokens.shape[1] |
| ) |
| token_ids = prompt.text_tokens[0, :n_tokens].cpu().tolist() |
| positions = prompt.token_positions[0, :n_tokens].cpu().long().tolist() |
|
|
| labels = _decode_tokens_individually(tokenizer, token_ids) |
|
|
| audio_dur = prompt.audio.shape[-1] / prompt.sample_rate if prompt.audio.numel() > 0 else 0.0 |
| header = f"{n_tokens} tokens | {audio_dur:.2f}s audio" |
|
|
| parts: list[str] = [] |
| prev_pos = 0 |
| for pos, label in zip(positions, labels): |
| gap = max(0, pos - prev_pos) |
| if gap > 0: |
| parts.append(f'<span style="color:#bbb">{"." * gap}</span>') |
| escaped = html.escape(label) |
| parts.append( |
| f'<span style="color:#1a1a2e; background:#e8e8ff; border-radius:3px; padding:0 2px; font-weight:600">{escaped}</span>' |
| ) |
| prev_pos = pos + 1 |
|
|
| body = "".join(parts) |
| return ( |
| f'<div style="font-family:monospace; font-size:13px; line-height:1.8; word-break:break-all; ' |
| f'padding:4px 0">' |
| f'<div style="font-size:11px; color:#666; margin-bottom:4px">{header}</div>' |
| f"{body}</div>" |
| ) |
|
|
|
|
| def _decode_byte_tokens(raw_tokens: list[str]) -> list[str]: |
| """Decode GPT-2 byte-level token strings into proper Unicode per-token labels.""" |
| if not raw_tokens: |
| return raw_tokens |
| try: |
| tokenizer = _model.tokenizer |
| token_ids = tokenizer.convert_tokens_to_ids(raw_tokens) |
| return _decode_tokens_individually(tokenizer, token_ids) |
| except Exception: |
| return [t.replace("\u0120", " ") for t in raw_tokens] |
|
|
|
|
| def _format_step_logs(step_logs: list[dict], audio_duration: float, wall_time: float) -> str: |
| """Build an HTML string from step_logs: dots for n_frames_before, tokens highlighted.""" |
| if not step_logs: |
| return "" |
|
|
| n_tokens = len(step_logs) |
| total_frames = sum(entry.get("n_frames_before", 0) for entry in step_logs) |
| rtf = wall_time / audio_duration if audio_duration > 0 else float("inf") |
| header = f"{n_tokens} steps | {audio_duration:.1f}s audio | {total_frames} frames | {wall_time:.1f}s wall | RTF {rtf:.2f}" |
|
|
| raw_tokens = [entry.get("token", "") for entry in step_logs] |
| labels = _decode_byte_tokens(raw_tokens) |
|
|
| parts: list[str] = [] |
| for entry, label in zip(step_logs, labels): |
| n_frames = entry.get("n_frames_before", 0) |
| if n_frames > 0: |
| parts.append(f'<span style="color:#bbb">{"." * n_frames}</span>') |
| escaped = html.escape(label) |
| parts.append( |
| f'<span style="color:#1a2e1a; background:#e8ffe8; border-radius:3px; padding:0 2px; font-weight:600">{escaped}</span>' |
| ) |
|
|
| body = "".join(parts) |
| return ( |
| f'<div style="font-family:monospace; font-size:13px; line-height:1.8; word-break:break-all; ' |
| f'padding:4px 0">' |
| f'<div style="font-size:11px; color:#666; margin-bottom:4px">{header}</div>' |
| f"{body}</div>" |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @gpu_decorator(duration=120) |
| @torch.inference_mode() |
| def generate( |
| audio_path: str | None, |
| text: str, |
| num_extra_steps: float = 0, |
| noise_temperature: float = 0.9, |
| acoustic_cfg_scale: float = 2.0, |
| duration_cfg_scale: float = 2.0, |
| num_flow_matching_steps: float = 20, |
| negative_condition_source: str = "negative_step_output", |
| text_only_logit_scale: float = 0.0, |
| num_acoustic_candidates: float = 1, |
| scorer: str = "likelihood", |
| spkr_verification_weight: float = 1.0, |
| speed_up_factor: float = 0.0, |
| normalize_text: bool = True, |
| ) -> tuple[str | None, str, str]: |
| """Encode prompt + generate speech in a single GPU call. |
| |
| Returns (wav_path, prompt_alignment_html, generated_alignment_html). |
| """ |
| |
| _encoder.to(_device) |
| _model.to(_device) |
| _model.decoder.to(_device) |
|
|
| |
| if audio_path is None or audio_path == "": |
| prompt = EncoderOutput.empty(_device) |
| prompt_html = "No audio provided (zero-shot mode)." |
| else: |
| audio, sample_rate = torchaudio.load(audio_path) |
| audio = audio.mean(dim=0, keepdim=True) |
| audio = audio / audio.abs().max().clamp(min=1e-8) * 0.95 |
| audio = audio.to(_device) |
|
|
| |
| prompt_text = None |
| if audio_path: |
| audio_fname = os.path.basename(audio_path) |
| for key in (audio_fname, audio_fname.replace("tada_preset_", "")): |
| if key in _PROMPT_TRANSCRIPTS: |
| prompt_text = _PROMPT_TRANSCRIPTS[key] |
| break |
|
|
| text_kwarg = [prompt_text] if prompt_text else None |
| prompt = _encoder(audio, text=text_kwarg, sample_rate=sample_rate) |
| prompt_html = _format_token_alignment(prompt) |
|
|
| |
| try: |
| logger.info("Generating speech for text: %s", text) |
|
|
| suf = float(speed_up_factor) if speed_up_factor > 0 else None |
|
|
| t0 = time.time() |
| output = _model.generate( |
| prompt=prompt, |
| text=text, |
| num_transition_steps=0, |
| num_extra_steps=int(num_extra_steps), |
| normalize_text=normalize_text, |
| inference_options=InferenceOptions( |
| acoustic_cfg_scale=float(acoustic_cfg_scale), |
| duration_cfg_scale=float(duration_cfg_scale), |
| num_flow_matching_steps=int(num_flow_matching_steps), |
| noise_temperature=float(noise_temperature), |
| speed_up_factor=suf, |
| time_schedule="logsnr", |
| negative_condition_source=negative_condition_source, |
| text_only_logit_scale=float(text_only_logit_scale), |
| num_acoustic_candidates=int(num_acoustic_candidates), |
| scorer=scorer, |
| spkr_verification_weight=float(spkr_verification_weight), |
| ), |
| system_prompt="", |
| ) |
| wall_time = time.time() - t0 |
|
|
| wav = output.audio[0].detach().cpu().float() |
| if wav.dim() == 1: |
| wav = wav.unsqueeze(0) |
|
|
| tmp_path = os.path.join(tempfile.gettempdir(), f"tada_output_{id(output)}.wav") |
| torchaudio.save(tmp_path, wav, 24_000) |
|
|
| audio_duration = wav.shape[-1] / 24_000 |
|
|
| |
| all_logs = output.step_logs or [] |
| if text and output.input_text_ids is not None: |
| input_ids = output.input_text_ids[0] |
| seq_len = input_ids.shape[0] |
| n_eos = _model.config.shift_acoustic |
| normalized = normalize_text_fn(text) if normalize_text else text |
| n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False)) |
| text_end = seq_len - n_eos |
| text_start = text_end - n_text_tokens |
|
|
| log_by_step = {e["step"]: e for e in all_logs} |
|
|
| text_logs = [] |
| for s in range(text_start, text_end): |
| if s in log_by_step: |
| text_logs.append(log_by_step[s]) |
| else: |
| token_id = input_ids[s].item() |
| token_str = _model.tokenizer.convert_ids_to_tokens([token_id])[0] |
| text_logs.append({ |
| "step": s, |
| "token": token_str, |
| "n_frames_before": 0, |
| "n_frames_after": 0, |
| "n_frames_src": "prefilled", |
| "acoustic_mask": 1, |
| "acoustic_feat_src": "prefilled", |
| "acoustic_feat_norm": 0.0, |
| }) |
| generated_logs = text_logs |
| else: |
| generated_logs = all_logs |
| generated_html = _format_step_logs(generated_logs, audio_duration, wall_time) |
|
|
| return tmp_path, prompt_html, generated_html |
|
|
| except gr.Error: |
| raise |
| except Exception as e: |
| logger.exception("Generation failed") |
| raise gr.Error(f"Generation failed: {e}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_ui() -> gr.Blocks: |
| with gr.Blocks( |
| title="TADA Inference", |
| css=( |
| ".gradio-container { max-width: 1400px !important; width: 100% !important; margin: auto !important; } " |
| ".compact-audio { min-height: 0 !important; } " |
| ".compact-audio audio { height: 36px !important; } " |
| ), |
| ) as demo: |
| gr.Markdown( |
| "# TADA - Text-Acoustic Dual Alignment LLM\n" |
| "A demo of **tada-3b-ml** \u2014 " |
| "a text-to-speech model that clones voice, emotion, and timing from a short audio prompt.\n\n" |
| "**How to use:** Choose a voice prompt (or upload your own), enter text, and click **Generate**. " |
| "The model will encode the prompt and generate speech in one step." |
| ) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=1): |
| with gr.Accordion("Text Settings", open=False): |
| num_extra_steps = gr.Slider( |
| minimum=0, maximum=200, value=0, step=1, |
| label="Text Tokens to Generate", |
| ) |
| text_only_logit_scale = gr.Slider( |
| minimum=0.0, maximum=5.0, value=0.0, step=0.1, |
| label="Text-Only Logit Scale", |
| info="0 = disabled. Blends text-only logits with audio-conditioned logits.", |
| ) |
| normalize_text_cb = gr.Checkbox( |
| value=True, |
| label="Normalize Text", |
| info="Apply text normalization before generation", |
| ) |
|
|
| with gr.Accordion("Acoustic Settings", open=False): |
| acoustic_cfg_scale = gr.Slider( |
| minimum=1.0, maximum=3.0, value=1.6, step=0.1, |
| label="Acoustic CFG Scale", |
| ) |
| duration_cfg_scale = gr.Slider( |
| minimum=1.0, maximum=3.0, value=1.0, step=0.1, |
| label="Duration CFG Scale", |
| ) |
| negative_condition_source = gr.Dropdown( |
| choices=["negative_step_output", "prompt", "zero"], |
| value="negative_step_output", |
| label="Negative Condition Source", |
| ) |
| noise_temperature = gr.Slider( |
| minimum=0.4, maximum=1.2, value=0.9, step=0.1, |
| label="Noise Temperature", |
| ) |
| num_flow_matching_steps = gr.Slider( |
| minimum=5, maximum=50, value=20, step=5, |
| label="Flow Matching Steps", |
| ) |
| speed_up_factor = gr.Slider( |
| minimum=0.0, maximum=3.0, value=0.0, step=0.1, |
| label="Speed Up Factor", |
| info="0 = disabled (natural duration). >0 scales speech speed.", |
| ) |
| num_acoustic_candidates = gr.Slider( |
| minimum=1, maximum=16, value=1, step=1, |
| label="Acoustic Candidates", |
| info="Number of candidates to generate and rank.", |
| ) |
| scorer_dropdown = gr.Dropdown( |
| choices=["likelihood", "spkr_verification", "duration_median"], |
| value="likelihood", |
| label="Scorer", |
| info="How to rank acoustic candidates.", |
| ) |
| spkr_verification_weight = gr.Slider( |
| minimum=0.0, maximum=5.0, value=1.0, step=0.1, |
| label="Speaker Verification Weight", |
| info="Weight for spkr_verification scorer.", |
| ) |
|
|
| with gr.Column(scale=2): |
| preset_choices = ["None (zero-shot)"] + list(_PRESET_SAMPLES.keys()) |
| _default_voice = "fb_ears_emo_amusement_freeform.wav" |
| preset_dropdown = gr.Dropdown( |
| choices=preset_choices, |
| value=_default_voice if _default_voice in _PRESET_SAMPLES else "None (zero-shot)", |
| label="Voice Prompt", |
| info="Pick a preset or upload / record your own", |
| ) |
| _default_voice_path = _PRESET_SAMPLES.get(_default_voice) |
| audio_input = gr.Audio( |
| label="Prompt Preview", |
| type="filepath", |
| sources=["upload", "microphone"], |
| value=_default_voice_path, |
| elem_classes=["compact-audio"], |
| ) |
|
|
| def _on_preset_selected(choice: str) -> str | None: |
| if choice == "None (zero-shot)": |
| return None |
| path = _PRESET_SAMPLES.get(choice) |
| if path is None: |
| return None |
| tmp_path = os.path.join(tempfile.gettempdir(), f"tada_preset_{choice}") |
| shutil.copy2(path, tmp_path) |
| return tmp_path |
|
|
| preset_dropdown.change( |
| fn=_on_preset_selected, |
| inputs=[preset_dropdown], |
| outputs=[audio_input], |
| ) |
|
|
| with gr.Accordion("Prompt Token Alignment", open=True): |
| prompt_alignment = gr.HTML(value="Generate to see prompt alignment.") |
|
|
| with gr.Column(scale=2): |
| _default_transcript = "emo_interest_sentences" |
| transcript_choices = ["(custom)"] + list(_PRESET_TRANSCRIPTS.keys()) |
| transcript_dropdown = gr.Dropdown( |
| choices=transcript_choices, |
| value=_default_transcript if _default_transcript in _PRESET_TRANSCRIPTS else "(custom)", |
| label="Transcript", |
| info="Pick a preset or type your own below", |
| ) |
| text_input = gr.Textbox( |
| label="Text to Speak", |
| placeholder="Type what you want the model to say ...", |
| autoscroll=False, |
| max_lines=20, |
| value=_PRESET_TRANSCRIPTS.get(_default_transcript, ""), |
| ) |
|
|
| def _on_transcript_selected(choice: str) -> str: |
| if choice == "(custom)": |
| return "" |
| return _PRESET_TRANSCRIPTS.get(choice, "") |
|
|
| transcript_dropdown.change( |
| fn=_on_transcript_selected, |
| inputs=[transcript_dropdown], |
| outputs=[text_input], |
| ) |
|
|
| generate_btn = gr.Button("Generate", variant="primary", size="lg") |
|
|
| |
| audio_output = gr.Audio(label="Generated Audio") |
| with gr.Accordion("Generated Alignment", open=False): |
| generated_text_display = gr.HTML(value="Generate speech to see the alignment") |
|
|
| |
| all_inputs = [ |
| audio_input, |
| text_input, |
| num_extra_steps, |
| noise_temperature, |
| acoustic_cfg_scale, |
| duration_cfg_scale, |
| num_flow_matching_steps, |
| negative_condition_source, |
| text_only_logit_scale, |
| num_acoustic_candidates, |
| scorer_dropdown, |
| spkr_verification_weight, |
| speed_up_factor, |
| normalize_text_cb, |
| ] |
|
|
| generate_btn.click( |
| fn=generate, |
| inputs=all_inputs, |
| outputs=[audio_output, prompt_alignment, generated_text_display], |
| ) |
|
|
| return demo |
|
|
|
|
| |
| |
| |
|
|
| _share = os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes") |
| _port = int(os.environ.get("GRADIO_PORT", "7860")) |
|
|
| |
| demo = build_ui() |
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="TADA Inference Gradio App") |
| parser.add_argument("--share", action="store_true", default=_share, help="Create a public Gradio share link") |
| parser.add_argument("--port", type=int, default=_port, help="Server port (default: 7860)") |
| args = parser.parse_args() |
|
|
| demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=[_SAMPLES_DIR]) |
| else: |
| demo.launch(server_name="0.0.0.0", server_port=_port, share=_share, allowed_paths=[_SAMPLES_DIR]) |
|
|