mnist / config.py
jcopo's picture
Update config.py
f0e2f9c verified
raw
history blame contribute delete
512 Bytes
"""Model configuration for jcopo/mnist
Training step: 45000
Precision: float32
"""
import jax
import jax.numpy as jnp
from flax import nnx
from diffuse.neural_network import CondUNet2D
# Model definition
model = CondUNet2D(
in_channels=32,
ch=32, # base_channels
ch_mult=(1, 2, 2), # channel_multipliers
num_res_blocks=1,
attention_resolutions=(2,),
num_heads=4,
dropout=True,
dropout_rate=0.1,
activation=nnx.swish,
param_dtype=jnp.float32,
rngs=nnx.Rngs(0),
)