|
|
--- |
|
|
tags: |
|
|
- jax |
|
|
- flax |
|
|
- flax-nnx |
|
|
- diffusion |
|
|
library_name: diffuse |
|
|
--- |
|
|
--- |
|
|
|
|
|
## Mnist Generation |
|
|
Flow matching diffusion model trained for mnist generation. |
|
|
Use with [**diffuse**](https://github.com/jcopo/diffuse), a JAX/Flax sampling library. |
|
|
Light enough to run on CPU |
|
|
|
|
|
--- |
|
|
|
|
|
## Model Details |
|
|
* **Framework:** JAX/Flax (NNX) |
|
|
* **Format:** msgpack |
|
|
* **Prediction Type:** Velocity (Flow Matching) |
|
|
|
|
|
--- |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Download and Load Model |
|
|
|
|
|
```python |
|
|
import os |
|
|
|
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
from flax import nnx, serialization |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download model weights and config |
|
|
model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack") |
|
|
config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py") |
|
|
|
|
|
# Load config to get model architecture |
|
|
import importlib.util |
|
|
spec = importlib.util.spec_from_file_location("model_config", config_path) |
|
|
config_module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(config_module) |
|
|
|
|
|
# Initialize model from config |
|
|
model = config_module.model |
|
|
|
|
|
# Load weights |
|
|
with open(model_path, "rb") as f: |
|
|
state_dict = serialization.from_bytes(None, f.read()) |
|
|
|
|
|
# Restore weights into model |
|
|
graphdef, state = nnx.split(model) |
|
|
state.replace_by_pure_dict(state_dict) |
|
|
model = nnx.merge(graphdef, state) |
|
|
model.eval() # Set to evaluation mode |
|
|
|
|
|
print("✅ Model loaded successfully!") |
|
|
``` |
|
|
|