from Imports import * from Configuration import * from Model_Loading import * from Weights_Loading import model_state from preprocessing import * # ── JIT-compiled single decoder step ───────────────────────────────────────── @jax.jit def jit_encoder(enc_ema, enc_nt, ids_jnp): enc_out, enc_nt_new = encoder_model.stateless_call( enc_ema, enc_nt, ids_jnp, training=False ) return enc_out, enc_nt_new @jax.jit def jit_decoder_step(dec_ema, dec_nt, enc_out, prev_mel, h1_h, h1_c, h2_h, h2_c, prev_attn, cum_attn, context, text_mask): outputs, new_nt = decoder_model.stateless_call( dec_ema, dec_nt, [enc_out, prev_mel, h1_h, h1_c, h2_h, h2_c, prev_attn, cum_attn, context, text_mask], training=False ) return outputs, new_nt @jax.jit def jit_postnet(post_ema, post_nt, mel_out): mel_residual, new_nt = postnet_model.stateless_call( post_ema, post_nt, mel_out, training=False ) return mel_residual, new_nt @jax.jit def jit_vocoder(voc_ema, voc_nt, mel_out): wav_out, new_nt = vocoder.stateless_call( voc_ema, voc_nt, mel_out, training=False ) return wav_out, new_nt def inference( text, model_state, vocoder_state, max_steps=MAX_MEL_LEN, THRESHOLD_STOP=0.5, EOS_ATTN_THRESHOLD=0.2, EOS_CONSEC_STEPS=3, progress=None, ): # keep original ids for reference ids_full = text_to_ids_tf(text).numpy() # filter out space tokens before encoding ids_full = text_to_ids_tf(text).numpy() space_id = char2id[' '] pause_id = char2id[','] # comma — model learned to pause here # replace space with pause token instead of removing ids = np.where(ids_full == space_id, pause_id, ids_full) if ids[-1] != EOS_ID: ids = np.append(ids, EOS_ID) ids_jnp = jnp.array(ids)[None, :] # ── encoder ─────────────────────────────────────────────────────────────── enc_out, _ = jit_encoder( model_state['enc']['ema'], model_state['enc']['non_trainable'], ids_jnp ) text_len = enc_out.shape[1] eos_pos = text_len - 1 text_mask = jnp.ones((1, text_len), dtype=jnp.float32) h1_h = h1_c = jnp.zeros((1, 1024)) h2_h = h2_c = jnp.zeros((1, 1024)) prev_attn = jnp.zeros((1, text_len)) cum_attn = jnp.zeros((1, text_len)) context = jnp.zeros((1, 512)) prev_mel = jnp.zeros((1, NUM_MEL_BINS)) mel_frames = [] attn_frames = [] eos_consec = 0 total_steps = max_steps // R if progress is not None: try: progress(0, desc="Decoding (autoregressive)…") except TypeError: progress(0) for step in tqdm(range(total_steps), desc="Decoding"): if progress is not None: frac = (step + 1) / max(total_steps, 1) try: progress(frac, desc=f"Decoding… {step + 1}/{total_steps}") except TypeError: progress(frac) outputs, _ = jit_decoder_step( model_state['dec']['ema'], model_state['dec']['non_trainable'], enc_out, prev_mel, h1_h, h1_c, h2_h, h2_c, prev_attn, cum_attn, context, text_mask ) mel_frames_r, stop_tok, h1_h, h1_c, h2_h, h2_c, prev_attn, context = outputs cum_attn = cum_attn + prev_attn mel_frames_r = jnp.reshape(mel_frames_r, (1, R, NUM_MEL_BINS)) for r_i in range(R): mel_frames.append(mel_frames_r[:, r_i, :]) attn_frames.append(np.array(prev_attn[0])) prev_mel = mel_frames_r[:, -1, :] stop_prob = float(jax.nn.sigmoid(stop_tok[0, 0])) eos_weight = float(cum_attn[0, eos_pos]) attn_position = float(jnp.argmax(cum_attn[0])) progress_ratio = attn_position / max(text_len - 1, 1) if step % 50 == 0: print(f" step {step:04d} | stop={stop_prob:.3f} | " f"eos_cum_attn={eos_weight:.3f} | attn_pos={attn_position:.0f}/{text_len-1}") if progress_ratio > 0.85 and eos_weight > EOS_ATTN_THRESHOLD: eos_consec += 1 else: eos_consec = 0 if eos_consec >= EOS_CONSEC_STEPS: print(f" stopped at step {step} — EOS attention ({eos_weight:.3f} > {EOS_ATTN_THRESHOLD})") break if stop_prob > THRESHOLD_STOP: print(f" stopped at step {step} — stop token ({stop_prob:.3f})") break if len(mel_frames) >= max_steps: print(f"WARNING: hit max_steps={max_steps}, model did not stop") mel_out = jnp.stack(mel_frames, axis=1) attn_matrix = np.stack(attn_frames, axis=0) mel_residual, _ = jit_postnet( model_state['post']['ema'], model_state['post']['non_trainable'], mel_out ) mel_out = mel_out + mel_residual print(f"postnet residual mean abs: {jnp.mean(jnp.abs(mel_residual)):.6f}") wav_out, _ = jit_vocoder( vocoder_state['ema'], vocoder_state['non_trainable'], mel_out ) mel_np = np.array(mel_out[0]).T wav_np = np.array(wav_out[0, :, 0]) wav_np = wav_np.astype(np.float32) wav_np = wav_np - float(np.mean(wav_np)) return mel_np, wav_np