Spaces:
Running on Zero
Running on Zero
| """ | |
| Gradio app for TADA inference (English-only, single model). | |
| Usage: | |
| pip install hume-tada | |
| python app.py | |
| # or with hot reload + share link: | |
| GRADIO_SHARE=1 gradio app.py | |
| """ | |
| import html | |
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import time | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| try: | |
| import spaces | |
| gpu_decorator = spaces.GPU | |
| except ImportError: | |
| gpu_decorator = lambda fn=None, **kw: fn if fn else (lambda f: f) | |
| from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402 | |
| from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402 | |
| from tada.utils.text import normalize_text as normalize_text_fn # noqa: E402 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Preset samples & transcripts (English only) | |
| # --------------------------------------------------------------------------- | |
| _script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| _SAMPLES_DIR = os.path.join(_script_dir, "samples") | |
| _AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac") | |
| def _discover_preset_samples() -> dict[str, str]: | |
| """Return {display_name: absolute_path} for audio files in samples/en/.""" | |
| presets: dict[str, str] = {} | |
| search_dir = os.path.join(_SAMPLES_DIR, "en") | |
| if not os.path.isdir(search_dir): | |
| return presets | |
| for fname in sorted(os.listdir(search_dir)): | |
| if fname.lower().endswith(_AUDIO_EXTENSIONS): | |
| presets[fname] = os.path.join(search_dir, fname) | |
| return presets | |
| def _load_preset_transcripts() -> dict[str, str]: | |
| """Load preset transcripts from synth_transcripts.json.""" | |
| import json | |
| candidate = os.path.join(_SAMPLES_DIR, "en", "synth_transcripts.json") | |
| if os.path.isfile(candidate): | |
| with open(candidate) as f: | |
| return json.load(f) | |
| return {} | |
| def _load_prompt_transcripts() -> dict[str, str]: | |
| """Load prompt transcripts from prompt_transcripts.json.""" | |
| import json | |
| candidate = os.path.join(_SAMPLES_DIR, "en", "prompt_transcripts.json") | |
| if os.path.isfile(candidate): | |
| with open(candidate) as f: | |
| return json.load(f) | |
| return {} | |
| _PRESET_SAMPLES = _discover_preset_samples() | |
| _PRESET_TRANSCRIPTS = _load_preset_transcripts() | |
| _PROMPT_TRANSCRIPTS = _load_prompt_transcripts() | |
| logger.info("Discovered %d preset audio samples, %d transcripts", len(_PRESET_SAMPLES), len(_PRESET_TRANSCRIPTS)) | |
| # --------------------------------------------------------------------------- | |
| # Global model state — single model, single encoder | |
| # --------------------------------------------------------------------------- | |
| _MODEL_NAME = "HumeAI/tada-3b-ml" | |
| _device = "cuda" | |
| def _validate_no_meta_tensors(model, name: str = "model"): | |
| """Raise if any parameter is on the meta device (not materialised).""" | |
| for param_name, param in model.named_parameters(): | |
| if param.device.type == "meta": | |
| raise RuntimeError( | |
| f"{name} has meta-device parameter: {param_name}. " | |
| "Pass low_cpu_mem_usage=False to from_pretrained()." | |
| ) | |
| logger.info("Loading encoder ...") | |
| _encoder = Encoder.from_pretrained("HumeAI/tada-codec", language=None, low_cpu_mem_usage=False).to(_device) | |
| _validate_no_meta_tensors(_encoder, "Encoder") | |
| logger.info("Loading %s ...", _MODEL_NAME) | |
| _model = TadaForCausalLM.from_pretrained(_MODEL_NAME, low_cpu_mem_usage=False) | |
| _validate_no_meta_tensors(_model, "TadaForCausalLM") | |
| logger.info("Models loaded.") | |
| # --------------------------------------------------------------------------- | |
| # Core inference helpers | |
| # --------------------------------------------------------------------------- | |
| def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]: | |
| """Decode a list of token IDs into per-token strings, handling multi-byte characters.""" | |
| labels: list[str] = [] | |
| for i in range(len(token_ids)): | |
| prefix = tokenizer.decode(token_ids[:i], skip_special_tokens=True) | |
| full = tokenizer.decode(token_ids[: i + 1], skip_special_tokens=True) | |
| token_str = full[len(prefix) :] | |
| labels.append(token_str) | |
| return labels | |
| def _format_token_alignment(prompt: EncoderOutput) -> str: | |
| """Build an HTML string: dots in grey, tokens as bold coloured spans.""" | |
| if prompt.text_tokens is None or prompt.token_positions is None: | |
| return "" | |
| tokenizer = _encoder.tokenizer | |
| n_tokens = ( | |
| int(prompt.text_tokens_len[0].item()) if prompt.text_tokens_len is not None else prompt.text_tokens.shape[1] | |
| ) | |
| token_ids = prompt.text_tokens[0, :n_tokens].cpu().tolist() | |
| positions = prompt.token_positions[0, :n_tokens].cpu().long().tolist() | |
| labels = _decode_tokens_individually(tokenizer, token_ids) | |
| audio_dur = prompt.audio.shape[-1] / prompt.sample_rate if prompt.audio.numel() > 0 else 0.0 | |
| header = f"{n_tokens} tokens | {audio_dur:.2f}s audio" | |
| parts: list[str] = [] | |
| prev_pos = 0 | |
| for pos, label in zip(positions, labels): | |
| gap = max(0, pos - prev_pos) | |
| if gap > 0: | |
| parts.append(f'<span style="color:#bbb">{"." * gap}</span>') | |
| escaped = html.escape(label) | |
| parts.append( | |
| f'<span style="color:#1a1a2e; background:#e8e8ff; border-radius:3px; padding:0 2px; font-weight:600">{escaped}</span>' | |
| ) | |
| prev_pos = pos + 1 | |
| body = "".join(parts) | |
| return ( | |
| f'<div style="font-family:monospace; font-size:13px; line-height:1.8; word-break:break-all; ' | |
| f'padding:4px 0">' | |
| f'<div style="font-size:11px; color:#666; margin-bottom:4px">{header}</div>' | |
| f"{body}</div>" | |
| ) | |
| def _decode_byte_tokens(raw_tokens: list[str]) -> list[str]: | |
| """Decode GPT-2 byte-level token strings into proper Unicode per-token labels.""" | |
| if not raw_tokens: | |
| return raw_tokens | |
| try: | |
| tokenizer = _model.tokenizer | |
| token_ids = tokenizer.convert_tokens_to_ids(raw_tokens) | |
| return _decode_tokens_individually(tokenizer, token_ids) | |
| except Exception: | |
| return [t.replace("\u0120", " ") for t in raw_tokens] | |
| def _format_step_logs(step_logs: list[dict], audio_duration: float, wall_time: float) -> str: | |
| """Build an HTML string from step_logs: dots for n_frames_before, tokens highlighted.""" | |
| if not step_logs: | |
| return "" | |
| n_tokens = len(step_logs) | |
| total_frames = sum(entry.get("n_frames_before", 0) for entry in step_logs) | |
| rtf = wall_time / audio_duration if audio_duration > 0 else float("inf") | |
| header = f"{n_tokens} steps | {audio_duration:.1f}s audio | {total_frames} frames | {wall_time:.1f}s wall | RTF {rtf:.2f}" | |
| raw_tokens = [entry.get("token", "") for entry in step_logs] | |
| labels = _decode_byte_tokens(raw_tokens) | |
| parts: list[str] = [] | |
| for entry, label in zip(step_logs, labels): | |
| n_frames = entry.get("n_frames_before", 0) | |
| if n_frames > 0: | |
| parts.append(f'<span style="color:#bbb">{"." * n_frames}</span>') | |
| escaped = html.escape(label) | |
| parts.append( | |
| f'<span style="color:#1a2e1a; background:#e8ffe8; border-radius:3px; padding:0 2px; font-weight:600">{escaped}</span>' | |
| ) | |
| body = "".join(parts) | |
| return ( | |
| f'<div style="font-family:monospace; font-size:13px; line-height:1.8; word-break:break-all; ' | |
| f'padding:4px 0">' | |
| f'<div style="font-size:11px; color:#666; margin-bottom:4px">{header}</div>' | |
| f"{body}</div>" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Single generate function (merged prompt encoding + generation) | |
| # --------------------------------------------------------------------------- | |
| def generate( | |
| audio_path: str | None, | |
| text: str, | |
| num_extra_steps: float = 0, | |
| noise_temperature: float = 0.9, | |
| acoustic_cfg_scale: float = 2.0, | |
| duration_cfg_scale: float = 2.0, | |
| num_flow_matching_steps: float = 20, | |
| negative_condition_source: str = "negative_step_output", | |
| text_only_logit_scale: float = 0.0, | |
| num_acoustic_candidates: float = 1, | |
| scorer: str = "likelihood", | |
| spkr_verification_weight: float = 1.0, | |
| speed_up_factor: float = 0.0, | |
| normalize_text: bool = True, | |
| ) -> tuple[str | None, str, str]: | |
| """Encode prompt + generate speech in a single GPU call. | |
| Returns (wav_path, prompt_alignment_html, generated_alignment_html). | |
| """ | |
| # Move model + encoder to GPU | |
| _encoder.to(_device) | |
| _model.to(_device) | |
| _model.decoder.to(_device) | |
| # --- Encode prompt --- | |
| if audio_path is None or audio_path == "": | |
| prompt = EncoderOutput.empty(_device) | |
| prompt_html = "No audio provided (zero-shot mode)." | |
| else: | |
| audio, sample_rate = torchaudio.load(audio_path) | |
| audio = audio.mean(dim=0, keepdim=True) # mono | |
| audio = audio / audio.abs().max().clamp(min=1e-8) * 0.95 | |
| audio = audio.to(_device) | |
| # Look up prompt transcript for preset samples | |
| prompt_text = None | |
| if audio_path: | |
| audio_fname = os.path.basename(audio_path) | |
| for key in (audio_fname, audio_fname.replace("tada_preset_", "")): | |
| if key in _PROMPT_TRANSCRIPTS: | |
| prompt_text = _PROMPT_TRANSCRIPTS[key] | |
| break | |
| text_kwarg = [prompt_text] if prompt_text else None | |
| prompt = _encoder(audio, text=text_kwarg, sample_rate=sample_rate) | |
| prompt_html = _format_token_alignment(prompt) | |
| # --- Generate speech --- | |
| try: | |
| logger.info("Generating speech for text: %s", text) | |
| suf = float(speed_up_factor) if speed_up_factor > 0 else None | |
| t0 = time.time() | |
| output = _model.generate( | |
| prompt=prompt, | |
| text=text, | |
| num_transition_steps=0, | |
| num_extra_steps=int(num_extra_steps), | |
| normalize_text=normalize_text, | |
| inference_options=InferenceOptions( | |
| acoustic_cfg_scale=float(acoustic_cfg_scale), | |
| duration_cfg_scale=float(duration_cfg_scale), | |
| num_flow_matching_steps=int(num_flow_matching_steps), | |
| noise_temperature=float(noise_temperature), | |
| speed_up_factor=suf, | |
| time_schedule="logsnr", | |
| negative_condition_source=negative_condition_source, | |
| text_only_logit_scale=float(text_only_logit_scale), | |
| num_acoustic_candidates=int(num_acoustic_candidates), | |
| scorer=scorer, | |
| spkr_verification_weight=float(spkr_verification_weight), | |
| ), | |
| system_prompt="", | |
| ) | |
| wall_time = time.time() - t0 | |
| wav = output.audio[0].detach().cpu().float() | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| tmp_path = os.path.join(tempfile.gettempdir(), f"tada_output_{id(output)}.wav") | |
| torchaudio.save(tmp_path, wav, 24_000) | |
| audio_duration = wav.shape[-1] / 24_000 | |
| # Extract text-to-speak step_logs | |
| all_logs = output.step_logs or [] | |
| if text and output.input_text_ids is not None: | |
| input_ids = output.input_text_ids[0] | |
| seq_len = input_ids.shape[0] | |
| n_eos = _model.config.shift_acoustic | |
| normalized = normalize_text_fn(text) if normalize_text else text | |
| n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False)) | |
| text_end = seq_len - n_eos | |
| text_start = text_end - n_text_tokens | |
| log_by_step = {e["step"]: e for e in all_logs} | |
| text_logs = [] | |
| for s in range(text_start, text_end): | |
| if s in log_by_step: | |
| text_logs.append(log_by_step[s]) | |
| else: | |
| token_id = input_ids[s].item() | |
| token_str = _model.tokenizer.convert_ids_to_tokens([token_id])[0] | |
| text_logs.append({ | |
| "step": s, | |
| "token": token_str, | |
| "n_frames_before": 0, | |
| "n_frames_after": 0, | |
| "n_frames_src": "prefilled", | |
| "acoustic_mask": 1, | |
| "acoustic_feat_src": "prefilled", | |
| "acoustic_feat_norm": 0.0, | |
| }) | |
| generated_logs = text_logs | |
| else: | |
| generated_logs = all_logs | |
| generated_html = _format_step_logs(generated_logs, audio_duration, wall_time) | |
| return tmp_path, prompt_html, generated_html | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| logger.exception("Generation failed") | |
| raise gr.Error(f"Generation failed: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks( | |
| title="TADA Inference", | |
| css=( | |
| ".gradio-container { max-width: 1400px !important; width: 100% !important; margin: auto !important; } " | |
| ".compact-audio { min-height: 0 !important; } " | |
| ".compact-audio audio { height: 36px !important; } " | |
| ), | |
| ) as demo: | |
| gr.Markdown( | |
| "# TADA - Text-Acoustic Dual Alignment LLM\n" | |
| "A demo of **tada-3b-ml** \u2014 " | |
| "a text-to-speech model that clones voice, emotion, and timing from a short audio prompt.\n\n" | |
| "**How to use:** Choose a voice prompt (or upload your own), enter text, and click **Generate**. " | |
| "The model will encode the prompt and generate speech in one step." | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Text Settings", open=False): | |
| num_extra_steps = gr.Slider( | |
| minimum=0, maximum=200, value=0, step=1, | |
| label="Text Tokens to Generate", | |
| ) | |
| text_only_logit_scale = gr.Slider( | |
| minimum=0.0, maximum=5.0, value=0.0, step=0.1, | |
| label="Text-Only Logit Scale", | |
| info="0 = disabled. Blends text-only logits with audio-conditioned logits.", | |
| ) | |
| normalize_text_cb = gr.Checkbox( | |
| value=True, | |
| label="Normalize Text", | |
| info="Apply text normalization before generation", | |
| ) | |
| with gr.Accordion("Acoustic Settings", open=False): | |
| acoustic_cfg_scale = gr.Slider( | |
| minimum=1.0, maximum=3.0, value=1.6, step=0.1, | |
| label="Acoustic CFG Scale", | |
| ) | |
| duration_cfg_scale = gr.Slider( | |
| minimum=1.0, maximum=3.0, value=1.0, step=0.1, | |
| label="Duration CFG Scale", | |
| ) | |
| negative_condition_source = gr.Dropdown( | |
| choices=["negative_step_output", "prompt", "zero"], | |
| value="negative_step_output", | |
| label="Negative Condition Source", | |
| ) | |
| noise_temperature = gr.Slider( | |
| minimum=0.4, maximum=1.2, value=0.9, step=0.1, | |
| label="Noise Temperature", | |
| ) | |
| num_flow_matching_steps = gr.Slider( | |
| minimum=5, maximum=50, value=20, step=5, | |
| label="Flow Matching Steps", | |
| ) | |
| speed_up_factor = gr.Slider( | |
| minimum=0.0, maximum=3.0, value=0.0, step=0.1, | |
| label="Speed Up Factor", | |
| info="0 = disabled (natural duration). >0 scales speech speed.", | |
| ) | |
| num_acoustic_candidates = gr.Slider( | |
| minimum=1, maximum=16, value=1, step=1, | |
| label="Acoustic Candidates", | |
| info="Number of candidates to generate and rank.", | |
| ) | |
| scorer_dropdown = gr.Dropdown( | |
| choices=["likelihood", "spkr_verification", "duration_median"], | |
| value="likelihood", | |
| label="Scorer", | |
| info="How to rank acoustic candidates.", | |
| ) | |
| spkr_verification_weight = gr.Slider( | |
| minimum=0.0, maximum=5.0, value=1.0, step=0.1, | |
| label="Speaker Verification Weight", | |
| info="Weight for spkr_verification scorer.", | |
| ) | |
| with gr.Column(scale=2): | |
| preset_choices = ["None (zero-shot)"] + list(_PRESET_SAMPLES.keys()) | |
| _default_voice = "fb_ears_emo_amusement_freeform.wav" | |
| preset_dropdown = gr.Dropdown( | |
| choices=preset_choices, | |
| value=_default_voice if _default_voice in _PRESET_SAMPLES else "None (zero-shot)", | |
| label="Voice Prompt", | |
| info="Pick a preset or upload / record your own", | |
| ) | |
| _default_voice_path = _PRESET_SAMPLES.get(_default_voice) | |
| audio_input = gr.Audio( | |
| label="Prompt Preview", | |
| type="filepath", | |
| sources=["upload", "microphone"], | |
| value=_default_voice_path, | |
| elem_classes=["compact-audio"], | |
| ) | |
| def _on_preset_selected(choice: str) -> str | None: | |
| if choice == "None (zero-shot)": | |
| return None | |
| path = _PRESET_SAMPLES.get(choice) | |
| if path is None: | |
| return None | |
| tmp_path = os.path.join(tempfile.gettempdir(), f"tada_preset_{choice}") | |
| shutil.copy2(path, tmp_path) | |
| return tmp_path | |
| preset_dropdown.change( | |
| fn=_on_preset_selected, | |
| inputs=[preset_dropdown], | |
| outputs=[audio_input], | |
| ) | |
| with gr.Accordion("Prompt Token Alignment", open=True): | |
| prompt_alignment = gr.HTML(value="Generate to see prompt alignment.") | |
| with gr.Column(scale=2): | |
| _default_transcript = "emo_interest_sentences" | |
| transcript_choices = ["(custom)"] + list(_PRESET_TRANSCRIPTS.keys()) | |
| transcript_dropdown = gr.Dropdown( | |
| choices=transcript_choices, | |
| value=_default_transcript if _default_transcript in _PRESET_TRANSCRIPTS else "(custom)", | |
| label="Transcript", | |
| info="Pick a preset or type your own below", | |
| ) | |
| text_input = gr.Textbox( | |
| label="Text to Speak", | |
| placeholder="Type what you want the model to say ...", | |
| autoscroll=False, | |
| max_lines=20, | |
| value=_PRESET_TRANSCRIPTS.get(_default_transcript, ""), | |
| ) | |
| def _on_transcript_selected(choice: str) -> str: | |
| if choice == "(custom)": | |
| return "" | |
| return _PRESET_TRANSCRIPTS.get(choice, "") | |
| transcript_dropdown.change( | |
| fn=_on_transcript_selected, | |
| inputs=[transcript_dropdown], | |
| outputs=[text_input], | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| # --- Output --- | |
| audio_output = gr.Audio(label="Generated Audio") | |
| with gr.Accordion("Generated Alignment", open=False): | |
| generated_text_display = gr.HTML(value="Generate speech to see the alignment") | |
| # Wire up generate button | |
| all_inputs = [ | |
| audio_input, | |
| text_input, | |
| num_extra_steps, | |
| noise_temperature, | |
| acoustic_cfg_scale, | |
| duration_cfg_scale, | |
| num_flow_matching_steps, | |
| negative_condition_source, | |
| text_only_logit_scale, | |
| num_acoustic_candidates, | |
| scorer_dropdown, | |
| spkr_verification_weight, | |
| speed_up_factor, | |
| normalize_text_cb, | |
| ] | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=all_inputs, | |
| outputs=[audio_output, prompt_alignment, generated_text_display], | |
| ) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Entry-point | |
| # --------------------------------------------------------------------------- | |
| _share = os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes") | |
| _port = int(os.environ.get("GRADIO_PORT", "7860")) | |
| # `demo` at module scope so the `gradio` CLI / HF Spaces can discover it. | |
| demo = build_ui() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="TADA Inference Gradio App") | |
| parser.add_argument("--share", action="store_true", default=_share, help="Create a public Gradio share link") | |
| parser.add_argument("--port", type=int, default=_port, help="Server port (default: 7860)") | |
| args = parser.parse_args() | |
| demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=[_SAMPLES_DIR]) | |
| else: | |
| demo.launch(server_name="0.0.0.0", server_port=_port, share=_share, allowed_paths=[_SAMPLES_DIR]) | |