StyleGAN / generate.py
masterofaudio2077's picture
Upload 11 files
18dcb10 verified
Raw
History Blame Contribute Delete
952 Bytes
from imports import jax,jnp,np
from loading_model import generator
def inference_and_plot(gen_state, resolution, phase='stable'):
rng = gen_state['rng']
# ── split into 3 keys: next state, z, noise ──────────
rng, z_key, noise_key = jax.random.split(rng, 3)
fixed_z = jax.random.normal(z_key, [16, 512])
fake_images, _ = generator.stateless_call(
gen_state['ema_trainable'],
gen_state['non_trainable'],
jnp.array(fixed_z),
jnp.array(1.0),
noise_key, # ← dedicated noise key
)
print(f"Fake image range: {float(jnp.min(fake_images)):.3f} to {float(jnp.max(fake_images)):.3f}")
gen_state = {**gen_state, 'rng': rng} # ← advance state with consumed rng
fake_images = (fake_images + 1.0) / 2.0
fake_images = jnp.clip(fake_images, 0.0, 1.0)
fake_images = np.array(fake_images)
return gen_state