Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ Author: @Raxephion 2025
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
-
import numpy as np
|
| 8 |
import random
|
| 9 |
import torch
|
| 10 |
from diffusers import StableDiffusionPipeline
|
|
@@ -95,12 +95,21 @@ if INITIAL_MODEL_ID:
|
|
| 95 |
print(f"\nLoading initial model '{INITIAL_MODEL_ID}' on startup...")
|
| 96 |
try:
|
| 97 |
# Load the pipeline onto the initial device and dtype
|
| 98 |
-
|
| 99 |
INITIAL_MODEL_ID,
|
| 100 |
torch_dtype=initial_dtype_to_use,
|
| 101 |
safety_checker=None, # <<< SAFETY CHECKER DISABLED <<<
|
| 102 |
)
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
current_model_id = INITIAL_MODEL_ID
|
| 105 |
current_device_loaded = torch.device(initial_device_to_use)
|
| 106 |
print(f"Initial model loaded successfully on {current_device_loaded}.")
|
|
@@ -146,10 +155,11 @@ def infer(
|
|
| 146 |
size, # From size_dropdown
|
| 147 |
seed, # From seed_input (now a Slider)
|
| 148 |
randomize_seed, # From randomize_seed_checkbox
|
|
|
|
| 149 |
progress=gr.Progress(track_tqdm=True), # Added progress argument from template
|
| 150 |
):
|
| 151 |
"""Generates an image using the selected model and parameters on the chosen device."""
|
| 152 |
-
global current_pipeline, current_model_id, current_device_loaded, SCHEDULER_MAP, MAX_SEED
|
| 153 |
|
| 154 |
# This check is done before parameter parsing so we can determine device/dtype for loading
|
| 155 |
# Need to redo some parameter parsing here to get device_to_use early
|
|
@@ -165,7 +175,6 @@ def infer(
|
|
| 165 |
|
| 166 |
# 1. Load/Switch Model if necessary
|
| 167 |
# Check if the requested model identifier OR the requested device has changed
|
| 168 |
-
# Use string comparison for current_device_loaded as it's a torch.device object
|
| 169 |
if current_pipeline is None or current_model_id != model_identifier or (current_device_loaded is not None and str(current_device_loaded) != temp_device_to_use):
|
| 170 |
|
| 171 |
print(f"Loading model: {model_identifier} onto {temp_device_to_use} with dtype {temp_dtype_to_use}...")
|
|
@@ -180,6 +189,7 @@ def infer(
|
|
| 180 |
print(f"Warning: Failed to move previous pipeline to CPU: {move_e}")
|
| 181 |
del current_pipeline
|
| 182 |
current_pipeline = None # Set to None immediately
|
|
|
|
| 183 |
if str(current_device_loaded) == "cuda":
|
| 184 |
try:
|
| 185 |
torch.cuda.empty_cache()
|
|
@@ -190,7 +200,7 @@ def infer(
|
|
| 190 |
# Ensure the device is actually available if not CPU (redundant with earlier check but safe)
|
| 191 |
if temp_device_to_use == "cuda":
|
| 192 |
if not torch.cuda.is_available():
|
| 193 |
-
raise gr.Error("
|
| 194 |
|
| 195 |
try:
|
| 196 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
|
@@ -198,6 +208,24 @@ def infer(
|
|
| 198 |
torch_dtype=temp_dtype_to_use, # Use the determined dtype for loading
|
| 199 |
safety_checker=None, # DISABLED
|
| 200 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
pipeline = pipeline.to(temp_device_to_use) # Use the determined device
|
| 202 |
|
| 203 |
current_pipeline = pipeline
|
|
@@ -244,6 +272,8 @@ def infer(
|
|
| 244 |
|
| 245 |
# Re-determine device_to_use and dtype_to_use *after* ensuring pipeline is loaded
|
| 246 |
# They should match current_device_loaded and the pipeline's dtype
|
|
|
|
|
|
|
| 247 |
device_to_use = str(current_pipeline.device) if current_pipeline else ("cuda" if selected_device_str == "GPU" and "GPU" in AVAILABLE_DEVICES else "cpu")
|
| 248 |
dtype_to_use = current_pipeline.dtype if current_pipeline else torch.float32 # Fallback if somehow pipeline is still None
|
| 249 |
|
|
@@ -253,6 +283,30 @@ def infer(
|
|
| 253 |
raise gr.Error("Model failed to load during setup or switching. Cannot generate image.")
|
| 254 |
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# 2. Configure Scheduler
|
| 257 |
selected_scheduler_class = SCHEDULER_MAP.get(scheduler_name)
|
| 258 |
if selected_scheduler_class is None:
|
|
@@ -348,7 +402,7 @@ def infer(
|
|
| 348 |
if width <= 0 or height <= 0:
|
| 349 |
raise ValueError("Image width and height must be positive.")
|
| 350 |
|
| 351 |
-
print(f"Generating: Prompt='{prompt[:80]}{'...' if len(prompt) > 80 else ''}', NegPrompt='{negative_prompt[:80]}{'...' if len(negative_prompt) > 80 else ''}', Steps={num_inference_steps_int}, CFG={guidance_scale_float}, Size={width}x{height}, Scheduler={scheduler_name}, Seed={seed_int if generator else 'System Random'}, Device={device_to_use}, Dtype={dtype_to_use}")
|
| 352 |
start_time = time.time()
|
| 353 |
|
| 354 |
try:
|
|
@@ -367,8 +421,6 @@ def infer(
|
|
| 367 |
|
| 368 |
# Add VAE usage here if needed for specific models that require it
|
| 369 |
# vae=...
|
| 370 |
-
# Potentially add attention slicing/xformers/etc. for memory efficiency
|
| 371 |
-
# enable_attention_slicing="auto", # Can help with VRAM on smaller GPUs
|
| 372 |
# enable_xformers_memory_efficient_attention() # Needs xformers installed & compatible GPU
|
| 373 |
)
|
| 374 |
end_time = time.time()
|
|
@@ -488,6 +540,17 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added Soft theme from
|
|
| 488 |
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True) # Use 0 as default, interactive initially
|
| 489 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) # Simplified label
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
generate_button = gr.Button("✨ Generate Image ✨", variant="primary", scale=1) # Added emojis
|
| 493 |
|
|
@@ -520,7 +583,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added Soft theme from
|
|
| 520 |
scheduler_dropdown,
|
| 521 |
size_dropdown,
|
| 522 |
seed_input,
|
| 523 |
-
randomize_seed_checkbox,
|
|
|
|
| 524 |
],
|
| 525 |
outputs=[output_image, actual_seed_output], # Return image and the actual seed used
|
| 526 |
api_name="generate" # Optional: For API access
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
import random
|
| 9 |
import torch
|
| 10 |
from diffusers import StableDiffusionPipeline
|
|
|
|
| 95 |
print(f"\nLoading initial model '{INITIAL_MODEL_ID}' on startup...")
|
| 96 |
try:
|
| 97 |
# Load the pipeline onto the initial device and dtype
|
| 98 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
| 99 |
INITIAL_MODEL_ID,
|
| 100 |
torch_dtype=initial_dtype_to_use,
|
| 101 |
safety_checker=None, # <<< SAFETY CHECKER DISABLED <<<
|
| 102 |
)
|
| 103 |
+
|
| 104 |
+
# --- Apply Optimizations during initial load ---
|
| 105 |
+
# Apply attention slicing by default for memory efficiency on Spaces
|
| 106 |
+
# Can be turned off via UI toggle later, but good default for VRAM
|
| 107 |
+
# We'll add the UI toggle later, for now, just enable it here
|
| 108 |
+
# pipeline.enable_attention_slicing() # Enable by default on initial load
|
| 109 |
+
|
| 110 |
+
pipeline = pipeline.to(initial_device_to_use) # Move to the initial device
|
| 111 |
+
|
| 112 |
+
current_pipeline = pipeline
|
| 113 |
current_model_id = INITIAL_MODEL_ID
|
| 114 |
current_device_loaded = torch.device(initial_device_to_use)
|
| 115 |
print(f"Initial model loaded successfully on {current_device_loaded}.")
|
|
|
|
| 155 |
size, # From size_dropdown
|
| 156 |
seed, # From seed_input (now a Slider)
|
| 157 |
randomize_seed, # From randomize_seed_checkbox
|
| 158 |
+
enable_attention_slicing, # <-- New input for the optimization toggle
|
| 159 |
progress=gr.Progress(track_tqdm=True), # Added progress argument from template
|
| 160 |
):
|
| 161 |
"""Generates an image using the selected model and parameters on the chosen device."""
|
| 162 |
+
global current_pipeline, current_model_id, current_device_loaded, SCHEDULER_MAP, MAX_SEED
|
| 163 |
|
| 164 |
# This check is done before parameter parsing so we can determine device/dtype for loading
|
| 165 |
# Need to redo some parameter parsing here to get device_to_use early
|
|
|
|
| 175 |
|
| 176 |
# 1. Load/Switch Model if necessary
|
| 177 |
# Check if the requested model identifier OR the requested device has changed
|
|
|
|
| 178 |
if current_pipeline is None or current_model_id != model_identifier or (current_device_loaded is not None and str(current_device_loaded) != temp_device_to_use):
|
| 179 |
|
| 180 |
print(f"Loading model: {model_identifier} onto {temp_device_to_use} with dtype {temp_dtype_to_use}...")
|
|
|
|
| 189 |
print(f"Warning: Failed to move previous pipeline to CPU: {move_e}")
|
| 190 |
del current_pipeline
|
| 191 |
current_pipeline = None # Set to None immediately
|
| 192 |
+
# Attempt to clear CUDA cache if using GPU (from the previous device)
|
| 193 |
if str(current_device_loaded) == "cuda":
|
| 194 |
try:
|
| 195 |
torch.cuda.empty_cache()
|
|
|
|
| 200 |
# Ensure the device is actually available if not CPU (redundant with earlier check but safe)
|
| 201 |
if temp_device_to_use == "cuda":
|
| 202 |
if not torch.cuda.is_available():
|
| 203 |
+
raise gr.Error("GPU selected but CUDA is not available to PyTorch on this Space. Please select CPU or ensure the Space is configured with a GPU and the CUDA version of PyTorch is installed.")
|
| 204 |
|
| 205 |
try:
|
| 206 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
|
|
|
| 208 |
torch_dtype=temp_dtype_to_use, # Use the determined dtype for loading
|
| 209 |
safety_checker=None, # DISABLED
|
| 210 |
)
|
| 211 |
+
|
| 212 |
+
# Apply optimizations based on UI input during load
|
| 213 |
+
if enable_attention_slicing and temp_device_to_use == "cuda": # Only apply on GPU
|
| 214 |
+
try:
|
| 215 |
+
pipeline.enable_attention_slicing()
|
| 216 |
+
print("Attention Slicing enabled.")
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Warning: Failed to enable Attention Slicing: {e}")
|
| 219 |
+
gr.Warning(f"Failed to enable Attention Slicing. Error: {e}")
|
| 220 |
+
else:
|
| 221 |
+
try:
|
| 222 |
+
pipeline.disable_attention_slicing() # Ensure it's off if toggle is off or on CPU
|
| 223 |
+
print("Attention Slicing disabled.")
|
| 224 |
+
except Exception as e:
|
| 225 |
+
# May fail if it wasn't enabled, ignore
|
| 226 |
+
pass
|
| 227 |
+
|
| 228 |
+
|
| 229 |
pipeline = pipeline.to(temp_device_to_use) # Use the determined device
|
| 230 |
|
| 231 |
current_pipeline = pipeline
|
|
|
|
| 272 |
|
| 273 |
# Re-determine device_to_use and dtype_to_use *after* ensuring pipeline is loaded
|
| 274 |
# They should match current_device_loaded and the pipeline's dtype
|
| 275 |
+
# This is crucial because current_pipeline.device and dtype are the definitive source
|
| 276 |
+
# after a potentially successful load or switch.
|
| 277 |
device_to_use = str(current_pipeline.device) if current_pipeline else ("cuda" if selected_device_str == "GPU" and "GPU" in AVAILABLE_DEVICES else "cpu")
|
| 278 |
dtype_to_use = current_pipeline.dtype if current_pipeline else torch.float32 # Fallback if somehow pipeline is still None
|
| 279 |
|
|
|
|
| 283 |
raise gr.Error("Model failed to load during setup or switching. Cannot generate image.")
|
| 284 |
|
| 285 |
|
| 286 |
+
# --- Apply Optimizations *before* generation if model was already loaded ---
|
| 287 |
+
# If the model didn't need reloading, we need to apply/remove slicing here
|
| 288 |
+
if str(current_pipeline.device) == "cuda": # Only attempt on GPU
|
| 289 |
+
if enable_attention_slicing:
|
| 290 |
+
try:
|
| 291 |
+
current_pipeline.enable_attention_slicing()
|
| 292 |
+
# print("Attention Slicing enabled for generation.") # Too verbose
|
| 293 |
+
except Exception as e:
|
| 294 |
+
print(f"Warning: Failed to enable Attention Slicing before generation: {e}")
|
| 295 |
+
gr.Warning(f"Failed to enable Attention Slicing. Error: {e}")
|
| 296 |
+
else:
|
| 297 |
+
try:
|
| 298 |
+
current_pipeline.disable_attention_slicing()
|
| 299 |
+
# print("Attention Slicing disabled for generation.") # Too verbose
|
| 300 |
+
except Exception as e:
|
| 301 |
+
# May fail if it wasn't enabled, ignore
|
| 302 |
+
pass
|
| 303 |
+
else: # Ensure slicing is off on CPU
|
| 304 |
+
try:
|
| 305 |
+
current_pipeline.disable_attention_slicing()
|
| 306 |
+
except Exception as e:
|
| 307 |
+
pass # Ignore
|
| 308 |
+
|
| 309 |
+
|
| 310 |
# 2. Configure Scheduler
|
| 311 |
selected_scheduler_class = SCHEDULER_MAP.get(scheduler_name)
|
| 312 |
if selected_scheduler_class is None:
|
|
|
|
| 402 |
if width <= 0 or height <= 0:
|
| 403 |
raise ValueError("Image width and height must be positive.")
|
| 404 |
|
| 405 |
+
print(f"Generating: Prompt='{prompt[:80]}{'...' if len(prompt) > 80 else ''}', NegPrompt='{negative_prompt[:80]}{'...' if len(negative_prompt) > 80 else ''}', Steps={num_inference_steps_int}, CFG={guidance_scale_float}, Size={width}x{height}, Scheduler={scheduler_name}, Seed={seed_int if generator else 'System Random'}, Device={device_to_use}, Dtype={dtype_to_use}, Slicing Enabled={enable_attention_slicing and device_to_use == 'cuda'}")
|
| 406 |
start_time = time.time()
|
| 407 |
|
| 408 |
try:
|
|
|
|
| 421 |
|
| 422 |
# Add VAE usage here if needed for specific models that require it
|
| 423 |
# vae=...
|
|
|
|
|
|
|
| 424 |
# enable_xformers_memory_efficient_attention() # Needs xformers installed & compatible GPU
|
| 425 |
)
|
| 426 |
end_time = time.time()
|
|
|
|
| 540 |
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True) # Use 0 as default, interactive initially
|
| 541 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) # Simplified label
|
| 542 |
|
| 543 |
+
# --- New: Memory Optimization Toggle ---
|
| 544 |
+
with gr.Row():
|
| 545 |
+
# Default to enabled if GPU is available, otherwise off
|
| 546 |
+
default_slicing = True if "GPU" in AVAILABLE_DEVICES else False
|
| 547 |
+
enable_attention_slicing_checkbox = gr.Checkbox(
|
| 548 |
+
label="Enable Attention Slicing (Memory Optimization - GPU only)",
|
| 549 |
+
value=default_slicing,
|
| 550 |
+
interactive="GPU" in AVAILABLE_DEVICES # Only interactive if GPU is an option
|
| 551 |
+
)
|
| 552 |
+
gr.Markdown("*(Helps reduce VRAM usage, may slightly affect speed/quality)*")
|
| 553 |
+
|
| 554 |
|
| 555 |
generate_button = gr.Button("✨ Generate Image ✨", variant="primary", scale=1) # Added emojis
|
| 556 |
|
|
|
|
| 583 |
scheduler_dropdown,
|
| 584 |
size_dropdown,
|
| 585 |
seed_input,
|
| 586 |
+
randomize_seed_checkbox,
|
| 587 |
+
enable_attention_slicing_checkbox, # <-- Pass the new checkbox value
|
| 588 |
],
|
| 589 |
outputs=[output_image, actual_seed_output], # Return image and the actual seed used
|
| 590 |
api_name="generate" # Optional: For API access
|