Spaces:
Runtime error
Runtime error
Commit ·
90ec6ed
1
Parent(s): ef7a9a3
updated for cpu
Browse files
README.md
CHANGED
|
@@ -7,6 +7,11 @@ sdk: gradio
|
|
| 7 |
sdk_version: "3.50.2"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
# Style-Guided Image Generation with Purple Enhancement
|
|
|
|
| 7 |
sdk_version: "3.50.2"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
hardware: true
|
| 11 |
+
resources:
|
| 12 |
+
cpu: 1
|
| 13 |
+
memory: "16Gi"
|
| 14 |
+
gpu: 1
|
| 15 |
---
|
| 16 |
|
| 17 |
# Style-Guided Image Generation with Purple Enhancement
|
app.py
CHANGED
|
@@ -40,10 +40,14 @@ styles = {
|
|
| 40 |
|
| 41 |
def load_pipeline():
|
| 42 |
"""Load and prepare the pipeline with all style embeddings"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
pipe = StableDiffusionPipeline.from_pretrained(
|
| 44 |
"CompVis/stable-diffusion-v1-4",
|
| 45 |
-
torch_dtype=
|
| 46 |
-
).to(
|
| 47 |
|
| 48 |
# Load all embeddings
|
| 49 |
for style_info in styles.values():
|
|
@@ -71,8 +75,9 @@ def generate_image(prompt, style, seed, apply_guidance, guidance_strength=0.5):
|
|
| 71 |
# Get style info
|
| 72 |
style_info = styles[style]
|
| 73 |
|
| 74 |
-
# Prepare generator
|
| 75 |
-
|
|
|
|
| 76 |
|
| 77 |
# Create styled prompt
|
| 78 |
styled_prompt = f"{prompt} {style_info['token']}"
|
|
|
|
| 40 |
|
| 41 |
def load_pipeline():
|
| 42 |
"""Load and prepare the pipeline with all style embeddings"""
|
| 43 |
+
# Check if CUDA is available
|
| 44 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 46 |
+
|
| 47 |
pipe = StableDiffusionPipeline.from_pretrained(
|
| 48 |
"CompVis/stable-diffusion-v1-4",
|
| 49 |
+
torch_dtype=dtype
|
| 50 |
+
).to(device)
|
| 51 |
|
| 52 |
# Load all embeddings
|
| 53 |
for style_info in styles.values():
|
|
|
|
| 75 |
# Get style info
|
| 76 |
style_info = styles[style]
|
| 77 |
|
| 78 |
+
# Prepare generator with appropriate device
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
+
generator = torch.Generator(device).manual_seed(int(seed))
|
| 81 |
|
| 82 |
# Create styled prompt
|
| 83 |
styled_prompt = f"{prompt} {style_info['token']}"
|