Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["KERAS_BACKEND"] = "jax" | |
| import numpy as np | |
| import jax | |
| import jax.numpy as jnp | |
| from loading_model import generator | |
| from configuration import LATENT_DIM, RESOLUTIONS | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| WEIGHTS_H5 = os.path.join(BASE_DIR, "generator.weights.h5") | |
| CHECKPOINT_PKL = os.path.join(BASE_DIR, "checkpoint.pkl") | |
| _gen_state = None | |
| _use_stateless = False | |
| _saved_alpha = 1.0 | |
| TARGET_RES = 256 | |
| def _postprocess(images: np.ndarray) -> list: | |
| images = (np.array(images) + 1.0) / 2.0 | |
| images = np.clip(images, 0.0, 1.0) | |
| return [images[i] for i in range(images.shape[0])] | |
| def _build_model(): | |
| dummy_alpha = jnp.array(1.0) | |
| dummy_rng = jax.random.PRNGKey(0) | |
| dummy_z = jnp.zeros((1, LATENT_DIM)) | |
| for res in RESOLUTIONS: | |
| generator.current_resolution = res | |
| _ = generator(dummy_z, alpha=dummy_alpha, rng_key=dummy_rng) | |
| generator.current_resolution = TARGET_RES | |
| print(f"β Model built β trainable: {len(generator.trainable_variables)} non-trainable: {len(generator.non_trainable_variables)}") | |
| def load_weights(resolution=256): | |
| global _gen_state, _use_stateless, _saved_alpha | |
| _build_model() | |
| if os.path.isfile(WEIGHTS_H5): | |
| print(f"π Loading Keras weights: {WEIGHTS_H5}") | |
| generator.load_weights(WEIGHTS_H5) | |
| _use_stateless = False | |
| print("β Keras weights loaded.") | |
| elif os.path.isfile(CHECKPOINT_PKL): | |
| print(f"π Loading checkpoint: {CHECKPOINT_PKL}") | |
| import joblib | |
| data = joblib.load(CHECKPOINT_PKL) | |
| print(f" Keys: {list(data.keys())}") | |
| # Support both key naming conventions | |
| gen_trainable = ( | |
| data.get("ema_trainable") or # β new key from your training code | |
| data.get("gen_trainable") # β old key | |
| ) | |
| gen_non_trainable = ( | |
| data.get("gen_non_trainable") or | |
| data.get("non_trainable", []) | |
| ) | |
| _saved_alpha = float(data.get("alpha", 1.0)) | |
| print(f" trainable={len(gen_trainable)} non_trainable={len(gen_non_trainable)} alpha={_saved_alpha:.4f}") | |
| n_model = len(generator.trainable_variables) | |
| if len(gen_trainable) != n_model: | |
| raise ValueError( | |
| f"β Mismatch β checkpoint:{len(gen_trainable)} model:{n_model}\n" | |
| f" Make sure _build_model() mirrors training exactly." | |
| ) | |
| n_non = len(generator.non_trainable_variables) | |
| _gen_state = { | |
| "ema_trainable": [jnp.asarray(t) for t in gen_trainable], | |
| "non_trainable": ( | |
| [jnp.asarray(t) for t in gen_non_trainable[:n_non]] | |
| if gen_non_trainable | |
| else [jnp.asarray(v.value) for v in generator.non_trainable_variables] | |
| ), | |
| } | |
| _use_stateless = True | |
| print("β Checkpoint loaded.") | |
| else: | |
| raise FileNotFoundError(f"β No weights found in: {BASE_DIR}") | |
| generator.current_resolution = TARGET_RES | |
| print(f"πΌοΈ Resolution locked to: {TARGET_RES}") | |
| return True | |
| def generate_images(n_images, resolution=None, seed=None): | |
| if not _use_stateless and _gen_state is None: | |
| raise RuntimeError("β Call load_weights() first.") | |
| if seed is None: | |
| seed = int(np.random.randint(0, 2**31)) | |
| rng = jax.random.PRNGKey(seed) | |
| if _use_stateless and _gen_state is not None: | |
| return _generate_stateless(n_images, rng) | |
| return _generate_keras(n_images, rng) | |
| def _generate_keras(n_images, rng): | |
| generator.current_resolution = TARGET_RES | |
| rng, z_key, noise_key = jax.random.split(rng, 3) | |
| z = jax.random.normal(z_key, (n_images, LATENT_DIM)) | |
| out = generator(z, alpha=1.0, rng_key=noise_key) | |
| images = out[0] if isinstance(out, (list, tuple)) else out | |
| return _postprocess(np.array(images)) | |
| def _generate_stateless(n_images, rng): | |
| generator.current_resolution = TARGET_RES | |
| rng, z_key, noise_key = jax.random.split(rng, 3) | |
| z = jax.random.normal(z_key, (n_images, LATENT_DIM)) | |
| images, _ = generator.stateless_call( | |
| _gen_state["ema_trainable"], | |
| _gen_state["non_trainable"], | |
| z, | |
| jnp.array(_saved_alpha), | |
| noise_key, | |
| ) | |
| return _postprocess(np.array(images)) |