Commit
·
90816e9
1
Parent(s):
733669b
Cleanup
Browse files- app.py +4 -0
- pre-requirements.txt +5 -0
- requirements.txt +5 -3
app.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import jax
|
|
|
|
| 3 |
from diffusers import FlaxStableDiffusionPipeline
|
| 4 |
from flax.jax_utils import replicate
|
| 5 |
from flax.training.common_utils import shard
|
| 6 |
|
| 7 |
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
| 8 |
"bguisard/stable-diffusion-nano",
|
|
|
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
|
| 11 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import jax
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
from diffusers import FlaxStableDiffusionPipeline
|
| 5 |
from flax.jax_utils import replicate
|
| 6 |
from flax.training.common_utils import shard
|
| 7 |
|
| 8 |
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
| 9 |
"bguisard/stable-diffusion-nano",
|
| 10 |
+
dtype=jnp.float16,
|
| 11 |
+
resume_download=True,
|
| 12 |
+
use_memory_efficient_attention=True
|
| 13 |
)
|
| 14 |
|
| 15 |
|
pre-requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip
|
| 2 |
+
setuptools
|
| 3 |
+
wheel
|
| 4 |
+
ninja
|
| 5 |
+
cmake
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
transformers
|
| 2 |
-
|
| 3 |
-
jax[cuda11_pip]
|
| 4 |
-
|
|
|
|
|
|
|
|
|
| 1 |
transformers
|
| 2 |
+
flax
|
| 3 |
+
jax[cuda11_pip]
|
| 4 |
+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 5 |
+
jaxlib
|
| 6 |
+
git+https://github.com/huggingface/diffusers@main
|