"""Model configuration for jcopo/mnist Training step: 45000 Precision: float32 """ import jax import jax.numpy as jnp from flax import nnx from diffuse.neural_network import CondUNet2D # Model definition model = CondUNet2D( in_channels=32, ch=32, # base_channels ch_mult=(1, 2, 2), # channel_multipliers num_res_blocks=1, attention_resolutions=(2,), num_heads=4, dropout=True, dropout_rate=0.1, activation=nnx.swish, param_dtype=jnp.float32, rngs=nnx.Rngs(0), )