jcopo commited on
Commit
5a116fc
·
verified ·
1 Parent(s): a502ae9

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +1 -1
config.py CHANGED
@@ -7,7 +7,7 @@ Precision: float32
7
  import jax
8
  import jax.numpy as jnp
9
  from flax import nnx
10
- from diffuse.models import CondUNet2D
11
 
12
  # Model definition
13
  model = CondUNet2D(
 
7
  import jax
8
  import jax.numpy as jnp
9
  from flax import nnx
10
+ from diffuse.neural_network import CondUNet2D
11
 
12
  # Model definition
13
  model = CondUNet2D(