Spaces:
Runtime error
Runtime error
Commit ·
d9edb33
1
Parent(s): a6df2dc
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,8 +23,8 @@ def infer(prompts, negative_prompts):
|
|
| 23 |
rng = create_key(0)
|
| 24 |
rng = jax.random.split(rng, jax.device_count())
|
| 25 |
|
| 26 |
-
prompt_ids = pipe.
|
| 27 |
-
negative_prompt_ids = pipe.
|
| 28 |
|
| 29 |
p_params = replicate(params)
|
| 30 |
prompt_ids = shard(prompt_ids)
|
|
|
|
| 23 |
rng = create_key(0)
|
| 24 |
rng = jax.random.split(rng, jax.device_count())
|
| 25 |
|
| 26 |
+
prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
|
| 27 |
+
negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples)
|
| 28 |
|
| 29 |
p_params = replicate(params)
|
| 30 |
prompt_ids = shard(prompt_ids)
|