File size: 6,811 Bytes
45fb6b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# app.py
# Text-to-Image Space using Diffusers + Gradio
# Works on CPU (slow) and GPU (recommended). Choose a model in the UI.

import os
import math
import torch
import gradio as gr
from typing import List, Optional
from PIL import Image
from diffusers import (
    DiffusionPipeline,
    StableDiffusionPipeline,
    AutoPipelineForText2Image,
)

# --------- Config ---------
MODEL_CHOICES = {
    # Solid baseline, license-free to use after accepting on HF if required.
    "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)": "runwayml/stable-diffusion-v1-5",
    # Very fast for prototyping; outputs can be less detailed. Best with GPU.
    "SDXL Turbo (stabilityai/sdxl-turbo)": "stabilityai/sdxl-turbo",
}

DEFAULT_MODEL_LABEL = "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)"

# Disable safety checker by default (your responsibility). Toggle in UI.
DISABLE_SAFETY_DEFAULT = True

# --------- Runtime helpers ---------
def get_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    # Spaces don't use Apple MPS; leaving for completeness
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def nearest_multiple_of_8(x: int) -> int:
    if x < 64:
        return 64
    return int(round(x / 8) * 8)

# Cache pipelines per model to avoid reloading on each call
_PIPE_CACHE = {}

def load_pipe(model_id: str, device: str, fp16: bool) -> DiffusionPipeline:
    key = (model_id, device, fp16)
    if key in _PIPE_CACHE:
        return _PIPE_CACHE[key]

    dtype = torch.float16 if (fp16 and device == "cuda") else torch.float32

    # AutoPipeline works for many models; we fall back to SD pipeline for v1-5
    try:
        pipe = AutoPipelineForTextToImage.from_pretrained(
            model_id,
            torch_dtype=dtype,
            use_safetensors=True,
            trust_remote_code=False,
        )
    except Exception:
        # Legacy fallback for SD 1.5
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=dtype,
            use_safetensors=True,
        )

    # Send to device
    pipe = pipe.to(device)

    # Try memory-efficient attention if available
    if device == "cuda":
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception:
            pass

    _PIPE_CACHE[key] = pipe
    return pipe

# --------- Inference ---------
def generate(
    prompt: str,
    negative: str,
    model_label: str,
    steps: int,
    guidance: float,
    width: int,
    height: int,
    seed: Optional[int],
    batch_size: int,
    disable_safety: bool,
) -> List[Image.Image]:
    prompt = (prompt or "").strip()
    if not prompt:
        raise gr.Error("Enter a non-empty prompt.")

    model_id = MODEL_CHOICES[model_label]
    device = get_device()

    # SDXL Turbo ignores CFG and uses very low steps; keep sensible defaults
    is_turbo = "sdxl-turbo" in model_id.lower()
    if is_turbo:
        steps = max(1, min(steps, 6))  # turbo is usually 1–6 steps
        guidance = 0.0  # turbo uses guidance-free sampling; CFG does nothing

    width = nearest_multiple_of_8(width)
    height = nearest_multiple_of_8(height)
    batch_size = max(1, min(batch_size, 8))

    pipe = load_pipe(model_id, device, fp16=(device == "cuda"))

    # Safety checker
    if hasattr(pipe, "safety_checker"):
        pipe.safety_checker = None if disable_safety else pipe.safety_checker

    # Determinism
    generator = None
    if seed is not None and seed != "":
        try:
            seed = int(seed)
        except ValueError:
            seed = None
        if seed is not None:
            if device == "cuda":
                generator = torch.Generator(device="cuda").manual_seed(seed)
            elif device == "mps":
                generator = torch.Generator(device="cpu").manual_seed(seed)
            else:
                generator = torch.Generator(device="cpu").manual_seed(seed)

    prompts = [prompt] * batch_size
    negative_prompts = [negative] * batch_size if negative else None

    # Run
    with torch.autocast("cuda", enabled=(device == "cuda")):
        out = pipe(
            prompt=prompts,
            negative_prompt=negative_prompts,
            num_inference_steps=int(steps),
            guidance_scale=float(guidance),
            width=int(width),
            height=int(height),
            generator=generator,
        )

    images = out.images
    return images

# --------- UI ---------
with gr.Blocks(css="footer {visibility: hidden}") as demo:
    gr.Markdown(
        """
        # Text-to-Image (Diffusers)
        - **Models:** SD 1.5 and SDXL Turbo
        - **Tip:** SD 1.5 = better detail on CPU; Turbo = very fast on GPU, fewer steps.
        """
    )

    with gr.Row():
        model_dd = gr.Dropdown(
            label="Model",
            choices=list(MODEL_CHOICES.keys()),
            value=DEFAULT_MODEL_LABEL,
        )
        steps = gr.Slider(1, 75, value=30, step=1, label="Steps")
        guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance (CFG)")

    with gr.Row():
        width = gr.Slider(256, 1024, value=768, step=8, label="Width (multiple of 8)")
        height = gr.Slider(256, 1024, value=768, step=8, label="Height (multiple of 8)")
        batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size")

    prompt = gr.Textbox(label="Prompt", lines=2, placeholder="a cozy cabin at twilight beside a lake, cinematic lighting")
    negative = gr.Textbox(label="Negative Prompt", lines=1, placeholder="blurry, low quality, distorted")
    with gr.Row():
        seed = gr.Textbox(label="Seed (optional integer)", value="")
        disable_safety = gr.Checkbox(label="Disable safety checker (you are responsible)", value=DISABLE_SAFETY_DEFAULT)

    run_btn = gr.Button("Generate", variant="primary")
    gallery = gr.Gallery(label="Results", columns=2, height=512, preview=True)

    def _on_change_model(label):
        # If Turbo selected, nudge UI to sane defaults
        if "Turbo" in label:
            return gr.update(value=4), gr.update(value=0.0)
        else:
            return gr.update(value=30), gr.update(value=7.5)

    model_dd.change(_on_change_model, inputs=model_dd, outputs=[steps, guidance])

    run_btn.click(
        fn=generate,
        inputs=[prompt, negative, model_dd, steps, guidance, width, height, seed, batch_size, disable_safety],
        outputs=[gallery],
        api_name="generate",
        scroll_to_output=True,
        concurrency_limit=2,
    )

if __name__ == "__main__":
    # In Spaces, just running the file starts the app. Debug on for clearer stack traces.
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)