Update config.py
Browse files
config.py
CHANGED
|
@@ -7,7 +7,7 @@ Precision: float32
|
|
| 7 |
import jax
|
| 8 |
import jax.numpy as jnp
|
| 9 |
from flax import nnx
|
| 10 |
-
from diffuse.
|
| 11 |
|
| 12 |
# Model definition
|
| 13 |
model = CondUNet2D(
|
|
|
|
| 7 |
import jax
|
| 8 |
import jax.numpy as jnp
|
| 9 |
from flax import nnx
|
| 10 |
+
from diffuse.neural_network import CondUNet2D
|
| 11 |
|
| 12 |
# Model definition
|
| 13 |
model = CondUNet2D(
|