achase25 commited on
Commit
45fb6b9
·
verified ·
1 Parent(s): 684e6c6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Text-to-Image Space using Diffusers + Gradio
3
+ # Works on CPU (slow) and GPU (recommended). Choose a model in the UI.
4
+
5
+ import os
6
+ import math
7
+ import torch
8
+ import gradio as gr
9
+ from typing import List, Optional
10
+ from PIL import Image
11
+ from diffusers import (
12
+ DiffusionPipeline,
13
+ StableDiffusionPipeline,
14
+ AutoPipelineForText2Image,
15
+ )
16
+
17
+ # --------- Config ---------
18
+ MODEL_CHOICES = {
19
+ # Solid baseline, license-free to use after accepting on HF if required.
20
+ "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)": "runwayml/stable-diffusion-v1-5",
21
+ # Very fast for prototyping; outputs can be less detailed. Best with GPU.
22
+ "SDXL Turbo (stabilityai/sdxl-turbo)": "stabilityai/sdxl-turbo",
23
+ }
24
+
25
+ DEFAULT_MODEL_LABEL = "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)"
26
+
27
+ # Disable safety checker by default (your responsibility). Toggle in UI.
28
+ DISABLE_SAFETY_DEFAULT = True
29
+
30
+ # --------- Runtime helpers ---------
31
+ def get_device() -> str:
32
+ if torch.cuda.is_available():
33
+ return "cuda"
34
+ # Spaces don't use Apple MPS; leaving for completeness
35
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
36
+ return "mps"
37
+ return "cpu"
38
+
39
+ def nearest_multiple_of_8(x: int) -> int:
40
+ if x < 64:
41
+ return 64
42
+ return int(round(x / 8) * 8)
43
+
44
+ # Cache pipelines per model to avoid reloading on each call
45
+ _PIPE_CACHE = {}
46
+
47
+ def load_pipe(model_id: str, device: str, fp16: bool) -> DiffusionPipeline:
48
+ key = (model_id, device, fp16)
49
+ if key in _PIPE_CACHE:
50
+ return _PIPE_CACHE[key]
51
+
52
+ dtype = torch.float16 if (fp16 and device == "cuda") else torch.float32
53
+
54
+ # AutoPipeline works for many models; we fall back to SD pipeline for v1-5
55
+ try:
56
+ pipe = AutoPipelineForTextToImage.from_pretrained(
57
+ model_id,
58
+ torch_dtype=dtype,
59
+ use_safetensors=True,
60
+ trust_remote_code=False,
61
+ )
62
+ except Exception:
63
+ # Legacy fallback for SD 1.5
64
+ pipe = StableDiffusionPipeline.from_pretrained(
65
+ model_id,
66
+ torch_dtype=dtype,
67
+ use_safetensors=True,
68
+ )
69
+
70
+ # Send to device
71
+ pipe = pipe.to(device)
72
+
73
+ # Try memory-efficient attention if available
74
+ if device == "cuda":
75
+ try:
76
+ pipe.enable_xformers_memory_efficient_attention()
77
+ except Exception:
78
+ pass
79
+
80
+ _PIPE_CACHE[key] = pipe
81
+ return pipe
82
+
83
+ # --------- Inference ---------
84
+ def generate(
85
+ prompt: str,
86
+ negative: str,
87
+ model_label: str,
88
+ steps: int,
89
+ guidance: float,
90
+ width: int,
91
+ height: int,
92
+ seed: Optional[int],
93
+ batch_size: int,
94
+ disable_safety: bool,
95
+ ) -> List[Image.Image]:
96
+ prompt = (prompt or "").strip()
97
+ if not prompt:
98
+ raise gr.Error("Enter a non-empty prompt.")
99
+
100
+ model_id = MODEL_CHOICES[model_label]
101
+ device = get_device()
102
+
103
+ # SDXL Turbo ignores CFG and uses very low steps; keep sensible defaults
104
+ is_turbo = "sdxl-turbo" in model_id.lower()
105
+ if is_turbo:
106
+ steps = max(1, min(steps, 6)) # turbo is usually 1–6 steps
107
+ guidance = 0.0 # turbo uses guidance-free sampling; CFG does nothing
108
+
109
+ width = nearest_multiple_of_8(width)
110
+ height = nearest_multiple_of_8(height)
111
+ batch_size = max(1, min(batch_size, 8))
112
+
113
+ pipe = load_pipe(model_id, device, fp16=(device == "cuda"))
114
+
115
+ # Safety checker
116
+ if hasattr(pipe, "safety_checker"):
117
+ pipe.safety_checker = None if disable_safety else pipe.safety_checker
118
+
119
+ # Determinism
120
+ generator = None
121
+ if seed is not None and seed != "":
122
+ try:
123
+ seed = int(seed)
124
+ except ValueError:
125
+ seed = None
126
+ if seed is not None:
127
+ if device == "cuda":
128
+ generator = torch.Generator(device="cuda").manual_seed(seed)
129
+ elif device == "mps":
130
+ generator = torch.Generator(device="cpu").manual_seed(seed)
131
+ else:
132
+ generator = torch.Generator(device="cpu").manual_seed(seed)
133
+
134
+ prompts = [prompt] * batch_size
135
+ negative_prompts = [negative] * batch_size if negative else None
136
+
137
+ # Run
138
+ with torch.autocast("cuda", enabled=(device == "cuda")):
139
+ out = pipe(
140
+ prompt=prompts,
141
+ negative_prompt=negative_prompts,
142
+ num_inference_steps=int(steps),
143
+ guidance_scale=float(guidance),
144
+ width=int(width),
145
+ height=int(height),
146
+ generator=generator,
147
+ )
148
+
149
+ images = out.images
150
+ return images
151
+
152
+ # --------- UI ---------
153
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
154
+ gr.Markdown(
155
+ """
156
+ # Text-to-Image (Diffusers)
157
+ - **Models:** SD 1.5 and SDXL Turbo
158
+ - **Tip:** SD 1.5 = better detail on CPU; Turbo = very fast on GPU, fewer steps.
159
+ """
160
+ )
161
+
162
+ with gr.Row():
163
+ model_dd = gr.Dropdown(
164
+ label="Model",
165
+ choices=list(MODEL_CHOICES.keys()),
166
+ value=DEFAULT_MODEL_LABEL,
167
+ )
168
+ steps = gr.Slider(1, 75, value=30, step=1, label="Steps")
169
+ guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance (CFG)")
170
+
171
+ with gr.Row():
172
+ width = gr.Slider(256, 1024, value=768, step=8, label="Width (multiple of 8)")
173
+ height = gr.Slider(256, 1024, value=768, step=8, label="Height (multiple of 8)")
174
+ batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size")
175
+
176
+ prompt = gr.Textbox(label="Prompt", lines=2, placeholder="a cozy cabin at twilight beside a lake, cinematic lighting")
177
+ negative = gr.Textbox(label="Negative Prompt", lines=1, placeholder="blurry, low quality, distorted")
178
+ with gr.Row():
179
+ seed = gr.Textbox(label="Seed (optional integer)", value="")
180
+ disable_safety = gr.Checkbox(label="Disable safety checker (you are responsible)", value=DISABLE_SAFETY_DEFAULT)
181
+
182
+ run_btn = gr.Button("Generate", variant="primary")
183
+ gallery = gr.Gallery(label="Results", columns=2, height=512, preview=True)
184
+
185
+ def _on_change_model(label):
186
+ # If Turbo selected, nudge UI to sane defaults
187
+ if "Turbo" in label:
188
+ return gr.update(value=4), gr.update(value=0.0)
189
+ else:
190
+ return gr.update(value=30), gr.update(value=7.5)
191
+
192
+ model_dd.change(_on_change_model, inputs=model_dd, outputs=[steps, guidance])
193
+
194
+ run_btn.click(
195
+ fn=generate,
196
+ inputs=[prompt, negative, model_dd, steps, guidance, width, height, seed, batch_size, disable_safety],
197
+ outputs=[gallery],
198
+ api_name="generate",
199
+ scroll_to_output=True,
200
+ concurrency_limit=2,
201
+ )
202
+
203
+ if __name__ == "__main__":
204
+ # In Spaces, just running the file starts the app. Debug on for clearer stack traces.
205
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)