import os os.environ["KERAS_BACKEND"] = "jax" import numpy as np import jax import jax.numpy as jnp from loading_model import generator from configuration import LATENT_DIM, RESOLUTIONS BASE_DIR = os.path.dirname(os.path.abspath(__file__)) WEIGHTS_H5 = os.path.join(BASE_DIR, "generator.weights.h5") CHECKPOINT_PKL = os.path.join(BASE_DIR, "checkpoint.pkl") _gen_state = None _use_stateless = False _saved_alpha = 1.0 TARGET_RES = 256 def _postprocess(images: np.ndarray) -> list: images = (np.array(images) + 1.0) / 2.0 images = np.clip(images, 0.0, 1.0) return [images[i] for i in range(images.shape[0])] def _build_model(): dummy_alpha = jnp.array(1.0) dummy_rng = jax.random.PRNGKey(0) dummy_z = jnp.zeros((1, LATENT_DIM)) for res in RESOLUTIONS: generator.current_resolution = res _ = generator(dummy_z, alpha=dummy_alpha, rng_key=dummy_rng) generator.current_resolution = TARGET_RES print(f"✅ Model built — trainable: {len(generator.trainable_variables)} non-trainable: {len(generator.non_trainable_variables)}") def load_weights(resolution=256): global _gen_state, _use_stateless, _saved_alpha _build_model() if os.path.isfile(WEIGHTS_H5): print(f"📂 Loading Keras weights: {WEIGHTS_H5}") generator.load_weights(WEIGHTS_H5) _use_stateless = False print("✅ Keras weights loaded.") elif os.path.isfile(CHECKPOINT_PKL): print(f"📂 Loading checkpoint: {CHECKPOINT_PKL}") import joblib data = joblib.load(CHECKPOINT_PKL) print(f" Keys: {list(data.keys())}") # Support both key naming conventions gen_trainable = ( data.get("ema_trainable") or # ← new key from your training code data.get("gen_trainable") # ← old key ) gen_non_trainable = ( data.get("gen_non_trainable") or data.get("non_trainable", []) ) _saved_alpha = float(data.get("alpha", 1.0)) print(f" trainable={len(gen_trainable)} non_trainable={len(gen_non_trainable)} alpha={_saved_alpha:.4f}") n_model = len(generator.trainable_variables) if len(gen_trainable) != n_model: raise ValueError( f"❌ Mismatch — checkpoint:{len(gen_trainable)} model:{n_model}\n" f" Make sure _build_model() mirrors training exactly." ) n_non = len(generator.non_trainable_variables) _gen_state = { "ema_trainable": [jnp.asarray(t) for t in gen_trainable], "non_trainable": ( [jnp.asarray(t) for t in gen_non_trainable[:n_non]] if gen_non_trainable else [jnp.asarray(v.value) for v in generator.non_trainable_variables] ), } _use_stateless = True print("✅ Checkpoint loaded.") else: raise FileNotFoundError(f"❌ No weights found in: {BASE_DIR}") generator.current_resolution = TARGET_RES print(f"🖼️ Resolution locked to: {TARGET_RES}") return True def generate_images(n_images, resolution=None, seed=None): if not _use_stateless and _gen_state is None: raise RuntimeError("❌ Call load_weights() first.") if seed is None: seed = int(np.random.randint(0, 2**31)) rng = jax.random.PRNGKey(seed) if _use_stateless and _gen_state is not None: return _generate_stateless(n_images, rng) return _generate_keras(n_images, rng) def _generate_keras(n_images, rng): generator.current_resolution = TARGET_RES rng, z_key, noise_key = jax.random.split(rng, 3) z = jax.random.normal(z_key, (n_images, LATENT_DIM)) out = generator(z, alpha=1.0, rng_key=noise_key) images = out[0] if isinstance(out, (list, tuple)) else out return _postprocess(np.array(images)) def _generate_stateless(n_images, rng): generator.current_resolution = TARGET_RES rng, z_key, noise_key = jax.random.split(rng, 3) z = jax.random.normal(z_key, (n_images, LATENT_DIM)) images, _ = generator.stateless_call( _gen_state["ema_trainable"], _gen_state["non_trainable"], z, jnp.array(_saved_alpha), noise_key, ) return _postprocess(np.array(images))