File size: 512 Bytes
f7a97bc
 
 
 
 
 
aa4dc29
0a02da1
 
5a116fc
f7a97bc
aa4dc29
f7a97bc
f0e2f9c
a502ae9
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Model configuration for jcopo/mnist

Training step: 45000
Precision: float32
"""

import jax
import jax.numpy as jnp
from flax import nnx
from diffuse.neural_network import CondUNet2D

# Model definition
model = CondUNet2D(
    in_channels=32,
    ch=32,  # base_channels
    ch_mult=(1, 2, 2),  # channel_multipliers
    num_res_blocks=1,
    attention_resolutions=(2,),
    num_heads=4,
    dropout=True,
    dropout_rate=0.1,
    activation=nnx.swish,
    param_dtype=jnp.float32,
    rngs=nnx.Rngs(0),
)