--- tags: - jax - flax - flax-nnx - diffusion library_name: diffuse --- --- ## Mnist Generation Flow matching diffusion model trained for mnist generation. Use with [**diffuse**](https://github.com/jcopo/diffuse), a JAX/Flax sampling library. Light enough to run on CPU --- ## Model Details * **Framework:** JAX/Flax (NNX) * **Format:** msgpack * **Prediction Type:** Velocity (Flow Matching) --- ## Usage ### Download and Load Model ```python import os import jax import jax.numpy as jnp from flax import nnx, serialization from huggingface_hub import hf_hub_download # Download model weights and config model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack") config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py") # Load config to get model architecture import importlib.util spec = importlib.util.spec_from_file_location("model_config", config_path) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) # Initialize model from config model = config_module.model # Load weights with open(model_path, "rb") as f: state_dict = serialization.from_bytes(None, f.read()) # Restore weights into model graphdef, state = nnx.split(model) state.replace_by_pure_dict(state_dict) model = nnx.merge(graphdef, state) model.eval() # Set to evaluation mode print("✅ Model loaded successfully!") ```