Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,8 @@ import cv2
|
|
| 10 |
import os
|
| 11 |
|
| 12 |
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def image_grid(imgs, rows, cols):
|
|
@@ -41,7 +42,6 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
|
| 41 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
| 42 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
| 43 |
)
|
| 44 |
-
pipe = pipe.to("cuda")
|
| 45 |
|
| 46 |
def infer(prompts, negative_prompts, image):
|
| 47 |
|
|
|
|
| 10 |
import os
|
| 11 |
|
| 12 |
|
| 13 |
+
from jax import device
|
| 14 |
+
jax.config.update('jax_platform_name', 'gpu')
|
| 15 |
|
| 16 |
|
| 17 |
def image_grid(imgs, rows, cols):
|
|
|
|
| 42 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
| 43 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
| 44 |
)
|
|
|
|
| 45 |
|
| 46 |
def infer(prompts, negative_prompts, image):
|
| 47 |
|