jcopo commited on
Commit
bf471c5
·
verified ·
1 Parent(s): f0e2f9c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -17
README.md CHANGED
@@ -6,31 +6,40 @@ tags:
6
  - diffusion
7
  library_name: diffuse
8
  ---
 
9
 
10
- # mnist
 
 
 
11
 
12
- Diffusion model for use with [diffuse](https://github.com/jcopo/diffuse), a JAX/Flax sampling library.
13
 
14
  ## Model Details
 
 
 
15
 
16
- - **Model Type**: CondUNet2D
17
- - **Checkpoint Step**: 45,000
18
- - **Precision**: float32
19
- - **Framework**: JAX/Flax (NNX)
20
- - **Format**: msgpack
21
 
22
  ## Usage
23
 
 
 
24
  ```python
 
 
 
 
25
  from flax import nnx, serialization
26
  from huggingface_hub import hf_hub_download
27
- import importlib.util
28
 
29
  # Download model weights and config
30
- model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack")
31
- config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py")
32
 
33
  # Load config to get model architecture
 
34
  spec = importlib.util.spec_from_file_location("model_config", config_path)
35
  config_module = importlib.util.module_from_spec(spec)
36
  spec.loader.exec_module(config_module)
@@ -43,13 +52,10 @@ with open(model_path, "rb") as f:
43
  state_dict = serialization.from_bytes(None, f.read())
44
 
45
  # Restore weights into model
46
- nnx.update(model, state_dict)
 
 
47
  model.eval() # Set to evaluation mode
48
 
49
- # Now use the model for inference
50
- # output = model(input_data)
51
  ```
52
-
53
- ## Training Configuration
54
-
55
- This model was trained with the Triax framework using the configuration saved in the checkpoint.
 
6
  - diffusion
7
  library_name: diffuse
8
  ---
9
+ ---
10
 
11
+ ## Mnist Generation
12
+ Flow matching diffusion model trained for mnist generation.
13
+ Use with [**diffuse**](https://github.com/jcopo/diffuse), a JAX/Flax sampling library.
14
+ Light enough to run on CPU
15
 
16
+ ---
17
 
18
  ## Model Details
19
+ * **Framework:** JAX/Flax (NNX)
20
+ * **Format:** msgpack
21
+ * **Prediction Type:** Velocity (Flow Matching)
22
 
23
+ ---
 
 
 
 
24
 
25
  ## Usage
26
 
27
+ ### Download and Load Model
28
+
29
  ```python
30
+ import os
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
  from flax import nnx, serialization
35
  from huggingface_hub import hf_hub_download
 
36
 
37
  # Download model weights and config
38
+ model_path = hf_hub_download(repo_id="{jcopo/mnist}", filename="model.msgpack")
39
+ config_path = hf_hub_download(repo_id="{jcopo/mnist}", filename="config.py")
40
 
41
  # Load config to get model architecture
42
+ import importlib.util
43
  spec = importlib.util.spec_from_file_location("model_config", config_path)
44
  config_module = importlib.util.module_from_spec(spec)
45
  spec.loader.exec_module(config_module)
 
52
  state_dict = serialization.from_bytes(None, f.read())
53
 
54
  # Restore weights into model
55
+ graphdef, state = nnx.split(model)
56
+ state.replace_by_pure_dict(state_dict)
57
+ model = nnx.merge(graphdef, state)
58
  model.eval() # Set to evaluation mode
59
 
60
+ print("✅ Model loaded successfully!")
 
61
  ```