| { | |
| "model_type": "CondUNet2D", | |
| "checkpoint_step": 45000, | |
| "precision": "float32", | |
| "framework": "jax-flax-nnx", | |
| "format": "msgpack", | |
| "library_name": "diffuse", | |
| "architecture": { | |
| "in_channels": 32, | |
| "base_channels": 32, | |
| "channel_multipliers": [1, 2, 2], | |
| "num_res_blocks": 1, | |
| "attention_resolutions": [2], | |
| "num_heads": 4, | |
| "dropout": true, | |
| "dropout_rate": 0.1, | |
| "activation": "swish" | |
| }, | |
| "model_class": "diffuse.neural_network.CondUNet2D", | |
| "repo_id": "jcopo/mnist" | |
| } | |