| """Model configuration for jcopo/mnist | |
| Training step: 45000 | |
| Precision: float32 | |
| """ | |
| import jax | |
| import jax.numpy as jnp | |
| from flax import nnx | |
| from diffuse.neural_network import CondUNet2D | |
| # Model definition | |
| model = CondUNet2D( | |
| in_channels=32, | |
| ch=32, # base_channels | |
| ch_mult=(1, 2, 2), # channel_multipliers | |
| num_res_blocks=1, | |
| attention_resolutions=(2,), | |
| num_heads=4, | |
| dropout=True, | |
| dropout_rate=0.1, | |
| activation=nnx.swish, | |
| param_dtype=jnp.float32, | |
| rngs=nnx.Rngs(0), | |
| ) | |