Spaces:
Running on Zero
Running on Zero
fix alignment
Browse files
app.py
CHANGED
|
@@ -31,6 +31,7 @@ except ImportError:
|
|
| 31 |
|
| 32 |
from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402
|
| 33 |
from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402
|
|
|
|
| 34 |
|
| 35 |
logging.basicConfig(level=logging.INFO)
|
| 36 |
logger = logging.getLogger(__name__)
|
|
@@ -358,21 +359,17 @@ def generate_speech(
|
|
| 358 |
|
| 359 |
audio_duration = wav.shape[-1] / 24_000
|
| 360 |
|
| 361 |
-
# Extract only
|
| 362 |
all_logs = output.step_logs or []
|
| 363 |
if _model is not None and text and output.input_text_ids is not None:
|
| 364 |
input_ids = output.input_text_ids[0]
|
| 365 |
seq_len = input_ids.shape[0]
|
| 366 |
n_eos = _model.config.shift_acoustic
|
| 367 |
-
#
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
text_start = 0
|
| 371 |
-
for i in range(seq_len - 1, -1, -1):
|
| 372 |
-
if input_ids[i].item() == end_header_id:
|
| 373 |
-
text_start = i + 1
|
| 374 |
-
break
|
| 375 |
text_end = seq_len - n_eos
|
|
|
|
| 376 |
|
| 377 |
# Build a step -> log lookup from existing step_logs
|
| 378 |
log_by_step = {e["step"]: e for e in all_logs}
|
|
|
|
| 31 |
|
| 32 |
from tada.modules.encoder import Encoder, EncoderOutput # noqa: E402
|
| 33 |
from tada.modules.tada import InferenceOptions, TadaForCausalLM # noqa: E402
|
| 34 |
+
from tada.utils.text import normalize_text as normalize_text_fn # noqa: E402
|
| 35 |
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
| 37 |
logger = logging.getLogger(__name__)
|
|
|
|
| 359 |
|
| 360 |
audio_duration = wav.shape[-1] / 24_000
|
| 361 |
|
| 362 |
+
# Extract only text-to-speak step_logs, reconstructing any prefilled (missing) entries
|
| 363 |
all_logs = output.step_logs or []
|
| 364 |
if _model is not None and text and output.input_text_ids is not None:
|
| 365 |
input_ids = output.input_text_ids[0]
|
| 366 |
seq_len = input_ids.shape[0]
|
| 367 |
n_eos = _model.config.shift_acoustic
|
| 368 |
+
# Count text-to-speak tokens (same logic as generate())
|
| 369 |
+
normalized = normalize_text_fn(text) if normalize_text else text
|
| 370 |
+
n_text_tokens = len(_model.tokenizer.encode(normalized, add_special_tokens=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
text_end = seq_len - n_eos
|
| 372 |
+
text_start = text_end - n_text_tokens
|
| 373 |
|
| 374 |
# Build a step -> log lookup from existing step_logs
|
| 375 |
log_by_step = {e["step"]: e for e in all_logs}
|