Update app.py
Browse files
app.py
CHANGED
|
@@ -17,6 +17,8 @@ def load_model(config_path, ckpt_path):
|
|
| 17 |
with open(ckpt_path, "rb") as f:
|
| 18 |
leaves = pickle.load(f)
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from model import DiT, DiTConfig
|
| 21 |
|
| 22 |
dit_config = DiTConfig(**config["model"])
|
|
@@ -32,13 +34,16 @@ def sample_images(graphdef, state, x0, t):
|
|
| 32 |
flow = nnx.merge(graphdef, state)
|
| 33 |
|
| 34 |
def flow_fn(y, t):
|
|
|
|
|
|
|
| 35 |
o = flow(y, t[None])
|
| 36 |
-
return o
|
| 37 |
|
| 38 |
-
o = ode.odeint(flow_fn, x0, t, rtol=1e-
|
| 39 |
o = jnp.clip(o[-1], 0, 1)
|
| 40 |
return o
|
| 41 |
|
|
|
|
| 42 |
@spaces.GPU
|
| 43 |
def generate_grid(seed, noise_level):
|
| 44 |
# Load model (doing this inside function to avoid global variables)
|
|
@@ -66,8 +71,6 @@ def generate_grid(seed, noise_level):
|
|
| 66 |
return jax.device_get(grid)
|
| 67 |
|
| 68 |
|
| 69 |
-
generate_grid(0, 3)
|
| 70 |
-
|
| 71 |
# Create Gradio interface
|
| 72 |
demo = gr.Interface(
|
| 73 |
fn=generate_grid,
|
|
|
|
| 17 |
with open(ckpt_path, "rb") as f:
|
| 18 |
leaves = pickle.load(f)
|
| 19 |
|
| 20 |
+
leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves)
|
| 21 |
+
|
| 22 |
from model import DiT, DiTConfig
|
| 23 |
|
| 24 |
dit_config = DiTConfig(**config["model"])
|
|
|
|
| 34 |
flow = nnx.merge(graphdef, state)
|
| 35 |
|
| 36 |
def flow_fn(y, t):
|
| 37 |
+
y = y.astype(jnp.bfloat16)
|
| 38 |
+
t = t.astype(jnp.bfloat16)
|
| 39 |
o = flow(y, t[None])
|
| 40 |
+
return o.astype(jnp.float32)
|
| 41 |
|
| 42 |
+
o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
|
| 43 |
o = jnp.clip(o[-1], 0, 1)
|
| 44 |
return o
|
| 45 |
|
| 46 |
+
|
| 47 |
@spaces.GPU
|
| 48 |
def generate_grid(seed, noise_level):
|
| 49 |
# Load model (doing this inside function to avoid global variables)
|
|
|
|
| 71 |
return jax.device_get(grid)
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
| 74 |
# Create Gradio interface
|
| 75 |
demo = gr.Interface(
|
| 76 |
fn=generate_grid,
|