sharath25 commited on
Commit
024f3d7
·
1 Parent(s): 7a6b47b

fix alignment

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -31,6 +31,7 @@ except ImportError:
31
 
32
  from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402
33
  from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402
 
34
 
35
  logging.basicConfig(level=logging.INFO)
36
  logger = logging.getLogger(__name__)
@@ -358,21 +359,17 @@ def generate_speech(
358
 
359
  audio_duration = wav.shape[-1] / 24_000
360
 
361
- # Extract only user-text step_logs, reconstructing any prefilled (missing) entries
362
  all_logs = output.step_logs or []
363
  if _model is not None and text and output.input_text_ids is not None:
364
  input_ids = output.input_text_ids[0]
365
  seq_len = input_ids.shape[0]
366
  n_eos = _model.config.shift_acoustic
367
- # Find text boundary: last <|end_header_id|> token marks end of assistant header
368
- end_header_id = _model.tokenizer.convert_tokens_to_ids("<|end_header_id|>")
369
- # Scan backwards for the last end_header token
370
- text_start = 0
371
- for i in range(seq_len - 1, -1, -1):
372
- if input_ids[i].item() == end_header_id:
373
- text_start = i + 1
374
- break
375
  text_end = seq_len - n_eos
 
376
 
377
  # Build a step -> log lookup from existing step_logs
378
  log_by_step = {e["step"]: e for e in all_logs}
 
31
 
32
  from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402
33
  from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402
34
+ from tada.utils.text import normalize_text as normalize_text_fn # noqa: E402
35
 
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
 
359
 
360
  audio_duration = wav.shape[-1] / 24_000
361
 
362
+ # Extract only text-to-speak step_logs, reconstructing any prefilled (missing) entries
363
  all_logs = output.step_logs or []
364
  if _model is not None and text and output.input_text_ids is not None:
365
  input_ids = output.input_text_ids[0]
366
  seq_len = input_ids.shape[0]
367
  n_eos = _model.config.shift_acoustic
368
+ # Count text-to-speak tokens (same logic as generate())
369
+ normalized = normalize_text_fn(text) if normalize_text else text
370
+ n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False))
 
 
 
 
 
371
  text_end = seq_len - n_eos
372
+ text_start = text_end - n_text_tokens
373
 
374
  # Build a step -> log lookup from existing step_logs
375
  log_by_step = {e["step"]: e for e in all_logs}