Spaces:
Sleeping
Sleeping
| from imports import jax,jnp,np | |
| from loading_model import generator | |
| def inference_and_plot(gen_state, resolution, phase='stable'): | |
| rng = gen_state['rng'] | |
| # ββ split into 3 keys: next state, z, noise ββββββββββ | |
| rng, z_key, noise_key = jax.random.split(rng, 3) | |
| fixed_z = jax.random.normal(z_key, [16, 512]) | |
| fake_images, _ = generator.stateless_call( | |
| gen_state['ema_trainable'], | |
| gen_state['non_trainable'], | |
| jnp.array(fixed_z), | |
| jnp.array(1.0), | |
| noise_key, # β dedicated noise key | |
| ) | |
| print(f"Fake image range: {float(jnp.min(fake_images)):.3f} to {float(jnp.max(fake_images)):.3f}") | |
| gen_state = {**gen_state, 'rng': rng} # β advance state with consumed rng | |
| fake_images = (fake_images + 1.0) / 2.0 | |
| fake_images = jnp.clip(fake_images, 0.0, 1.0) | |
| fake_images = np.array(fake_images) | |
| return gen_state |