from imports import jax,jnp,np from loading_model import generator def inference_and_plot(gen_state, resolution, phase='stable'): rng = gen_state['rng'] # ── split into 3 keys: next state, z, noise ────────── rng, z_key, noise_key = jax.random.split(rng, 3) fixed_z = jax.random.normal(z_key, [16, 512]) fake_images, _ = generator.stateless_call( gen_state['ema_trainable'], gen_state['non_trainable'], jnp.array(fixed_z), jnp.array(1.0), noise_key, # ← dedicated noise key ) print(f"Fake image range: {float(jnp.min(fake_images)):.3f} to {float(jnp.max(fake_images)):.3f}") gen_state = {**gen_state, 'rng': rng} # ← advance state with consumed rng fake_images = (fake_images + 1.0) / 2.0 fake_images = jnp.clip(fake_images, 0.0, 1.0) fake_images = np.array(fake_images) return gen_state