Spaces:
Sleeping
Sleeping
| from Imports import * | |
| from Configuration import * | |
| from Model_Loading import * | |
| from Weights_Loading import model_state | |
| from preprocessing import * | |
| # ββ JIT-compiled single decoder step βββββββββββββββββββββββββββββββββββββββββ | |
| def jit_encoder(enc_ema, enc_nt, ids_jnp): | |
| enc_out, enc_nt_new = encoder_model.stateless_call( | |
| enc_ema, enc_nt, ids_jnp, training=False | |
| ) | |
| return enc_out, enc_nt_new | |
| def jit_decoder_step(dec_ema, dec_nt, enc_out, | |
| prev_mel, h1_h, h1_c, h2_h, h2_c, | |
| prev_attn, cum_attn, context, text_mask): | |
| outputs, new_nt = decoder_model.stateless_call( | |
| dec_ema, dec_nt, | |
| [enc_out, prev_mel, h1_h, h1_c, h2_h, h2_c, | |
| prev_attn, cum_attn, context, text_mask], | |
| training=False | |
| ) | |
| return outputs, new_nt | |
| def jit_postnet(post_ema, post_nt, mel_out): | |
| mel_residual, new_nt = postnet_model.stateless_call( | |
| post_ema, post_nt, mel_out, training=False | |
| ) | |
| return mel_residual, new_nt | |
| def jit_vocoder(voc_ema, voc_nt, mel_out): | |
| wav_out, new_nt = vocoder.stateless_call( | |
| voc_ema, voc_nt, mel_out, training=False | |
| ) | |
| return wav_out, new_nt | |
| def inference( | |
| text, | |
| model_state, | |
| vocoder_state, | |
| max_steps=MAX_MEL_LEN, | |
| THRESHOLD_STOP=0.5, | |
| EOS_ATTN_THRESHOLD=0.2, | |
| EOS_CONSEC_STEPS=3, | |
| progress=None, | |
| ): | |
| # keep original ids for reference | |
| ids_full = text_to_ids_tf(text).numpy() | |
| # filter out space tokens before encoding | |
| ids_full = text_to_ids_tf(text).numpy() | |
| space_id = char2id[' '] | |
| pause_id = char2id[','] # comma β model learned to pause here | |
| # replace space with pause token instead of removing | |
| ids = np.where(ids_full == space_id, pause_id, ids_full) | |
| if ids[-1] != EOS_ID: | |
| ids = np.append(ids, EOS_ID) | |
| ids_jnp = jnp.array(ids)[None, :] | |
| # ββ encoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| enc_out, _ = jit_encoder( | |
| model_state['enc']['ema'], | |
| model_state['enc']['non_trainable'], | |
| ids_jnp | |
| ) | |
| text_len = enc_out.shape[1] | |
| eos_pos = text_len - 1 | |
| text_mask = jnp.ones((1, text_len), dtype=jnp.float32) | |
| h1_h = h1_c = jnp.zeros((1, 1024)) | |
| h2_h = h2_c = jnp.zeros((1, 1024)) | |
| prev_attn = jnp.zeros((1, text_len)) | |
| cum_attn = jnp.zeros((1, text_len)) | |
| context = jnp.zeros((1, 512)) | |
| prev_mel = jnp.zeros((1, NUM_MEL_BINS)) | |
| mel_frames = [] | |
| attn_frames = [] | |
| eos_consec = 0 | |
| total_steps = max_steps // R | |
| if progress is not None: | |
| try: | |
| progress(0, desc="Decoding (autoregressive)β¦") | |
| except TypeError: | |
| progress(0) | |
| for step in tqdm(range(total_steps), desc="Decoding"): | |
| if progress is not None: | |
| frac = (step + 1) / max(total_steps, 1) | |
| try: | |
| progress(frac, desc=f"Decoding⦠{step + 1}/{total_steps}") | |
| except TypeError: | |
| progress(frac) | |
| outputs, _ = jit_decoder_step( | |
| model_state['dec']['ema'], | |
| model_state['dec']['non_trainable'], | |
| enc_out, prev_mel, | |
| h1_h, h1_c, h2_h, h2_c, | |
| prev_attn, cum_attn, context, text_mask | |
| ) | |
| mel_frames_r, stop_tok, h1_h, h1_c, h2_h, h2_c, prev_attn, context = outputs | |
| cum_attn = cum_attn + prev_attn | |
| mel_frames_r = jnp.reshape(mel_frames_r, (1, R, NUM_MEL_BINS)) | |
| for r_i in range(R): | |
| mel_frames.append(mel_frames_r[:, r_i, :]) | |
| attn_frames.append(np.array(prev_attn[0])) | |
| prev_mel = mel_frames_r[:, -1, :] | |
| stop_prob = float(jax.nn.sigmoid(stop_tok[0, 0])) | |
| eos_weight = float(cum_attn[0, eos_pos]) | |
| attn_position = float(jnp.argmax(cum_attn[0])) | |
| progress_ratio = attn_position / max(text_len - 1, 1) | |
| if step % 50 == 0: | |
| print(f" step {step:04d} | stop={stop_prob:.3f} | " | |
| f"eos_cum_attn={eos_weight:.3f} | attn_pos={attn_position:.0f}/{text_len-1}") | |
| if progress_ratio > 0.85 and eos_weight > EOS_ATTN_THRESHOLD: | |
| eos_consec += 1 | |
| else: | |
| eos_consec = 0 | |
| if eos_consec >= EOS_CONSEC_STEPS: | |
| print(f" stopped at step {step} β EOS attention ({eos_weight:.3f} > {EOS_ATTN_THRESHOLD})") | |
| break | |
| if stop_prob > THRESHOLD_STOP: | |
| print(f" stopped at step {step} β stop token ({stop_prob:.3f})") | |
| break | |
| if len(mel_frames) >= max_steps: | |
| print(f"WARNING: hit max_steps={max_steps}, model did not stop") | |
| mel_out = jnp.stack(mel_frames, axis=1) | |
| attn_matrix = np.stack(attn_frames, axis=0) | |
| mel_residual, _ = jit_postnet( | |
| model_state['post']['ema'], | |
| model_state['post']['non_trainable'], | |
| mel_out | |
| ) | |
| mel_out = mel_out + mel_residual | |
| print(f"postnet residual mean abs: {jnp.mean(jnp.abs(mel_residual)):.6f}") | |
| wav_out, _ = jit_vocoder( | |
| vocoder_state['ema'], | |
| vocoder_state['non_trainable'], | |
| mel_out | |
| ) | |
| mel_np = np.array(mel_out[0]).T | |
| wav_np = np.array(wav_out[0, :, 0]) | |
| wav_np = wav_np.astype(np.float32) | |
| wav_np = wav_np - float(np.mean(wav_np)) | |
| return mel_np, wav_np |