TTSIE / Inference.py
masterofaudio2077's picture
Upload 13 files
c39b616 verified
Raw
History Blame Contribute Delete
5.68 kB
from Imports import *
from Configuration import *
from Model_Loading import *
from Weights_Loading import model_state
from preprocessing import *
# ── JIT-compiled single decoder step ─────────────────────────────────────────
@jax.jit
def jit_encoder(enc_ema, enc_nt, ids_jnp):
enc_out, enc_nt_new = encoder_model.stateless_call(
enc_ema, enc_nt, ids_jnp, training=False
)
return enc_out, enc_nt_new
@jax.jit
def jit_decoder_step(dec_ema, dec_nt, enc_out,
prev_mel, h1_h, h1_c, h2_h, h2_c,
prev_attn, cum_attn, context, text_mask):
outputs, new_nt = decoder_model.stateless_call(
dec_ema, dec_nt,
[enc_out, prev_mel, h1_h, h1_c, h2_h, h2_c,
prev_attn, cum_attn, context, text_mask],
training=False
)
return outputs, new_nt
@jax.jit
def jit_postnet(post_ema, post_nt, mel_out):
mel_residual, new_nt = postnet_model.stateless_call(
post_ema, post_nt, mel_out, training=False
)
return mel_residual, new_nt
@jax.jit
def jit_vocoder(voc_ema, voc_nt, mel_out):
wav_out, new_nt = vocoder.stateless_call(
voc_ema, voc_nt, mel_out, training=False
)
return wav_out, new_nt
def inference(
text,
model_state,
vocoder_state,
max_steps=MAX_MEL_LEN,
THRESHOLD_STOP=0.5,
EOS_ATTN_THRESHOLD=0.2,
EOS_CONSEC_STEPS=3,
progress=None,
):
# keep original ids for reference
ids_full = text_to_ids_tf(text).numpy()
# filter out space tokens before encoding
ids_full = text_to_ids_tf(text).numpy()
space_id = char2id[' ']
pause_id = char2id[','] # comma β€” model learned to pause here
# replace space with pause token instead of removing
ids = np.where(ids_full == space_id, pause_id, ids_full)
if ids[-1] != EOS_ID:
ids = np.append(ids, EOS_ID)
ids_jnp = jnp.array(ids)[None, :]
# ── encoder ───────────────────────────────────────────────────────────────
enc_out, _ = jit_encoder(
model_state['enc']['ema'],
model_state['enc']['non_trainable'],
ids_jnp
)
text_len = enc_out.shape[1]
eos_pos = text_len - 1
text_mask = jnp.ones((1, text_len), dtype=jnp.float32)
h1_h = h1_c = jnp.zeros((1, 1024))
h2_h = h2_c = jnp.zeros((1, 1024))
prev_attn = jnp.zeros((1, text_len))
cum_attn = jnp.zeros((1, text_len))
context = jnp.zeros((1, 512))
prev_mel = jnp.zeros((1, NUM_MEL_BINS))
mel_frames = []
attn_frames = []
eos_consec = 0
total_steps = max_steps // R
if progress is not None:
try:
progress(0, desc="Decoding (autoregressive)…")
except TypeError:
progress(0)
for step in tqdm(range(total_steps), desc="Decoding"):
if progress is not None:
frac = (step + 1) / max(total_steps, 1)
try:
progress(frac, desc=f"Decoding… {step + 1}/{total_steps}")
except TypeError:
progress(frac)
outputs, _ = jit_decoder_step(
model_state['dec']['ema'],
model_state['dec']['non_trainable'],
enc_out, prev_mel,
h1_h, h1_c, h2_h, h2_c,
prev_attn, cum_attn, context, text_mask
)
mel_frames_r, stop_tok, h1_h, h1_c, h2_h, h2_c, prev_attn, context = outputs
cum_attn = cum_attn + prev_attn
mel_frames_r = jnp.reshape(mel_frames_r, (1, R, NUM_MEL_BINS))
for r_i in range(R):
mel_frames.append(mel_frames_r[:, r_i, :])
attn_frames.append(np.array(prev_attn[0]))
prev_mel = mel_frames_r[:, -1, :]
stop_prob = float(jax.nn.sigmoid(stop_tok[0, 0]))
eos_weight = float(cum_attn[0, eos_pos])
attn_position = float(jnp.argmax(cum_attn[0]))
progress_ratio = attn_position / max(text_len - 1, 1)
if step % 50 == 0:
print(f" step {step:04d} | stop={stop_prob:.3f} | "
f"eos_cum_attn={eos_weight:.3f} | attn_pos={attn_position:.0f}/{text_len-1}")
if progress_ratio > 0.85 and eos_weight > EOS_ATTN_THRESHOLD:
eos_consec += 1
else:
eos_consec = 0
if eos_consec >= EOS_CONSEC_STEPS:
print(f" stopped at step {step} β€” EOS attention ({eos_weight:.3f} > {EOS_ATTN_THRESHOLD})")
break
if stop_prob > THRESHOLD_STOP:
print(f" stopped at step {step} β€” stop token ({stop_prob:.3f})")
break
if len(mel_frames) >= max_steps:
print(f"WARNING: hit max_steps={max_steps}, model did not stop")
mel_out = jnp.stack(mel_frames, axis=1)
attn_matrix = np.stack(attn_frames, axis=0)
mel_residual, _ = jit_postnet(
model_state['post']['ema'],
model_state['post']['non_trainable'],
mel_out
)
mel_out = mel_out + mel_residual
print(f"postnet residual mean abs: {jnp.mean(jnp.abs(mel_residual)):.6f}")
wav_out, _ = jit_vocoder(
vocoder_state['ema'],
vocoder_state['non_trainable'],
mel_out
)
mel_np = np.array(mel_out[0]).T
wav_np = np.array(wav_out[0, :, 0])
wav_np = wav_np.astype(np.float32)
wav_np = wav_np - float(np.mean(wav_np))
return mel_np, wav_np