import joblib import jax.numpy as jnp from loading_model import generator from inference import _build_model _build_model() # Load checkpoint ck = joblib.load("weights/checkpoint.pkl") print("=== CHECKPOINT variables ===") for i, w in enumerate(ck["gen_trainable"]): print(f" [{i:03d}] shape={str(w.shape):30s}") print("\n=== MODEL variables ===") for i, v in enumerate(generator.trainable_variables): print(f" [{i:03d}] shape={str(v.shape):30s} name={v.name}")