jcopo commited on
Commit
a502ae9
·
verified ·
1 Parent(s): a587dc4

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +13 -13
config.py CHANGED
@@ -11,16 +11,16 @@ from diffuse.models import CondUNet2D
11
 
12
  # Model definition
13
  model = CondUNet2D(
14
- in_channels=IN_CHANNELS,
15
- ch=variant["base_channels"],
16
- ch_mult=variant["channel_multipliers"],
17
- num_res_blocks=variant["num_res_blocks"],
18
- attention_resolutions=variant["attention_resolutions"],
19
- num_heads=variant["num_heads"],
20
- dropout=True,
21
- dropout_rate=0.1,
22
- activation=nnx.swish,
23
- param_dtype=jnp.float32,
24
- dtype=compute_dtype,
25
- rngs=nnx.Rngs(MODEL_SEED),
26
- )
 
11
 
12
  # Model definition
13
  model = CondUNet2D(
14
+ in_channels=IN_CHANNELS,
15
+ ch=32, # base_channels
16
+ ch_mult=(1, 2, 2), # channel_multipliers
17
+ num_res_blocks=1,
18
+ attention_resolutions=(2,),
19
+ num_heads=4,
20
+ dropout=True,
21
+ dropout_rate=0.1,
22
+ activation=nnx.swish,
23
+ param_dtype=jnp.float32,
24
+ rngs=nnx.Rngs(0),
25
+ )
26
+