""" Gradio app for TADA inference (English-only, single model). Usage: pip install hume-tada python app.py # or with hot reload + share link: GRADIO_SHARE=1 gradio app.py """ import html import logging import os import shutil import tempfile import time import torch import torchaudio import gradio as gr try: import spaces gpu_decorator = spaces.GPU except ImportError: gpu_decorator = lambda fn=None, **kw: fn if fn else (lambda f: f) from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402 from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402 from tada.utils.text import normalize_text as normalize_text_fn # noqa: E402 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Preset samples & transcripts (English only) # --------------------------------------------------------------------------- _script_dir = os.path.dirname(os.path.abspath(__file__)) _SAMPLES_DIR = os.path.join(_script_dir, "samples") _AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac") def _discover_preset_samples() -> dict[str, str]: """Return {display_name: absolute_path} for audio files in samples/en/.""" presets: dict[str, str] = {} search_dir = os.path.join(_SAMPLES_DIR, "en") if not os.path.isdir(search_dir): return presets for fname in sorted(os.listdir(search_dir)): if fname.lower().endswith(_AUDIO_EXTENSIONS): presets[fname] = os.path.join(search_dir, fname) return presets def _load_preset_transcripts() -> dict[str, str]: """Load preset transcripts from synth_transcripts.json.""" import json candidate = os.path.join(_SAMPLES_DIR, "en", "synth_transcripts.json") if os.path.isfile(candidate): with open(candidate) as f: return json.load(f) return {} def _load_prompt_transcripts() -> dict[str, str]: """Load prompt transcripts from prompt_transcripts.json.""" import json candidate = os.path.join(_SAMPLES_DIR, "en", "prompt_transcripts.json") if os.path.isfile(candidate): with open(candidate) as f: return json.load(f) return {} _PRESET_SAMPLES = _discover_preset_samples() _PRESET_TRANSCRIPTS = _load_preset_transcripts() _PROMPT_TRANSCRIPTS = _load_prompt_transcripts() logger.info("Discovered %d preset audio samples, %d transcripts", len(_PRESET_SAMPLES), len(_PRESET_TRANSCRIPTS)) # --------------------------------------------------------------------------- # Global model state — single model, single encoder # --------------------------------------------------------------------------- _MODEL_NAME = "HumeAI/tada-3b-ml" _device = "cuda" def _validate_no_meta_tensors(model, name: str = "model"): """Raise if any parameter is on the meta device (not materialised).""" for param_name, param in model.named_parameters(): if param.device.type == "meta": raise RuntimeError( f"{name} has meta-device parameter: {param_name}. " "Pass low_cpu_mem_usage=False to from_pretrained()." ) logger.info("Loading encoder ...") _encoder = Encoder.from_pretrained("HumeAI/tada-codec", language=None, low_cpu_mem_usage=False).to(_device) _validate_no_meta_tensors(_encoder, "Encoder") logger.info("Loading %s ...", _MODEL_NAME) _model = TadaForCausalLM.from_pretrained(_MODEL_NAME, low_cpu_mem_usage=False) _validate_no_meta_tensors(_model, "TadaForCausalLM") logger.info("Models loaded.") # --------------------------------------------------------------------------- # Core inference helpers # --------------------------------------------------------------------------- def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]: """Decode a list of token IDs into per-token strings, handling multi-byte characters.""" labels: list[str] = [] for i in range(len(token_ids)): prefix = tokenizer.decode(token_ids[:i], skip_special_tokens=True) full = tokenizer.decode(token_ids[: i + 1], skip_special_tokens=True) token_str = full[len(prefix) :] labels.append(token_str) return labels def _format_token_alignment(prompt: EncoderOutput) -> str: """Build an HTML string: dots in grey, tokens as bold coloured spans.""" if prompt.text_tokens is None or prompt.token_positions is None: return "" tokenizer = _encoder.tokenizer n_tokens = ( int(prompt.text_tokens_len[0].item()) if prompt.text_tokens_len is not None else prompt.text_tokens.shape[1] ) token_ids = prompt.text_tokens[0, :n_tokens].cpu().tolist() positions = prompt.token_positions[0, :n_tokens].cpu().long().tolist() labels = _decode_tokens_individually(tokenizer, token_ids) audio_dur = prompt.audio.shape[-1] / prompt.sample_rate if prompt.audio.numel() > 0 else 0.0 header = f"{n_tokens} tokens | {audio_dur:.2f}s audio" parts: list[str] = [] prev_pos = 0 for pos, label in zip(positions, labels): gap = max(0, pos - prev_pos) if gap > 0: parts.append(f'{"." * gap}') escaped = html.escape(label) parts.append( f'{escaped}' ) prev_pos = pos + 1 body = "".join(parts) return ( f'
' f'
{header}
' f"{body}
" ) def _decode_byte_tokens(raw_tokens: list[str]) -> list[str]: """Decode GPT-2 byte-level token strings into proper Unicode per-token labels.""" if not raw_tokens: return raw_tokens try: tokenizer = _model.tokenizer token_ids = tokenizer.convert_tokens_to_ids(raw_tokens) return _decode_tokens_individually(tokenizer, token_ids) except Exception: return [t.replace("\u0120", " ") for t in raw_tokens] def _format_step_logs(step_logs: list[dict], audio_duration: float, wall_time: float) -> str: """Build an HTML string from step_logs: dots for n_frames_before, tokens highlighted.""" if not step_logs: return "" n_tokens = len(step_logs) total_frames = sum(entry.get("n_frames_before", 0) for entry in step_logs) rtf = wall_time / audio_duration if audio_duration > 0 else float("inf") header = f"{n_tokens} steps | {audio_duration:.1f}s audio | {total_frames} frames | {wall_time:.1f}s wall | RTF {rtf:.2f}" raw_tokens = [entry.get("token", "") for entry in step_logs] labels = _decode_byte_tokens(raw_tokens) parts: list[str] = [] for entry, label in zip(step_logs, labels): n_frames = entry.get("n_frames_before", 0) if n_frames > 0: parts.append(f'{"." * n_frames}') escaped = html.escape(label) parts.append( f'{escaped}' ) body = "".join(parts) return ( f'
' f'
{header}
' f"{body}
" ) # --------------------------------------------------------------------------- # Single generate function (merged prompt encoding + generation) # --------------------------------------------------------------------------- @gpu_decorator(duration=120) @torch.inference_mode() def generate( audio_path: str | None, text: str, num_extra_steps: float = 0, noise_temperature: float = 0.9, acoustic_cfg_scale: float = 2.0, duration_cfg_scale: float = 2.0, num_flow_matching_steps: float = 20, negative_condition_source: str = "negative_step_output", text_only_logit_scale: float = 0.0, num_acoustic_candidates: float = 1, scorer: str = "likelihood", spkr_verification_weight: float = 1.0, speed_up_factor: float = 0.0, normalize_text: bool = True, ) -> tuple[str | None, str, str]: """Encode prompt + generate speech in a single GPU call. Returns (wav_path, prompt_alignment_html, generated_alignment_html). """ # Move model + encoder to GPU _encoder.to(_device) _model.to(_device) _model.decoder.to(_device) # --- Encode prompt --- if audio_path is None or audio_path == "": prompt = EncoderOutput.empty(_device) prompt_html = "No audio provided (zero-shot mode)." else: audio, sample_rate = torchaudio.load(audio_path) audio = audio.mean(dim=0, keepdim=True) # mono audio = audio / audio.abs().max().clamp(min=1e-8) * 0.95 audio = audio.to(_device) # Look up prompt transcript for preset samples prompt_text = None if audio_path: audio_fname = os.path.basename(audio_path) for key in (audio_fname, audio_fname.replace("tada_preset_", "")): if key in _PROMPT_TRANSCRIPTS: prompt_text = _PROMPT_TRANSCRIPTS[key] break text_kwarg = [prompt_text] if prompt_text else None prompt = _encoder(audio, text=text_kwarg, sample_rate=sample_rate) prompt_html = _format_token_alignment(prompt) # --- Generate speech --- try: logger.info("Generating speech for text: %s", text) suf = float(speed_up_factor) if speed_up_factor > 0 else None t0 = time.time() output = _model.generate( prompt=prompt, text=text, num_transition_steps=0, num_extra_steps=int(num_extra_steps), normalize_text=normalize_text, inference_options=InferenceOptions( acoustic_cfg_scale=float(acoustic_cfg_scale), duration_cfg_scale=float(duration_cfg_scale), num_flow_matching_steps=int(num_flow_matching_steps), noise_temperature=float(noise_temperature), speed_up_factor=suf, time_schedule="logsnr", negative_condition_source=negative_condition_source, text_only_logit_scale=float(text_only_logit_scale), num_acoustic_candidates=int(num_acoustic_candidates), scorer=scorer, spkr_verification_weight=float(spkr_verification_weight), ), system_prompt="", ) wall_time = time.time() - t0 wav = output.audio[0].detach().cpu().float() if wav.dim() == 1: wav = wav.unsqueeze(0) tmp_path = os.path.join(tempfile.gettempdir(), f"tada_output_{id(output)}.wav") torchaudio.save(tmp_path, wav, 24_000) audio_duration = wav.shape[-1] / 24_000 # Extract text-to-speak step_logs all_logs = output.step_logs or [] if text and output.input_text_ids is not None: input_ids = output.input_text_ids[0] seq_len = input_ids.shape[0] n_eos = _model.config.shift_acoustic normalized = normalize_text_fn(text) if normalize_text else text n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False)) text_end = seq_len - n_eos text_start = text_end - n_text_tokens log_by_step = {e["step"]: e for e in all_logs} text_logs = [] for s in range(text_start, text_end): if s in log_by_step: text_logs.append(log_by_step[s]) else: token_id = input_ids[s].item() token_str = _model.tokenizer.convert_ids_to_tokens([token_id])[0] text_logs.append({ "step": s, "token": token_str, "n_frames_before": 0, "n_frames_after": 0, "n_frames_src": "prefilled", "acoustic_mask": 1, "acoustic_feat_src": "prefilled", "acoustic_feat_norm": 0.0, }) generated_logs = text_logs else: generated_logs = all_logs generated_html = _format_step_logs(generated_logs, audio_duration, wall_time) return tmp_path, prompt_html, generated_html except gr.Error: raise except Exception as e: logger.exception("Generation failed") raise gr.Error(f"Generation failed: {e}") # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def build_ui() -> gr.Blocks: with gr.Blocks( title="TADA Inference", css=( ".gradio-container { max-width: 1400px !important; width: 100% !important; margin: auto !important; } " ".compact-audio { min-height: 0 !important; } " ".compact-audio audio { height: 36px !important; } " ), ) as demo: gr.Markdown( "# TADA - Text-Acoustic Dual Alignment LLM\n" "A demo of **tada-3b-ml** \u2014 " "a text-to-speech model that clones voice, emotion, and timing from a short audio prompt.\n\n" "**How to use:** Choose a voice prompt (or upload your own), enter text, and click **Generate**. " "The model will encode the prompt and generate speech in one step." ) with gr.Row(equal_height=False): with gr.Column(scale=1): with gr.Accordion("Text Settings", open=False): num_extra_steps = gr.Slider( minimum=0, maximum=200, value=0, step=1, label="Text Tokens to Generate", ) text_only_logit_scale = gr.Slider( minimum=0.0, maximum=5.0, value=0.0, step=0.1, label="Text-Only Logit Scale", info="0 = disabled. Blends text-only logits with audio-conditioned logits.", ) normalize_text_cb = gr.Checkbox( value=True, label="Normalize Text", info="Apply text normalization before generation", ) with gr.Accordion("Acoustic Settings", open=False): acoustic_cfg_scale = gr.Slider( minimum=1.0, maximum=3.0, value=1.6, step=0.1, label="Acoustic CFG Scale", ) duration_cfg_scale = gr.Slider( minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Duration CFG Scale", ) negative_condition_source = gr.Dropdown( choices=["negative_step_output", "prompt", "zero"], value="negative_step_output", label="Negative Condition Source", ) noise_temperature = gr.Slider( minimum=0.4, maximum=1.2, value=0.9, step=0.1, label="Noise Temperature", ) num_flow_matching_steps = gr.Slider( minimum=5, maximum=50, value=20, step=5, label="Flow Matching Steps", ) speed_up_factor = gr.Slider( minimum=0.0, maximum=3.0, value=0.0, step=0.1, label="Speed Up Factor", info="0 = disabled (natural duration). >0 scales speech speed.", ) num_acoustic_candidates = gr.Slider( minimum=1, maximum=16, value=1, step=1, label="Acoustic Candidates", info="Number of candidates to generate and rank.", ) scorer_dropdown = gr.Dropdown( choices=["likelihood", "spkr_verification", "duration_median"], value="likelihood", label="Scorer", info="How to rank acoustic candidates.", ) spkr_verification_weight = gr.Slider( minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="Speaker Verification Weight", info="Weight for spkr_verification scorer.", ) with gr.Column(scale=2): preset_choices = ["None (zero-shot)"] + list(_PRESET_SAMPLES.keys()) _default_voice = "fb_ears_emo_amusement_freeform.wav" preset_dropdown = gr.Dropdown( choices=preset_choices, value=_default_voice if _default_voice in _PRESET_SAMPLES else "None (zero-shot)", label="Voice Prompt", info="Pick a preset or upload / record your own", ) _default_voice_path = _PRESET_SAMPLES.get(_default_voice) audio_input = gr.Audio( label="Prompt Preview", type="filepath", sources=["upload", "microphone"], value=_default_voice_path, elem_classes=["compact-audio"], ) def _on_preset_selected(choice: str) -> str | None: if choice == "None (zero-shot)": return None path = _PRESET_SAMPLES.get(choice) if path is None: return None tmp_path = os.path.join(tempfile.gettempdir(), f"tada_preset_{choice}") shutil.copy2(path, tmp_path) return tmp_path preset_dropdown.change( fn=_on_preset_selected, inputs=[preset_dropdown], outputs=[audio_input], ) with gr.Accordion("Prompt Token Alignment", open=True): prompt_alignment = gr.HTML(value="Generate to see prompt alignment.") with gr.Column(scale=2): _default_transcript = "emo_interest_sentences" transcript_choices = ["(custom)"] + list(_PRESET_TRANSCRIPTS.keys()) transcript_dropdown = gr.Dropdown( choices=transcript_choices, value=_default_transcript if _default_transcript in _PRESET_TRANSCRIPTS else "(custom)", label="Transcript", info="Pick a preset or type your own below", ) text_input = gr.Textbox( label="Text to Speak", placeholder="Type what you want the model to say ...", autoscroll=False, max_lines=20, value=_PRESET_TRANSCRIPTS.get(_default_transcript, ""), ) def _on_transcript_selected(choice: str) -> str: if choice == "(custom)": return "" return _PRESET_TRANSCRIPTS.get(choice, "") transcript_dropdown.change( fn=_on_transcript_selected, inputs=[transcript_dropdown], outputs=[text_input], ) generate_btn = gr.Button("Generate", variant="primary", size="lg") # --- Output --- audio_output = gr.Audio(label="Generated Audio") with gr.Accordion("Generated Alignment", open=False): generated_text_display = gr.HTML(value="Generate speech to see the alignment") # Wire up generate button all_inputs = [ audio_input, text_input, num_extra_steps, noise_temperature, acoustic_cfg_scale, duration_cfg_scale, num_flow_matching_steps, negative_condition_source, text_only_logit_scale, num_acoustic_candidates, scorer_dropdown, spkr_verification_weight, speed_up_factor, normalize_text_cb, ] generate_btn.click( fn=generate, inputs=all_inputs, outputs=[audio_output, prompt_alignment, generated_text_display], ) return demo # --------------------------------------------------------------------------- # Entry-point # --------------------------------------------------------------------------- _share = os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes") _port = int(os.environ.get("GRADIO_PORT", "7860")) # `demo` at module scope so the `gradio` CLI / HF Spaces can discover it. demo = build_ui() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="TADA Inference Gradio App") parser.add_argument("--share", action="store_true", default=_share, help="Create a public Gradio share link") parser.add_argument("--port", type=int, default=_port, help="Server port (default: 7860)") args = parser.parse_args() demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=[_SAMPLES_DIR]) else: demo.launch(server_name="0.0.0.0", server_port=_port, share=_share, allowed_paths=[_SAMPLES_DIR])