sharath25 commited on
Commit
dc3cb88
·
1 Parent(s): d65321c

simplify the gradio app and make it more stable

Browse files
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Gradio app for TADA inference.
3
 
4
  Usage:
5
  pip install hume-tada
@@ -8,9 +8,7 @@ Usage:
8
  GRADIO_SHARE=1 gradio app.py
9
  """
10
 
11
- import dataclasses
12
  import html
13
- import json
14
  import logging
15
  import os
16
  import shutil
@@ -37,27 +35,18 @@ logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
38
 
39
  # ---------------------------------------------------------------------------
40
- # Preset samples & transcripts
41
  # ---------------------------------------------------------------------------
42
  _script_dir = os.path.dirname(os.path.abspath(__file__))
43
  _SAMPLES_DIR = os.path.join(_script_dir, "samples")
44
 
45
  _AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac")
46
 
47
- LANGUAGE_MAP = {
48
- "English": None,
49
- "German": "de",
50
- "Japanese": "ja",
51
- }
52
 
53
- _CODE_TO_LANG_DIR = {None: "en", "de": "de", "ja": "ja"}
54
-
55
-
56
- def _discover_preset_samples(lang_code: str | None = None) -> dict[str, str]:
57
- """Return {display_name: absolute_path} for audio files in the language's samples/ subdir."""
58
  presets: dict[str, str] = {}
59
- lang_dir = _CODE_TO_LANG_DIR.get(lang_code, "en")
60
- search_dir = os.path.join(_SAMPLES_DIR, lang_dir)
61
  if not os.path.isdir(search_dir):
62
  return presets
63
  for fname in sorted(os.listdir(search_dir)):
@@ -66,108 +55,56 @@ def _discover_preset_samples(lang_code: str | None = None) -> dict[str, str]:
66
  return presets
67
 
68
 
69
- def _load_preset_transcripts(lang_code: str | None = None) -> dict[str, str]:
70
- """Load preset transcripts from synth_transcripts.json in the language's sample dir."""
71
- lang_dir = _CODE_TO_LANG_DIR.get(lang_code, "en")
72
- candidate = os.path.join(_SAMPLES_DIR, lang_dir, "synth_transcripts.json")
73
  if os.path.isfile(candidate):
74
  with open(candidate) as f:
75
  return json.load(f)
76
  return {}
77
 
78
 
79
- def _load_prompt_transcripts(lang_code: str | None = None) -> dict[str, str]:
80
- """Load prompt transcripts from prompt_transcripts.json in the language's sample dir.
81
-
82
- Returns {audio_filename: transcript} for providing ground-truth text to the encoder
83
- instead of relying on ASR (which may not support the target language).
84
- """
85
- lang_dir = _CODE_TO_LANG_DIR.get(lang_code, "en")
86
- candidate = os.path.join(_SAMPLES_DIR, lang_dir, "prompt_transcripts.json")
87
  if os.path.isfile(candidate):
88
  with open(candidate) as f:
89
  return json.load(f)
90
  return {}
91
 
92
 
93
- # Initialize with English samples
94
- _PRESET_SAMPLES = _discover_preset_samples(None)
95
- _PRESET_TRANSCRIPTS = _load_preset_transcripts(None)
96
- _PROMPT_TRANSCRIPTS = _load_prompt_transcripts(None)
97
  logger.info("Discovered %d preset audio samples, %d transcripts", len(_PRESET_SAMPLES), len(_PRESET_TRANSCRIPTS))
98
 
99
  # ---------------------------------------------------------------------------
100
- # Global model state
101
  # ---------------------------------------------------------------------------
102
- _MODEL_CHOICES = ["HumeAI/tada-1b", "HumeAI/tada-3b-ml"]
103
- _DEFAULT_MODEL = "HumeAI/tada-3b-ml"
104
-
105
- _MULTILINGUAL_MODELS = {"HumeAI/tada-3b-ml"}
106
-
107
-
108
- def _language_choices_for_model(model_name: str) -> list[str]:
109
- """Return the list of language display names available for the given model."""
110
- if model_name in _MULTILINGUAL_MODELS:
111
- return list(LANGUAGE_MAP.keys())
112
- return ["English"]
113
-
114
- _encoder_cache: dict[str | None, Encoder] = {}
115
- _model: TadaForCausalLM | None = None
116
- _current_model_name: str = ""
117
- _current_language: str | None = None
118
  _device = "cuda"
119
 
120
 
121
- def _move_encoder_output(output: EncoderOutput, device: str) -> EncoderOutput:
122
- """Move all tensor fields of an EncoderOutput to the given device."""
123
- kwargs = {}
124
- for f in dataclasses.fields(output):
125
- val = getattr(output, f.name)
126
- if isinstance(val, torch.Tensor):
127
- kwargs[f.name] = val.to(device)
128
- else:
129
- kwargs[f.name] = val
130
- return EncoderOutput(**kwargs)
131
-
132
-
133
- def get_encoder(language_code: str | None = None) -> Encoder:
134
- """Get or create an Encoder for the given language, with caching."""
135
- if language_code not in _encoder_cache:
136
- _encoder_cache[language_code] = Encoder.from_pretrained(
137
- "HumeAI/tada-codec", language=language_code
138
- ).to(_device)
139
- return _encoder_cache[language_code]
140
-
141
-
142
- def _get_device_info() -> str:
143
- if torch.cuda.is_available():
144
- names = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
145
- return f"CUDA - {', '.join(names)}"
146
- if torch.backends.mps.is_available():
147
- return "MPS (Apple Silicon)"
148
- return "CPU (ZeroGPU provides GPU during inference)"
149
-
150
-
151
- def load_models(model_name: str = _DEFAULT_MODEL) -> str:
152
- """Load encoder and TADA model. Returns a status string."""
153
- global _model, _current_model_name
154
 
155
- if _model is not None and _current_model_name == model_name:
156
- return f"Loaded: {model_name} on {_get_device_info()}"
157
 
158
- if _model is not None:
159
- del _model
160
- _model = None
161
 
162
- get_encoder(_current_language)
163
-
164
- logger.info("Loading %s ...", model_name)
165
- _model = TadaForCausalLM.from_pretrained(model_name)
166
-
167
- _current_model_name = model_name
168
- status = f"Loaded: {model_name} on {_get_device_info()}"
169
- logger.info(status)
170
- return status
171
 
172
 
173
  # ---------------------------------------------------------------------------
@@ -175,26 +112,6 @@ def load_models(model_name: str = _DEFAULT_MODEL) -> str:
175
  # ---------------------------------------------------------------------------
176
 
177
 
178
- def _encode_prompt(audio_path: str | None, language_code: str | None = None, prompt_text: str | None = None) -> EncoderOutput:
179
- """Encode an audio file into an EncoderOutput prompt (or return an empty one).
180
-
181
- If *prompt_text* is provided it is passed to the encoder for forced alignment,
182
- bypassing the built-in ASR transcription (which is English-only).
183
- """
184
- if audio_path is None or audio_path == "":
185
- return EncoderOutput.empty(_device)
186
-
187
- encoder = get_encoder(language_code)
188
- audio, sample_rate = torchaudio.load(audio_path)
189
- audio = audio.mean(dim=0, keepdim=True) # mono
190
- audio = audio / audio.abs().max().clamp(min=1e-8) * 0.95
191
- audio = audio.to(_device)
192
-
193
- text_kwarg = [prompt_text] if prompt_text else None
194
- prompt = encoder(audio, text=text_kwarg, sample_rate=sample_rate)
195
- return prompt
196
-
197
-
198
  def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]:
199
  """Decode a list of token IDs into per-token strings, handling multi-byte characters."""
200
  labels: list[str] = []
@@ -206,13 +123,12 @@ def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]:
206
  return labels
207
 
208
 
209
- def _format_token_alignment(prompt: EncoderOutput, language_code: str | None = None) -> str:
210
  """Build an HTML string: dots in grey, tokens as bold coloured spans."""
211
  if prompt.text_tokens is None or prompt.token_positions is None:
212
  return ""
213
 
214
- encoder = get_encoder(language_code)
215
- tokenizer = encoder.tokenizer
216
  n_tokens = (
217
  int(prompt.text_tokens_len[0].item()) if prompt.text_tokens_len is not None else prompt.text_tokens.shape[1]
218
  )
@@ -245,44 +161,9 @@ def _format_token_alignment(prompt: EncoderOutput, language_code: str | None = N
245
  )
246
 
247
 
248
- @gpu_decorator
249
- @torch.inference_mode()
250
- def process_prompt(audio_path: str | None, language: str = "English") -> tuple[str, EncoderOutput | None]:
251
- """Encode the voice prompt and return (alignment_html, prompt_on_cpu)."""
252
- global _current_language
253
- language_code = LANGUAGE_MAP.get(language)
254
- _current_language = language_code
255
-
256
- _encoder = get_encoder(language_code)
257
- _encoder.to(_device)
258
-
259
- if audio_path is None or audio_path == "":
260
- return "No audio provided (zero-shot mode).", None
261
-
262
- try:
263
- # Look up prompt transcript for preset samples (avoids English-only ASR for non-English audio)
264
- prompt_text = None
265
- if audio_path:
266
- audio_fname = os.path.basename(audio_path)
267
- # Check both the original filename and the preset-prefixed temp name
268
- for key in (audio_fname, audio_fname.replace("tada_preset_", "")):
269
- if key in _PROMPT_TRANSCRIPTS:
270
- prompt_text = _PROMPT_TRANSCRIPTS[key]
271
- break
272
-
273
- prompt = _encode_prompt(audio_path, language_code, prompt_text=prompt_text)
274
- alignment_html = _format_token_alignment(prompt, language_code)
275
- # Move to CPU for gr.State serialization (ZeroGPU compatibility)
276
- prompt_cpu = _move_encoder_output(prompt, "cpu")
277
- return alignment_html, prompt_cpu
278
- except Exception as e:
279
- logger.exception("Prompt processing failed")
280
- raise gr.Error(f"Prompt processing failed: {e}")
281
-
282
-
283
  def _decode_byte_tokens(raw_tokens: list[str]) -> list[str]:
284
  """Decode GPT-2 byte-level token strings into proper Unicode per-token labels."""
285
- if not raw_tokens or _model is None:
286
  return raw_tokens
287
  try:
288
  tokenizer = _model.tokenizer
@@ -324,9 +205,15 @@ def _format_step_logs(step_logs: list[dict], audio_duration: float, wall_time: f
324
  )
325
 
326
 
 
 
 
 
 
327
  @gpu_decorator(duration=120)
328
  @torch.inference_mode()
329
- def generate_speech(
 
330
  text: str,
331
  num_extra_steps: float = 0,
332
  noise_temperature: float = 0.9,
@@ -340,22 +227,43 @@ def generate_speech(
340
  spkr_verification_weight: float = 1.0,
341
  speed_up_factor: float = 0.0,
342
  normalize_text: bool = True,
343
- cached_prompt: EncoderOutput | None = None,
344
- ) -> tuple[str | None, str]:
345
- """Run TADA generation using the provided prompt and return (wav_path, alignment_html)."""
346
- if _model is None:
347
- raise gr.Error("Models are not loaded. Click 'Load Model' first.")
348
- if cached_prompt is None:
349
- raise gr.Error("Please upload audio and click 'Process Prompt' first.")
350
 
 
 
 
 
351
  _model.to(_device)
352
  _model.decoder.to(_device)
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  try:
355
- prompt = _move_encoder_output(cached_prompt, _device)
356
  logger.info("Generating speech for text: %s", text)
357
 
358
- # speed_up_factor: 0 means disabled (None)
359
  suf = float(speed_up_factor) if speed_up_factor > 0 else None
360
 
361
  t0 = time.time()
@@ -391,28 +299,24 @@ def generate_speech(
391
 
392
  audio_duration = wav.shape[-1] / 24_000
393
 
394
- # Extract only text-to-speak step_logs, reconstructing any prefilled (missing) entries
395
  all_logs = output.step_logs or []
396
- if _model is not None and text and output.input_text_ids is not None:
397
  input_ids = output.input_text_ids[0]
398
  seq_len = input_ids.shape[0]
399
  n_eos = _model.config.shift_acoustic
400
- # Count text-to-speak tokens (same logic as generate())
401
  normalized = normalize_text_fn(text) if normalize_text else text
402
  n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False))
403
  text_end = seq_len - n_eos
404
  text_start = text_end - n_text_tokens
405
 
406
- # Build a step -> log lookup from existing step_logs
407
  log_by_step = {e["step"]: e for e in all_logs}
408
 
409
- # Collect text-token entries, filling in any missing prefilled steps
410
  text_logs = []
411
  for s in range(text_start, text_end):
412
  if s in log_by_step:
413
  text_logs.append(log_by_step[s])
414
  else:
415
- # Prefilled step — reconstruct from input_text_ids
416
  token_id = input_ids[s].item()
417
  token_str = _model.tokenizer.convert_ids_to_tokens([token_id])[0]
418
  text_logs.append({
@@ -430,7 +334,7 @@ def generate_speech(
430
  generated_logs = all_logs
431
  generated_html = _format_step_logs(generated_logs, audio_duration, wall_time)
432
 
433
- return tmp_path, generated_html
434
 
435
  except gr.Error:
436
  raise
@@ -454,38 +358,9 @@ def build_ui() -> gr.Blocks:
454
  ),
455
  ) as demo:
456
  gr.Markdown("# TADA - Text-Acoustic Dual Alignment LLM")
457
- prompt_state = gr.State(value=None)
458
 
459
  with gr.Row(equal_height=False):
460
  with gr.Column(scale=1):
461
- with gr.Row():
462
- model_dropdown = gr.Dropdown(
463
- choices=_MODEL_CHOICES,
464
- value=_current_model_name or _DEFAULT_MODEL,
465
- label="Model",
466
- scale=3,
467
- )
468
- load_btn = gr.Button("Load Model", scale=1)
469
- load_status = gr.Textbox(label="Model Status", interactive=False, show_label=False)
470
-
471
- language_dd = gr.Dropdown(
472
- choices=list(LANGUAGE_MAP.keys()),
473
- value="English",
474
- label="Language",
475
- info="Selects the aligner for prompt encoding",
476
- )
477
-
478
- def _on_model_selected(model_name: str):
479
- """Update language choices when model changes."""
480
- choices = _language_choices_for_model(model_name)
481
- return gr.update(choices=choices, value="English")
482
-
483
- model_dropdown.change(
484
- fn=_on_model_selected,
485
- inputs=[model_dropdown],
486
- outputs=[language_dd],
487
- )
488
-
489
  with gr.Accordion("Text Settings", open=False):
490
  num_extra_steps = gr.Slider(
491
  minimum=0, maximum=200, value=0, step=1,
@@ -580,25 +455,8 @@ def build_ui() -> gr.Blocks:
580
  outputs=[audio_input],
581
  )
582
 
583
- def _on_language_changed(language: str):
584
- """Update preset samples and transcripts when language changes."""
585
- lang_code = LANGUAGE_MAP.get(language)
586
- samples = _discover_preset_samples(lang_code)
587
- new_preset_choices = ["None (zero-shot)"] + list(samples.keys())
588
- global _PRESET_SAMPLES, _PRESET_TRANSCRIPTS, _PROMPT_TRANSCRIPTS
589
- _PRESET_SAMPLES = samples
590
- _PRESET_TRANSCRIPTS = _load_preset_transcripts(lang_code)
591
- _PROMPT_TRANSCRIPTS = _load_prompt_transcripts(lang_code)
592
- new_transcript_choices = ["(custom)"] + list(_PRESET_TRANSCRIPTS.keys())
593
- first_sample = new_preset_choices[1] if len(new_preset_choices) > 1 else "None (zero-shot)"
594
- return (
595
- gr.update(choices=new_preset_choices, value=first_sample),
596
- gr.update(choices=new_transcript_choices, value="(custom)"),
597
- )
598
-
599
- process_prompt_btn = gr.Button("Process Prompt", variant="secondary", size="sm")
600
- with gr.Accordion("Token Alignment", open=True):
601
- prompt_alignment = gr.HTML(value="Upload audio and click <b>Process Prompt</b> before generating.")
602
 
603
  with gr.Column(scale=2):
604
  _default_transcript = "emo_interest_sentences"
@@ -630,52 +488,6 @@ def build_ui() -> gr.Blocks:
630
 
631
  generate_btn = gr.Button("Generate", variant="primary", size="lg")
632
 
633
- # --- Wire language change to update presets + re-process prompt ---
634
- language_dd.change(
635
- fn=_on_language_changed,
636
- inputs=[language_dd],
637
- outputs=[preset_dropdown, transcript_dropdown],
638
- )
639
-
640
- # --- Shared chain: show "Processing..." -> encode -> restore button ---
641
- def _wire_process_prompt(event):
642
- """Chain process_prompt onto any event."""
643
- event.then(
644
- fn=lambda: (gr.update(value="Processing...", interactive=False), ""),
645
- inputs=[],
646
- outputs=[process_prompt_btn, prompt_alignment],
647
- ).then(
648
- fn=process_prompt,
649
- inputs=[audio_input, language_dd],
650
- outputs=[prompt_alignment, prompt_state],
651
- ).then(
652
- fn=lambda: gr.update(value="Process Prompt", interactive=True),
653
- inputs=[],
654
- outputs=[process_prompt_btn],
655
- )
656
-
657
- # Manual click
658
- _wire_process_prompt(process_prompt_btn.click(fn=lambda: None, inputs=[], outputs=[]))
659
-
660
- # Load model (no auto-process; user must click Process Prompt)
661
- load_btn.click(
662
- fn=lambda: (gr.update(interactive=False), "Loading model..."),
663
- inputs=[],
664
- outputs=[load_btn, load_status],
665
- ).then(
666
- fn=load_models,
667
- inputs=[model_dropdown],
668
- outputs=[load_status],
669
- ).then(
670
- fn=lambda: gr.update(interactive=True),
671
- inputs=[],
672
- outputs=[load_btn],
673
- )
674
-
675
-
676
-
677
-
678
-
679
  # --- Output ---
680
  audio_output = gr.Audio(label="Generated Audio")
681
  with gr.Accordion("Generated Alignment", open=False):
@@ -683,6 +495,7 @@ def build_ui() -> gr.Blocks:
683
 
684
  # Wire up generate button
685
  all_inputs = [
 
686
  text_input,
687
  num_extra_steps,
688
  noise_temperature,
@@ -696,13 +509,12 @@ def build_ui() -> gr.Blocks:
696
  spkr_verification_weight,
697
  speed_up_factor,
698
  normalize_text_cb,
699
- prompt_state,
700
  ]
701
 
702
  generate_btn.click(
703
- fn=generate_speech,
704
  inputs=all_inputs,
705
- outputs=[audio_output, generated_text_display],
706
  )
707
 
708
  return demo
@@ -715,9 +527,6 @@ def build_ui() -> gr.Blocks:
715
  _share = os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes")
716
  _port = int(os.environ.get("GRADIO_PORT", "7860"))
717
 
718
- # Auto-load models on startup
719
- load_models()
720
-
721
  # `demo` at module scope so the `gradio` CLI / HF Spaces can discover it.
722
  demo = build_ui()
723
 
 
1
  """
2
+ Gradio app for TADA inference (English-only, single model).
3
 
4
  Usage:
5
  pip install hume-tada
 
8
  GRADIO_SHARE=1 gradio app.py
9
  """
10
 
 
11
  import html
 
12
  import logging
13
  import os
14
  import shutil
 
35
  logger = logging.getLogger(__name__)
36
 
37
  # ---------------------------------------------------------------------------
38
+ # Preset samples & transcripts (English only)
39
  # ---------------------------------------------------------------------------
40
  _script_dir = os.path.dirname(os.path.abspath(__file__))
41
  _SAMPLES_DIR = os.path.join(_script_dir, "samples")
42
 
43
  _AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac")
44
 
 
 
 
 
 
45
 
46
+ def _discover_preset_samples() -> dict[str, str]:
47
+ """Return {display_name: absolute_path} for audio files in samples/en/."""
 
 
 
48
  presets: dict[str, str] = {}
49
+ search_dir = os.path.join(_SAMPLES_DIR, "en")
 
50
  if not os.path.isdir(search_dir):
51
  return presets
52
  for fname in sorted(os.listdir(search_dir)):
 
55
  return presets
56
 
57
 
58
+ def _load_preset_transcripts() -> dict[str, str]:
59
+ """Load preset transcripts from synth_transcripts.json."""
60
+ import json
61
+ candidate = os.path.join(_SAMPLES_DIR, "en", "synth_transcripts.json")
62
  if os.path.isfile(candidate):
63
  with open(candidate) as f:
64
  return json.load(f)
65
  return {}
66
 
67
 
68
+ def _load_prompt_transcripts() -> dict[str, str]:
69
+ """Load prompt transcripts from prompt_transcripts.json."""
70
+ import json
71
+ candidate = os.path.join(_SAMPLES_DIR, "en", "prompt_transcripts.json")
 
 
 
 
72
  if os.path.isfile(candidate):
73
  with open(candidate) as f:
74
  return json.load(f)
75
  return {}
76
 
77
 
78
+ _PRESET_SAMPLES = _discover_preset_samples()
79
+ _PRESET_TRANSCRIPTS = _load_preset_transcripts()
80
+ _PROMPT_TRANSCRIPTS = _load_prompt_transcripts()
 
81
  logger.info("Discovered %d preset audio samples, %d transcripts", len(_PRESET_SAMPLES), len(_PRESET_TRANSCRIPTS))
82
 
83
  # ---------------------------------------------------------------------------
84
+ # Global model state — single model, single encoder
85
  # ---------------------------------------------------------------------------
86
+ _MODEL_NAME = "HumeAI/tada-3b-ml"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  _device = "cuda"
88
 
89
 
90
+ def _validate_no_meta_tensors(model, name: str = "model"):
91
+ """Raise if any parameter is on the meta device (not materialised)."""
92
+ for param_name, param in model.named_parameters():
93
+ if param.device.type == "meta":
94
+ raise RuntimeError(
95
+ f"{name} has meta-device parameter: {param_name}. "
96
+ "Pass low_cpu_mem_usage=False to from_pretrained()."
97
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
 
99
 
100
+ logger.info("Loading encoder ...")
101
+ _encoder = Encoder.from_pretrained("HumeAI/tada-codec", language=None, low_cpu_mem_usage=False).to(_device)
102
+ _validate_no_meta_tensors(_encoder, "Encoder")
103
 
104
+ logger.info("Loading %s ...", _MODEL_NAME)
105
+ _model = TadaForCausalLM.from_pretrained(_MODEL_NAME, low_cpu_mem_usage=False)
106
+ _validate_no_meta_tensors(_model, "TadaForCausalLM")
107
+ logger.info("Models loaded.")
 
 
 
 
 
108
 
109
 
110
  # ---------------------------------------------------------------------------
 
112
  # ---------------------------------------------------------------------------
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def _decode_tokens_individually(tokenizer, token_ids: list[int]) -> list[str]:
116
  """Decode a list of token IDs into per-token strings, handling multi-byte characters."""
117
  labels: list[str] = []
 
123
  return labels
124
 
125
 
126
+ def _format_token_alignment(prompt: EncoderOutput) -> str:
127
  """Build an HTML string: dots in grey, tokens as bold coloured spans."""
128
  if prompt.text_tokens is None or prompt.token_positions is None:
129
  return ""
130
 
131
+ tokenizer = _encoder.tokenizer
 
132
  n_tokens = (
133
  int(prompt.text_tokens_len[0].item()) if prompt.text_tokens_len is not None else prompt.text_tokens.shape[1]
134
  )
 
161
  )
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def _decode_byte_tokens(raw_tokens: list[str]) -> list[str]:
165
  """Decode GPT-2 byte-level token strings into proper Unicode per-token labels."""
166
+ if not raw_tokens:
167
  return raw_tokens
168
  try:
169
  tokenizer = _model.tokenizer
 
205
  )
206
 
207
 
208
+ # ---------------------------------------------------------------------------
209
+ # Single generate function (merged prompt encoding + generation)
210
+ # ---------------------------------------------------------------------------
211
+
212
+
213
  @gpu_decorator(duration=120)
214
  @torch.inference_mode()
215
+ def generate(
216
+ audio_path: str | None,
217
  text: str,
218
  num_extra_steps: float = 0,
219
  noise_temperature: float = 0.9,
 
227
  spkr_verification_weight: float = 1.0,
228
  speed_up_factor: float = 0.0,
229
  normalize_text: bool = True,
230
+ ) -> tuple[str | None, str, str]:
231
+ """Encode prompt + generate speech in a single GPU call.
 
 
 
 
 
232
 
233
+ Returns (wav_path, prompt_alignment_html, generated_alignment_html).
234
+ """
235
+ # Move model + encoder to GPU
236
+ _encoder.to(_device)
237
  _model.to(_device)
238
  _model.decoder.to(_device)
239
 
240
+ # --- Encode prompt ---
241
+ if audio_path is None or audio_path == "":
242
+ prompt = EncoderOutput.empty(_device)
243
+ prompt_html = "No audio provided (zero-shot mode)."
244
+ else:
245
+ audio, sample_rate = torchaudio.load(audio_path)
246
+ audio = audio.mean(dim=0, keepdim=True) # mono
247
+ audio = audio / audio.abs().max().clamp(min=1e-8) * 0.95
248
+ audio = audio.to(_device)
249
+
250
+ # Look up prompt transcript for preset samples
251
+ prompt_text = None
252
+ if audio_path:
253
+ audio_fname = os.path.basename(audio_path)
254
+ for key in (audio_fname, audio_fname.replace("tada_preset_", "")):
255
+ if key in _PROMPT_TRANSCRIPTS:
256
+ prompt_text = _PROMPT_TRANSCRIPTS[key]
257
+ break
258
+
259
+ text_kwarg = [prompt_text] if prompt_text else None
260
+ prompt = _encoder(audio, text=text_kwarg, sample_rate=sample_rate)
261
+ prompt_html = _format_token_alignment(prompt)
262
+
263
+ # --- Generate speech ---
264
  try:
 
265
  logger.info("Generating speech for text: %s", text)
266
 
 
267
  suf = float(speed_up_factor) if speed_up_factor > 0 else None
268
 
269
  t0 = time.time()
 
299
 
300
  audio_duration = wav.shape[-1] / 24_000
301
 
302
+ # Extract text-to-speak step_logs
303
  all_logs = output.step_logs or []
304
+ if text and output.input_text_ids is not None:
305
  input_ids = output.input_text_ids[0]
306
  seq_len = input_ids.shape[0]
307
  n_eos = _model.config.shift_acoustic
 
308
  normalized = normalize_text_fn(text) if normalize_text else text
309
  n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False))
310
  text_end = seq_len - n_eos
311
  text_start = text_end - n_text_tokens
312
 
 
313
  log_by_step = {e["step"]: e for e in all_logs}
314
 
 
315
  text_logs = []
316
  for s in range(text_start, text_end):
317
  if s in log_by_step:
318
  text_logs.append(log_by_step[s])
319
  else:
 
320
  token_id = input_ids[s].item()
321
  token_str = _model.tokenizer.convert_ids_to_tokens([token_id])[0]
322
  text_logs.append({
 
334
  generated_logs = all_logs
335
  generated_html = _format_step_logs(generated_logs, audio_duration, wall_time)
336
 
337
+ return tmp_path, prompt_html, generated_html
338
 
339
  except gr.Error:
340
  raise
 
358
  ),
359
  ) as demo:
360
  gr.Markdown("# TADA - Text-Acoustic Dual Alignment LLM")
 
361
 
362
  with gr.Row(equal_height=False):
363
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  with gr.Accordion("Text Settings", open=False):
365
  num_extra_steps = gr.Slider(
366
  minimum=0, maximum=200, value=0, step=1,
 
455
  outputs=[audio_input],
456
  )
457
 
458
+ with gr.Accordion("Prompt Token Alignment", open=True):
459
+ prompt_alignment = gr.HTML(value="Generate to see prompt alignment.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  with gr.Column(scale=2):
462
  _default_transcript = "emo_interest_sentences"
 
488
 
489
  generate_btn = gr.Button("Generate", variant="primary", size="lg")
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  # --- Output ---
492
  audio_output = gr.Audio(label="Generated Audio")
493
  with gr.Accordion("Generated Alignment", open=False):
 
495
 
496
  # Wire up generate button
497
  all_inputs = [
498
+ audio_input,
499
  text_input,
500
  num_extra_steps,
501
  noise_temperature,
 
509
  spkr_verification_weight,
510
  speed_up_factor,
511
  normalize_text_cb,
 
512
  ]
513
 
514
  generate_btn.click(
515
+ fn=generate,
516
  inputs=all_inputs,
517
+ outputs=[audio_output, prompt_alignment, generated_text_display],
518
  )
519
 
520
  return demo
 
527
  _share = os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes")
528
  _port = int(os.environ.get("GRADIO_PORT", "7860"))
529
 
 
 
 
530
  # `demo` at module scope so the `gradio` CLI / HF Spaces can discover it.
531
  demo = build_ui()
532
 
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- hume-tada==0.1.6
2
  descript-audio-codec
3
  transformers==4.57.3
4
  gradio==6.5.1
 
5
  spaces
 
1
+ hume-tada==0.1.7
2
  descript-audio-codec
3
  transformers==4.57.3
4
  gradio==6.5.1
5
+ accelerate==1.6.0
6
  spaces
samples/de/prompt_transcripts.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "segment_002.wav": "Patsy, mich nennst du Tarsuin, Rika ist Rika, Ma, kannst du Miss McAllis? Neferra, korrigierte seine Mutter. Und Due, fügte sein Vater hinzu. Nennen, also, Neferra und Due. Tarsuin deutete sicherheitshalber auf seine Eltern.",
3
- "segment_005.wav": "Die regionale UNIQA Generalagentur. Vertrauen, Versichern, Vorsorgen. Als Ihre Versicherungsagentur vor Ort bieten wir maßgeschneiderte Versicherungslösungen. Sowohl für den Privat- als auch für den Geschäftskunden.",
4
- "segment_007.wav": "Ja, also interessanterweise finde ich meine Position gar nicht so kritisch, sondern es ist halt eine typisch gesundheitswissenschaftliche Position. Die schaut halt sehr auf das ganze Bild, also versucht irgendwie alle Parameter im Blick zu haben. Und jetzt nicht nur das virale Geschehen, sondern ich schaue durchaus auch auf ökonomische Dinge, obwohl ich jetzt kein Ökonom bin, aber Wirtschaft und Arbeitslosigkeit hängen halt sehr eng mit Gesundheit zusammen.",
5
- "segment_010.wav": "Und ja, wie gesagt, diesen Beitrag hier kann ich euch empfehlen von transparentberaten.de. Ist kostenloses Streaming illegal? Könnt ihr es einfach nochmal durchlesen. Das ist so pauschal einfach, was das heißt. Also ich meine, falls irgendjemand hier Angst hat oder sich denkt, ja, aber vielleicht ist es ja doch legal. Nein, es ist nicht. Und hier könnt ihr es nochmal nachlesen. Aber ja."
6
- }
 
 
 
 
 
 
 
samples/de/segment_002.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e938f9cafef1ab7969dca137a1dbebd919593bc0881ce04ada9872c379c1eddd
3
- size 513078
 
 
 
 
samples/de/segment_005.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d233c8afaa7d9f5a1edb4949ef67f8d08a6eca9e977d6c2285d83b88c6efee97
3
- size 486078
 
 
 
 
samples/de/segment_007.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:05068225f5fb75350af2f39a3fbcb0b32cbac4d07790ee29b9f52b33b5f0e5a0
3
- size 1533138
 
 
 
 
samples/de/segment_010.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e749c6f2128ebaf60ba33b6a585434529383e7f8f09552ba74dacc1163946c7c
3
- size 585436
 
 
 
 
samples/de/synth_transcripts.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "casual_conversation": "Also, ich muss dir unbedingt erzählen, was mir gestern passiert ist. Ich war im Supermarkt, ganz normal einkaufen, und plötzlich steht mein alter Schulfreund vor mir. Den hab ich bestimmt zehn Jahre nicht gesehen! Wir haben dann einfach eine halbe Stunde im Gang gestanden und gequatscht, während die anderen Leute an uns vorbeigeschoben haben. Er wohnt jetzt in München und arbeitet bei so einem Start-up, irgendwas mit erneuerbaren Energien. Wir haben Nummern ausgetauscht und wollen uns nächste Woche auf einen Kaffee treffen. Ist schon verrückt, wie das Leben manchmal so spielt, oder?",
3
- "storytelling": "Meine Oma hat mir als Kind immer diese eine Geschichte erzählt, von einem kleinen Fuchs, der sich im Schwarzwald verlaufen hat. Der Fuchs war noch ganz jung und hatte sich zu weit von seiner Familie entfernt. Es wurde dunkel und er konnte den Weg nicht mehr finden. Aber dann hat er eine alte Eule getroffen, die oben in einer riesigen Eiche saß. Die Eule hat ihm gesagt, er soll einfach dem Bach folgen, denn der führt immer nach Hause. Und tatsächlich, der kleine Fuchs ist am Wasser entlanggelaufen und hat seine Familie wiedergefunden. Ich fand die Geschichte damals so beruhigend, und ehrlich gesagt, denke ich heute noch manchmal daran, wenn ich mich irgendwo verloren fühle.",
4
- "news_report": "In Berlin haben heute Tausende Menschen für mehr Klimaschutz demonstriert. Der Protestzug bewegte sich vom Brandenburger Tor bis zum Regierungsviertel, wo die Teilnehmer konkrete Maßnahmen gegen den Klimawandel forderten. Die Polizei schätzte die Zahl der Demonstranten auf etwa fünfzehntausend. Organisatoren sprachen von deutlich mehr Teilnehmern. Die Bundesregierung reagierte mit einer Erklärung, in der sie die Bedeutung des Klimaschutzes betonte und auf bereits beschlossene Maßnahmen verwies. Weitere Demonstrationen sind für das kommende Wochenende in Hamburg und München angekündigt.",
5
- "emotional_reflection": "Manchmal frage ich mich, ob ich die richtigen Entscheidungen getroffen hab, weißt du? Nicht, dass ich unglücklich bin oder so, aber es gibt diese Momente, wo man einfach nachdenkt. Letztes Jahr bin ich umgezogen, neue Stadt, neuer Job, und am Anfang war alles aufregend. Aber dann kommen die ruhigen Abende, wo man alleine in der Wohnung sitzt und merkt, dass man hier noch niemanden richtig kennt. Es wird besser, ganz langsam. Ich hab angefangen, in einen Sportverein zu gehen, und die Leute da sind echt nett. Es braucht halt einfach Zeit, sich irgendwo zu Hause zu fühlen. Das vergisst man manchmal.",
6
- "travel_experience": "Letzten Sommer sind wir mit dem Zug durch die Schweiz gefahren, und ich muss sagen, das war einer der schönsten Urlaube, die ich je hatte. Wir haben den Bernina Express genommen, dieser Zug, der über die Alpen fährt. Die Aussicht war einfach unglaublich, überall schneebedeckte Gipfel und türkisfarbene Seen. In Luzern haben wir dann zwei Tage verbracht und sind auf den Pilatus gewandert. Oben war es so klar, dass man gefühlt bis nach Italien gucken konnte. Abends haben wir in so einem kleinen Restaurant Käsefondue gegessen, und das war so gut, dass wir am nächsten Abend nochmal hingegangen sind.",
7
- "food_and_cooking": "Ich hab am Wochenende zum ersten Mal versucht, Sauerteigbrot selbst zu backen, und Mann, das ist echt eine Wissenschaft für sich. Allein den Sauerteig-Starter zu machen, hat fünf Tage gedauert. Jeden Tag füttern, umrühren, warten. Und dann der Teig selbst, der musste zwölf Stunden gehen, dann nochmal falten, dann nochmal vier Stunden. Aber als ich das Brot dann aus dem Ofen geholt hab und diese perfekte Kruste gesehen hab, das war schon ein tolles Gefühl. Innen war es schön luftig mit großen Poren. Mein Mitbewohner meinte, es schmeckt besser als vom Bäcker. Das war wahrscheinlich übertrieben, aber trotzdem hat mich das richtig gefreut."
8
- }
 
 
 
 
 
 
 
 
 
samples/ja/prompt_transcripts.json DELETED
@@ -1,7 +0,0 @@
1
- {
2
- "segment_003.wav": "じゃあ日本語もメロディーみたいな感じで覚えるで何言ってるか分かんないときはなんかタカタカとかサカサカとかカタカタって言ってればいいんですよ日本語って大体そういう音だから",
3
- "segment_009.wav": "でもその部分は言わないでAさんはねなんか僕が話すといつもね嫌な顔するんですよとか言っちゃうと",
4
- "segment_013.wav": "私も基本的にそうなんですね。でもその普通の日常をどうやったら楽しくすることができるかなって考えた時に、1日1回何か新しいことをしたらいいんじゃないかなと思いました。本当にちっちゃいことでいいです。",
5
- "segment_020.wav": "それでも自民党を守ろうとする人、安倍さんを守ろうとする人たちは鼻で笑うんでしょうね。そんな人らが最近はようラジオ出てはりますよね。で、はーはーって言ってますよね。",
6
- "segment_023.wav": "このムキムキのお兄さんがいるし バーだし少し高そうだと思いますよねこのバーの料金設定は良心的でした まあそんなに高くなかったです"
7
- }
 
 
 
 
 
 
 
 
samples/ja/segment_003.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb1a6756a6850fb93c30443fc2b2ffcaae21da94c5557ad0e85313e00136c1df
3
- size 371598
 
 
 
 
samples/ja/segment_009.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8840d4f51ba1ae3d0feaa46562684aa2f9087b3a2401c04b1f0071525a0afe3
3
- size 335958
 
 
 
 
samples/ja/segment_013.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:831c1f12d699579c9a9912b8a6c4998706a126ac735a1d0a277cccf59b61bf0a
3
- size 1361418
 
 
 
 
samples/ja/segment_020.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8df94013091664f51214e9d8873fb8a085e5b7c8b05eb6f2dfd01d46d024fb8
3
- size 592996
 
 
 
 
samples/ja/segment_023.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e150c51e33ac04d7f828ba3c07f787be3f9e6c11f7184ec39181ca870ad57834
3
- size 410478
 
 
 
 
samples/ja/synth_transcripts.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "casual_conversation": "いやー、最近ちょっとバタバタしててさ、なかなか連絡できなくてごめんね。先週末やっと時間できたから、久しぶりに駅前のカフェに行ったんだけど、あそこリニューアルしたの知ってた?内装がすごくおしゃれになってて、メニューも全然変わっててびっくりしちゃった。抹茶のティラミスっていう新しいデザートがあって、それがもう本当においしくて。写真撮ったから今度見せるね。あ、そうだ、来週の土曜日空いてる?もし良かったら一緒に行かない?",
3
- "storytelling": "子供の頃、おばあちゃんの家の裏に大きな竹林があってね、夏になると毎日そこで遊んでたんだ。ある日、竹林の奥の方に行ったら、見たことない小さな祠を見つけたの。苔がびっしり生えてて、すごく古い感じだった。おばあちゃんに聞いたら、昔この辺りに住んでいた人たちが水の神様を祀ってたらしいんだよね。それからなんか、その場所が特別に感じられて、雨の日もこっそり見に行ったりしてた。今思うと、あれが冒険心みたいなものの始まりだったのかもしれないなぁ。",
4
- "news_report": "本日未明、北海道の広い範囲で記録的な大雪が観測されました。札幌市では24時間で60センチの降雪を記録し、交通機関に大きな影響が出ています。JR北海道は始発から運転を見合わせており、新千歳空港でも100便以上が欠航となっています。気象庁によりますと、この大雪は明日の昼頃まで続く見込みで、引き続き不要不急の外出を控えるよう呼びかけています。特に屋根の雪下ろしの際には十分な注意が必要だということです。",
5
- "emotional_reflection": "最近さ、ふと立ち止まって考えることがあるんだよね。毎日忙しくて、目の前のことをこなすのに精一杯で、大事なことを見落としてないかなって。この前、昔の友達から急に連絡が来て、「元気?」ってたった一言だったんだけど、それだけですごく嬉しくて。人とのつながりって、当たり前じゃないんだなって改めて思ったんだ。もっと自分から連絡取るようにしなきゃなって。忙しいって言い訳にしちゃダメだよね、本当に。",
6
- "travel_experience": "去年の秋に京都に行ったんだけど、もう紅葉がすごくてさ。嵐山のトロッコ列車に乗ったら、窓の外が一面真っ赤で、まるで絵の中に入ったみたいだった。途中で列車がゆっくり停まるポイントがあって、そこで写真撮れるんだけど、みんな一斉にカメラ構えるのがちょっと面白かった。その後、嵯峨野の竹林を歩いたんだけど、人が多くても不思議と静かな感じがするんだよね。風で竹がさわさわ揺れる音がすごく心地よくて、しばらくぼーっと立ってた。",
7
- "food_and_cooking": "昨日初めて本格的なラーメンを作ってみたんだけど、スープがもう大変でさ。豚骨を8時間煮込んだんだよ、8時間だよ?途中で何回もアクを取って、火加減調整して。でもその甲斐があって、すっごく濃厚な白濁スープができたの。麺は製麺機がないから買ったやつなんだけど、チャーシューは自分で作った。醤油と味醂と生姜で煮込んで、最後にバーナーで炙ったらお店みたいな仕上がりになって感動しちゃった。次は味噌ラーメンに挑戦しようかな。"
8
- }