StyleGAN / inference.py
masterofaudio2077's picture
Upload 3 files
7048b74 verified
Raw
History Blame Contribute Delete
4.46 kB
import os
os.environ["KERAS_BACKEND"] = "jax"
import numpy as np
import jax
import jax.numpy as jnp
from loading_model import generator
from configuration import LATENT_DIM, RESOLUTIONS
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
WEIGHTS_H5 = os.path.join(BASE_DIR, "generator.weights.h5")
CHECKPOINT_PKL = os.path.join(BASE_DIR, "checkpoint.pkl")
_gen_state = None
_use_stateless = False
_saved_alpha = 1.0
TARGET_RES = 256
def _postprocess(images: np.ndarray) -> list:
images = (np.array(images) + 1.0) / 2.0
images = np.clip(images, 0.0, 1.0)
return [images[i] for i in range(images.shape[0])]
def _build_model():
dummy_alpha = jnp.array(1.0)
dummy_rng = jax.random.PRNGKey(0)
dummy_z = jnp.zeros((1, LATENT_DIM))
for res in RESOLUTIONS:
generator.current_resolution = res
_ = generator(dummy_z, alpha=dummy_alpha, rng_key=dummy_rng)
generator.current_resolution = TARGET_RES
print(f"βœ… Model built β€” trainable: {len(generator.trainable_variables)} non-trainable: {len(generator.non_trainable_variables)}")
def load_weights(resolution=256):
global _gen_state, _use_stateless, _saved_alpha
_build_model()
if os.path.isfile(WEIGHTS_H5):
print(f"πŸ“‚ Loading Keras weights: {WEIGHTS_H5}")
generator.load_weights(WEIGHTS_H5)
_use_stateless = False
print("βœ… Keras weights loaded.")
elif os.path.isfile(CHECKPOINT_PKL):
print(f"πŸ“‚ Loading checkpoint: {CHECKPOINT_PKL}")
import joblib
data = joblib.load(CHECKPOINT_PKL)
print(f" Keys: {list(data.keys())}")
# Support both key naming conventions
gen_trainable = (
data.get("ema_trainable") or # ← new key from your training code
data.get("gen_trainable") # ← old key
)
gen_non_trainable = (
data.get("gen_non_trainable") or
data.get("non_trainable", [])
)
_saved_alpha = float(data.get("alpha", 1.0))
print(f" trainable={len(gen_trainable)} non_trainable={len(gen_non_trainable)} alpha={_saved_alpha:.4f}")
n_model = len(generator.trainable_variables)
if len(gen_trainable) != n_model:
raise ValueError(
f"❌ Mismatch β€” checkpoint:{len(gen_trainable)} model:{n_model}\n"
f" Make sure _build_model() mirrors training exactly."
)
n_non = len(generator.non_trainable_variables)
_gen_state = {
"ema_trainable": [jnp.asarray(t) for t in gen_trainable],
"non_trainable": (
[jnp.asarray(t) for t in gen_non_trainable[:n_non]]
if gen_non_trainable
else [jnp.asarray(v.value) for v in generator.non_trainable_variables]
),
}
_use_stateless = True
print("βœ… Checkpoint loaded.")
else:
raise FileNotFoundError(f"❌ No weights found in: {BASE_DIR}")
generator.current_resolution = TARGET_RES
print(f"πŸ–ΌοΈ Resolution locked to: {TARGET_RES}")
return True
def generate_images(n_images, resolution=None, seed=None):
if not _use_stateless and _gen_state is None:
raise RuntimeError("❌ Call load_weights() first.")
if seed is None:
seed = int(np.random.randint(0, 2**31))
rng = jax.random.PRNGKey(seed)
if _use_stateless and _gen_state is not None:
return _generate_stateless(n_images, rng)
return _generate_keras(n_images, rng)
def _generate_keras(n_images, rng):
generator.current_resolution = TARGET_RES
rng, z_key, noise_key = jax.random.split(rng, 3)
z = jax.random.normal(z_key, (n_images, LATENT_DIM))
out = generator(z, alpha=1.0, rng_key=noise_key)
images = out[0] if isinstance(out, (list, tuple)) else out
return _postprocess(np.array(images))
def _generate_stateless(n_images, rng):
generator.current_resolution = TARGET_RES
rng, z_key, noise_key = jax.random.split(rng, 3)
z = jax.random.normal(z_key, (n_images, LATENT_DIM))
images, _ = generator.stateless_call(
_gen_state["ema_trainable"],
_gen_state["non_trainable"],
z,
jnp.array(_saved_alpha),
noise_key,
)
return _postprocess(np.array(images))