test1 / app.py
achase25's picture
Create app.py
45fb6b9 verified
# app.py
# Text-to-Image Space using Diffusers + Gradio
# Works on CPU (slow) and GPU (recommended). Choose a model in the UI.
import os
import math
import torch
import gradio as gr
from typing import List, Optional
from PIL import Image
from diffusers import (
DiffusionPipeline,
StableDiffusionPipeline,
AutoPipelineForText2Image,
)
# --------- Config ---------
MODEL_CHOICES = {
# Solid baseline, license-free to use after accepting on HF if required.
"Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)": "runwayml/stable-diffusion-v1-5",
# Very fast for prototyping; outputs can be less detailed. Best with GPU.
"SDXL Turbo (stabilityai/sdxl-turbo)": "stabilityai/sdxl-turbo",
}
DEFAULT_MODEL_LABEL = "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)"
# Disable safety checker by default (your responsibility). Toggle in UI.
DISABLE_SAFETY_DEFAULT = True
# --------- Runtime helpers ---------
def get_device() -> str:
if torch.cuda.is_available():
return "cuda"
# Spaces don't use Apple MPS; leaving for completeness
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return "mps"
return "cpu"
def nearest_multiple_of_8(x: int) -> int:
if x < 64:
return 64
return int(round(x / 8) * 8)
# Cache pipelines per model to avoid reloading on each call
_PIPE_CACHE = {}
def load_pipe(model_id: str, device: str, fp16: bool) -> DiffusionPipeline:
key = (model_id, device, fp16)
if key in _PIPE_CACHE:
return _PIPE_CACHE[key]
dtype = torch.float16 if (fp16 and device == "cuda") else torch.float32
# AutoPipeline works for many models; we fall back to SD pipeline for v1-5
try:
pipe = AutoPipelineForTextToImage.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
trust_remote_code=False,
)
except Exception:
# Legacy fallback for SD 1.5
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
)
# Send to device
pipe = pipe.to(device)
# Try memory-efficient attention if available
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
_PIPE_CACHE[key] = pipe
return pipe
# --------- Inference ---------
def generate(
prompt: str,
negative: str,
model_label: str,
steps: int,
guidance: float,
width: int,
height: int,
seed: Optional[int],
batch_size: int,
disable_safety: bool,
) -> List[Image.Image]:
prompt = (prompt or "").strip()
if not prompt:
raise gr.Error("Enter a non-empty prompt.")
model_id = MODEL_CHOICES[model_label]
device = get_device()
# SDXL Turbo ignores CFG and uses very low steps; keep sensible defaults
is_turbo = "sdxl-turbo" in model_id.lower()
if is_turbo:
steps = max(1, min(steps, 6)) # turbo is usually 1–6 steps
guidance = 0.0 # turbo uses guidance-free sampling; CFG does nothing
width = nearest_multiple_of_8(width)
height = nearest_multiple_of_8(height)
batch_size = max(1, min(batch_size, 8))
pipe = load_pipe(model_id, device, fp16=(device == "cuda"))
# Safety checker
if hasattr(pipe, "safety_checker"):
pipe.safety_checker = None if disable_safety else pipe.safety_checker
# Determinism
generator = None
if seed is not None and seed != "":
try:
seed = int(seed)
except ValueError:
seed = None
if seed is not None:
if device == "cuda":
generator = torch.Generator(device="cuda").manual_seed(seed)
elif device == "mps":
generator = torch.Generator(device="cpu").manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
prompts = [prompt] * batch_size
negative_prompts = [negative] * batch_size if negative else None
# Run
with torch.autocast("cuda", enabled=(device == "cuda")):
out = pipe(
prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=int(steps),
guidance_scale=float(guidance),
width=int(width),
height=int(height),
generator=generator,
)
images = out.images
return images
# --------- UI ---------
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# Text-to-Image (Diffusers)
- **Models:** SD 1.5 and SDXL Turbo
- **Tip:** SD 1.5 = better detail on CPU; Turbo = very fast on GPU, fewer steps.
"""
)
with gr.Row():
model_dd = gr.Dropdown(
label="Model",
choices=list(MODEL_CHOICES.keys()),
value=DEFAULT_MODEL_LABEL,
)
steps = gr.Slider(1, 75, value=30, step=1, label="Steps")
guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance (CFG)")
with gr.Row():
width = gr.Slider(256, 1024, value=768, step=8, label="Width (multiple of 8)")
height = gr.Slider(256, 1024, value=768, step=8, label="Height (multiple of 8)")
batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size")
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="a cozy cabin at twilight beside a lake, cinematic lighting")
negative = gr.Textbox(label="Negative Prompt", lines=1, placeholder="blurry, low quality, distorted")
with gr.Row():
seed = gr.Textbox(label="Seed (optional integer)", value="")
disable_safety = gr.Checkbox(label="Disable safety checker (you are responsible)", value=DISABLE_SAFETY_DEFAULT)
run_btn = gr.Button("Generate", variant="primary")
gallery = gr.Gallery(label="Results", columns=2, height=512, preview=True)
def _on_change_model(label):
# If Turbo selected, nudge UI to sane defaults
if "Turbo" in label:
return gr.update(value=4), gr.update(value=0.0)
else:
return gr.update(value=30), gr.update(value=7.5)
model_dd.change(_on_change_model, inputs=model_dd, outputs=[steps, guidance])
run_btn.click(
fn=generate,
inputs=[prompt, negative, model_dd, steps, guidance, width, height, seed, batch_size, disable_safety],
outputs=[gallery],
api_name="generate",
scroll_to_output=True,
concurrency_limit=2,
)
if __name__ == "__main__":
# In Spaces, just running the file starts the app. Debug on for clearer stack traces.
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)