Spaces:
Runtime error
Runtime error
Update dtype and example
Browse files
app.py
CHANGED
|
@@ -1,12 +1,27 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import jax
|
| 3 |
-
|
|
|
|
| 4 |
from flax.jax_utils import replicate
|
| 5 |
from flax.training.common_utils import shard
|
| 6 |
|
|
|
|
|
|
|
| 7 |
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
| 8 |
"bguisard/stable-diffusion-nano-2-1",
|
|
|
|
| 9 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
|
|
@@ -51,7 +66,17 @@ app = gr.Interface(
|
|
| 51 |
"Stable Diffusion Nano allows for fast prototyping of diffusion models, "
|
| 52 |
"enabling quick experimentation with easily available hardware."
|
| 53 |
),
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
app.launch()
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import jax
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
from diffusers import FlaxPNDMScheduler, FlaxStableDiffusionPipeline
|
| 5 |
from flax.jax_utils import replicate
|
| 6 |
from flax.training.common_utils import shard
|
| 7 |
|
| 8 |
+
DTYPE = jnp.bfloat16
|
| 9 |
+
|
| 10 |
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
| 11 |
"bguisard/stable-diffusion-nano-2-1",
|
| 12 |
+
dtype=DTYPE,
|
| 13 |
)
|
| 14 |
+
if DTYPE != jnp.float32:
|
| 15 |
+
# There is a known issue with schedulers when loading from a pre trained
|
| 16 |
+
# pipeline. We need the schedulers to always use float32.
|
| 17 |
+
# See: https://github.com/huggingface/diffusers/issues/2155
|
| 18 |
+
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(
|
| 19 |
+
pretrained_model_name_or_path="bguisard/stable-diffusion-nano-2-1",
|
| 20 |
+
subfolder="scheduler",
|
| 21 |
+
dtype=jnp.float32,
|
| 22 |
+
)
|
| 23 |
+
pipeline_params["scheduler"] = scheduler_params
|
| 24 |
+
pipeline.scheduler = scheduler
|
| 25 |
|
| 26 |
|
| 27 |
def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
|
|
|
|
| 66 |
"Stable Diffusion Nano allows for fast prototyping of diffusion models, "
|
| 67 |
"enabling quick experimentation with easily available hardware."
|
| 68 |
),
|
| 69 |
+
# Some examples were copied from hf.co/spaces/stabilityai/stable-diffusion
|
| 70 |
+
examples=[
|
| 71 |
+
# ["A watercolor painting of a bird", 30, 0],
|
| 72 |
+
[
|
| 73 |
+
"A small cabin on top of a snowy mountain in the style of Disney, artstation",
|
| 74 |
+
25,
|
| 75 |
+
3129302,
|
| 76 |
+
],
|
| 77 |
+
# ["A mecha robot in a favela in expressionist style", 30, 827198341273],
|
| 78 |
+
],
|
| 79 |
)
|
| 80 |
|
| 81 |
app.launch()
|
| 82 |
+
|