Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,8 +13,9 @@ import torch
|
|
| 13 |
from diffusers import DiffusionPipeline
|
| 14 |
from typing import Tuple
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
device =
|
|
|
|
| 18 |
|
| 19 |
# Setup rules for bad words (ensure the prompts are kid-friendly)
|
| 20 |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
|
|
@@ -68,21 +69,21 @@ DESCRIPTION = """## Children's Sticker Generator
|
|
| 68 |
Generate fun and playful stickers for children using AI.
|
| 69 |
"""
|
| 70 |
|
| 71 |
-
if not torch.cuda.is_available():
|
| 72 |
-
DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
|
| 73 |
-
|
| 74 |
MAX_SEED = np.iinfo(np.int32).max
|
| 75 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
# Initialize
|
| 80 |
-
|
| 81 |
-
"SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
|
| 82 |
-
torch_dtype=torch.float16,
|
| 83 |
-
use_safetensors=True,
|
| 84 |
-
variant="fp16"
|
| 85 |
-
).to(device)
|
| 86 |
|
| 87 |
# Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
|
| 88 |
def mm_to_pixels(mm, dpi=300):
|
|
@@ -134,8 +135,15 @@ def generate(
|
|
| 134 |
guidance_scale: float = 3,
|
| 135 |
randomize_seed: bool = False,
|
| 136 |
background: str = "transparent",
|
|
|
|
| 137 |
progress=gr.Progress(track_tqdm=True),
|
| 138 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
if check_text(prompt, negative_prompt):
|
| 140 |
raise ValueError("Prompt contains restricted words.")
|
| 141 |
|
|
@@ -145,7 +153,7 @@ def generate(
|
|
| 145 |
# Apply style
|
| 146 |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 147 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 148 |
-
generator = torch.Generator().manual_seed(seed)
|
| 149 |
|
| 150 |
# Ensure we have only white or transparent background options
|
| 151 |
width, height = size_map.get(size, (1024, 1024))
|
|
@@ -241,6 +249,11 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
| 241 |
step=0.1,
|
| 242 |
value=15.7,
|
| 243 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
gr.Examples(
|
| 246 |
examples=examples,
|
|
@@ -267,6 +280,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
| 267 |
guidance_scale,
|
| 268 |
randomize_seed,
|
| 269 |
background_selection,
|
|
|
|
| 270 |
],
|
| 271 |
outputs=[result, seed],
|
| 272 |
api_name="run",
|
|
|
|
| 13 |
from diffusers import DiffusionPipeline
|
| 14 |
from typing import Tuple
|
| 15 |
|
| 16 |
+
# Initialize device to None
|
| 17 |
+
device = None
|
| 18 |
+
pipe = None
|
| 19 |
|
| 20 |
# Setup rules for bad words (ensure the prompts are kid-friendly)
|
| 21 |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
|
|
|
|
| 69 |
Generate fun and playful stickers for children using AI.
|
| 70 |
"""
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
MAX_SEED = np.iinfo(np.int32).max
|
| 73 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
|
| 74 |
|
| 75 |
+
def initialize_pipeline(device_type):
|
| 76 |
+
global device, pipe
|
| 77 |
+
device = torch.device(device_type)
|
| 78 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 79 |
+
"SG161222/RealVisXL_V3.0_Turbo",
|
| 80 |
+
torch_dtype=torch.float32 if device_type == "cpu" else torch.float16,
|
| 81 |
+
use_safetensors=True,
|
| 82 |
+
variant="fp32" if device_type == "cpu" else "fp16"
|
| 83 |
+
).to(device)
|
| 84 |
|
| 85 |
+
# Initialize with CPU by default
|
| 86 |
+
initialize_pipeline("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
|
| 89 |
def mm_to_pixels(mm, dpi=300):
|
|
|
|
| 135 |
guidance_scale: float = 3,
|
| 136 |
randomize_seed: bool = False,
|
| 137 |
background: str = "transparent",
|
| 138 |
+
device_type: str = "cpu",
|
| 139 |
progress=gr.Progress(track_tqdm=True),
|
| 140 |
):
|
| 141 |
+
global device, pipe
|
| 142 |
+
|
| 143 |
+
# Switch device if necessary
|
| 144 |
+
if device.type != device_type:
|
| 145 |
+
initialize_pipeline(device_type)
|
| 146 |
+
|
| 147 |
if check_text(prompt, negative_prompt):
|
| 148 |
raise ValueError("Prompt contains restricted words.")
|
| 149 |
|
|
|
|
| 153 |
# Apply style
|
| 154 |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
| 155 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 156 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 157 |
|
| 158 |
# Ensure we have only white or transparent background options
|
| 159 |
width, height = size_map.get(size, (1024, 1024))
|
|
|
|
| 249 |
step=0.1,
|
| 250 |
value=15.7,
|
| 251 |
)
|
| 252 |
+
device_selection = gr.Radio(
|
| 253 |
+
choices=["cpu", "cuda"],
|
| 254 |
+
value="cpu",
|
| 255 |
+
label="Device",
|
| 256 |
+
)
|
| 257 |
|
| 258 |
gr.Examples(
|
| 259 |
examples=examples,
|
|
|
|
| 280 |
guidance_scale,
|
| 281 |
randomize_seed,
|
| 282 |
background_selection,
|
| 283 |
+
device_selection,
|
| 284 |
],
|
| 285 |
outputs=[result, seed],
|
| 286 |
api_name="run",
|