jcopo commited on
Commit
a587dc4
·
verified ·
1 Parent(s): aa4dc29

Upload model at step 45000

Browse files
Files changed (2) hide show
  1. README.md +4 -3
  2. config.py +3 -3
README.md CHANGED
@@ -3,17 +3,18 @@ tags:
3
  - jax
4
  - flax
5
  - flax-nnx
6
- library_name: triax
 
7
  ---
8
 
9
  # mnist
10
 
11
- Model trained using [Triax](https://github.com/your-org/triax), a JAX/Flax training framework.
12
 
13
  ## Model Details
14
 
15
  - **Model Type**: CondUNet2D
16
- - **Training Step**: 45,000
17
  - **Precision**: float32
18
  - **Framework**: JAX/Flax (NNX)
19
  - **Format**: msgpack
 
3
  - jax
4
  - flax
5
  - flax-nnx
6
+ - diffusion
7
+ library_name: diffuse
8
  ---
9
 
10
  # mnist
11
 
12
+ Diffusion model for use with [diffuse](https://github.com/jcopo/diffuse), a JAX/Flax sampling library.
13
 
14
  ## Model Details
15
 
16
  - **Model Type**: CondUNet2D
17
+ - **Checkpoint Step**: 45,000
18
  - **Precision**: float32
19
  - **Framework**: JAX/Flax (NNX)
20
  - **Format**: msgpack
config.py CHANGED
@@ -7,7 +7,7 @@ Precision: float32
7
  import jax
8
  import jax.numpy as jnp
9
  from flax import nnx
10
- from triax.models import CondUNet2D
11
 
12
  # Model definition
13
  model = CondUNet2D(
@@ -17,8 +17,8 @@ model = CondUNet2D(
17
  num_res_blocks=variant["num_res_blocks"],
18
  attention_resolutions=variant["attention_resolutions"],
19
  num_heads=variant["num_heads"],
20
- dropout=DROPOUT,
21
- dropout_rate=DROPOUT_RATE,
22
  activation=nnx.swish,
23
  param_dtype=jnp.float32,
24
  dtype=compute_dtype,
 
7
  import jax
8
  import jax.numpy as jnp
9
  from flax import nnx
10
+ from diffuse.models import CondUNet2D
11
 
12
  # Model definition
13
  model = CondUNet2D(
 
17
  num_res_blocks=variant["num_res_blocks"],
18
  attention_resolutions=variant["attention_resolutions"],
19
  num_heads=variant["num_heads"],
20
+ dropout=True,
21
+ dropout_rate=0.1,
22
  activation=nnx.swish,
23
  param_dtype=jnp.float32,
24
  dtype=compute_dtype,