File size: 6,811 Bytes
45fb6b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# app.py
# Text-to-Image Space using Diffusers + Gradio
# Works on CPU (slow) and GPU (recommended). Choose a model in the UI.
import os
import math
import torch
import gradio as gr
from typing import List, Optional
from PIL import Image
from diffusers import (
DiffusionPipeline,
StableDiffusionPipeline,
AutoPipelineForText2Image,
)
# --------- Config ---------
MODEL_CHOICES = {
# Solid baseline, license-free to use after accepting on HF if required.
"Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)": "runwayml/stable-diffusion-v1-5",
# Very fast for prototyping; outputs can be less detailed. Best with GPU.
"SDXL Turbo (stabilityai/sdxl-turbo)": "stabilityai/sdxl-turbo",
}
DEFAULT_MODEL_LABEL = "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)"
# Disable safety checker by default (your responsibility). Toggle in UI.
DISABLE_SAFETY_DEFAULT = True
# --------- Runtime helpers ---------
def get_device() -> str:
if torch.cuda.is_available():
return "cuda"
# Spaces don't use Apple MPS; leaving for completeness
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return "mps"
return "cpu"
def nearest_multiple_of_8(x: int) -> int:
if x < 64:
return 64
return int(round(x / 8) * 8)
# Cache pipelines per model to avoid reloading on each call
_PIPE_CACHE = {}
def load_pipe(model_id: str, device: str, fp16: bool) -> DiffusionPipeline:
key = (model_id, device, fp16)
if key in _PIPE_CACHE:
return _PIPE_CACHE[key]
dtype = torch.float16 if (fp16 and device == "cuda") else torch.float32
# AutoPipeline works for many models; we fall back to SD pipeline for v1-5
try:
pipe = AutoPipelineForTextToImage.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
trust_remote_code=False,
)
except Exception:
# Legacy fallback for SD 1.5
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
)
# Send to device
pipe = pipe.to(device)
# Try memory-efficient attention if available
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
_PIPE_CACHE[key] = pipe
return pipe
# --------- Inference ---------
def generate(
prompt: str,
negative: str,
model_label: str,
steps: int,
guidance: float,
width: int,
height: int,
seed: Optional[int],
batch_size: int,
disable_safety: bool,
) -> List[Image.Image]:
prompt = (prompt or "").strip()
if not prompt:
raise gr.Error("Enter a non-empty prompt.")
model_id = MODEL_CHOICES[model_label]
device = get_device()
# SDXL Turbo ignores CFG and uses very low steps; keep sensible defaults
is_turbo = "sdxl-turbo" in model_id.lower()
if is_turbo:
steps = max(1, min(steps, 6)) # turbo is usually 1–6 steps
guidance = 0.0 # turbo uses guidance-free sampling; CFG does nothing
width = nearest_multiple_of_8(width)
height = nearest_multiple_of_8(height)
batch_size = max(1, min(batch_size, 8))
pipe = load_pipe(model_id, device, fp16=(device == "cuda"))
# Safety checker
if hasattr(pipe, "safety_checker"):
pipe.safety_checker = None if disable_safety else pipe.safety_checker
# Determinism
generator = None
if seed is not None and seed != "":
try:
seed = int(seed)
except ValueError:
seed = None
if seed is not None:
if device == "cuda":
generator = torch.Generator(device="cuda").manual_seed(seed)
elif device == "mps":
generator = torch.Generator(device="cpu").manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
prompts = [prompt] * batch_size
negative_prompts = [negative] * batch_size if negative else None
# Run
with torch.autocast("cuda", enabled=(device == "cuda")):
out = pipe(
prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=int(steps),
guidance_scale=float(guidance),
width=int(width),
height=int(height),
generator=generator,
)
images = out.images
return images
# --------- UI ---------
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# Text-to-Image (Diffusers)
- **Models:** SD 1.5 and SDXL Turbo
- **Tip:** SD 1.5 = better detail on CPU; Turbo = very fast on GPU, fewer steps.
"""
)
with gr.Row():
model_dd = gr.Dropdown(
label="Model",
choices=list(MODEL_CHOICES.keys()),
value=DEFAULT_MODEL_LABEL,
)
steps = gr.Slider(1, 75, value=30, step=1, label="Steps")
guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance (CFG)")
with gr.Row():
width = gr.Slider(256, 1024, value=768, step=8, label="Width (multiple of 8)")
height = gr.Slider(256, 1024, value=768, step=8, label="Height (multiple of 8)")
batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size")
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="a cozy cabin at twilight beside a lake, cinematic lighting")
negative = gr.Textbox(label="Negative Prompt", lines=1, placeholder="blurry, low quality, distorted")
with gr.Row():
seed = gr.Textbox(label="Seed (optional integer)", value="")
disable_safety = gr.Checkbox(label="Disable safety checker (you are responsible)", value=DISABLE_SAFETY_DEFAULT)
run_btn = gr.Button("Generate", variant="primary")
gallery = gr.Gallery(label="Results", columns=2, height=512, preview=True)
def _on_change_model(label):
# If Turbo selected, nudge UI to sane defaults
if "Turbo" in label:
return gr.update(value=4), gr.update(value=0.0)
else:
return gr.update(value=30), gr.update(value=7.5)
model_dd.change(_on_change_model, inputs=model_dd, outputs=[steps, guidance])
run_btn.click(
fn=generate,
inputs=[prompt, negative, model_dd, steps, guidance, width, height, seed, batch_size, disable_safety],
outputs=[gallery],
api_name="generate",
scroll_to_output=True,
concurrency_limit=2,
)
if __name__ == "__main__":
# In Spaces, just running the file starts the app. Debug on for clearer stack traces.
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)
|