Theloomvale commited on
Commit
0371a09
·
verified ·
1 Parent(s): fab8b33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -73
app.py CHANGED
@@ -1,100 +1,334 @@
 
 
1
  import os
2
- label="Model",
3
- choices=list(DEFAULT_MODELS.keys()),
4
- value="Stable Diffusion 1.5 (fastest)",
5
- )
6
- model_id_state = gr.State(DEFAULT_MODELS["Stable Diffusion 1.5 (fastest)"])
7
-
8
-
9
- script = gr.Textbox(
10
- label="Prompt or Multi-Scene Script",
11
- lines=6,
12
- placeholder=(
13
- "Optional ambience on top...\n\n"
14
- "Scene 1: A cozy studio filled with soft morning light\n"
15
- "Scene 2: A minimalist desk with a steaming cup of tea\n"
16
- "Scene 3: ..."
17
- ),
18
  )
 
 
19
 
 
 
 
 
 
 
 
20
 
21
- negative = gr.Textbox(
22
- label="Negative Prompt (optional)",
23
- placeholder="blurry, low quality, watermark, text, nsfw",
24
- value="blurry, low quality, watermark, text",
25
- )
26
 
 
27
 
28
- w = gr.Slider(384, 1024, value=512, step=8, label="Width")
29
- h = gr.Slider(512, 1280, value=768, step=8, label="Height")
30
- steps = gr.Slider(10, 60, value=28, step=1, label="Steps")
31
- guidance = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="Guidance Scale")
32
- seed = gr.Number(value=-1, label="Seed (-1 = random)")
33
 
34
 
35
- # Decide if saving to repo is available
36
- can_save = bool(HF_TOKEN and SPACE_ID)
37
- save_to_repo = gr.Checkbox(
38
- label=f"Save generated images to this Space repo ({SPACE_ID})",
39
- value=can_save,
40
- interactive=can_save,
41
- visible=True,
42
- )
43
 
 
 
 
 
 
44
 
45
- btn = gr.Button("Generate Images", variant="primary")
46
- gallery = gr.Gallery(label="Images", columns=5, rows=1, height="auto")
47
- status = gr.Markdown(visible=True)
 
48
 
 
 
49
 
50
- def _sync_model_choice(choice):
51
- mid = DEFAULT_MODELS[choice]
52
- base_w, base_h = DEFAULT_W_H[mid]
53
- return mid, gr.update(value=base_w), gr.update(value=base_h)
54
 
 
 
 
 
 
 
 
 
 
55
 
56
- model.change(_sync_model_choice, inputs=model, outputs=[model_id_state, w, h])
 
 
57
 
 
 
 
 
 
 
 
58
 
59
- async def _on_click(script_text, negative_prompt, _model_choice, _model_id, width, height, steps_, guidance_, seed_, save_flag):
60
- imgs = await generate_per_scene(
61
- script_text=script_text,
62
- negative_prompt=negative_prompt,
63
- model_id=_model_id,
64
- width=int(width),
65
- height=int(height),
66
- steps=int(steps_),
67
- guidance=float(guidance_),
68
- seed=int(seed_),
69
- )
70
- msg = f"✅ Generated {len(imgs)} image(s)."
71
 
 
 
72
 
73
- links = []
74
- if save_flag:
75
- try:
76
- links = _save_images_to_repo(imgs)
77
- if links:
78
- msg += "\\nSaved: " + ", ".join(links)
79
- except Exception as e:
80
- print("[save_error]", e)
81
- msg += "\\n⚠️ Save failed (see logs)."
82
 
83
 
84
- return imgs, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
- btn.click(
88
- _on_click,
89
- inputs=[script, negative, model, model_id_state, w, h, steps, guidance, seed, save_to_repo],
90
- outputs=[gallery, status],
91
- concurrency_limit=1,
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
 
 
 
 
94
 
 
 
95
 
 
96
 
97
- # Queued interface is important for CPU workloads
98
  if __name__ == "__main__":
99
- demo.queue()
100
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
1
+ # app.py
2
+
3
  import os
4
+ import io
5
+ import re
6
+ import random
7
+ import asyncio
8
+ from typing import List, Optional, Tuple
9
+ from datetime import datetime
10
+
11
+ import torch
12
+ import gradio as gr
13
+ from diffusers import (
14
+ StableDiffusionPipeline,
15
+ StableDiffusionXLPipeline,
 
 
 
 
16
  )
17
+ from huggingface_hub import HfApi
18
+ from PIL import Image
19
 
20
+ # ----------------------
21
+ # Constants & Utilities
22
+ # ----------------------
23
+ DEFAULT_MODELS = {
24
+ "Stable Diffusion 1.5 (fastest)": "runwayml/stable-diffusion-v1-5",
25
+ "Stable Diffusion XL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
26
+ }
27
 
28
+ # CPU-friendly defaults; auto-updated on model switch.
29
+ DEFAULT_W_H = {
30
+ "runwayml/stable-diffusion-v1-5": (512, 768),
31
+ "stabilityai/stable-diffusion-xl-base-1.0": (768, 1024),
32
+ }
33
 
34
+ SCENE_HEADER = re.compile(r"^\s*Scene\s*\d+\s*[:\-–]", re.IGNORECASE | re.MULTILINE)
35
 
36
+ PIPELINES = {}
37
+ API = HfApi()
38
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
39
+ SPACE_ID = os.environ.get("SPACE_ID") or os.environ.get("SPACE_REPO")
 
40
 
41
 
42
+ def get_pipeline(model_id: str):
43
+ """Load & cache a pipeline for CPU usage."""
44
+ if model_id in PIPELINES:
45
+ return PIPELINES[model_id]
 
 
 
 
46
 
47
+ dtype = torch.float32 # CPU-safe
48
+ if "stable-diffusion-xl" in model_id:
49
+ pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype)
50
+ else:
51
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
52
 
53
+ pipe = pipe.to("cpu")
54
+ pipe.enable_attention_slicing()
55
+ pipe.enable_vae_slicing()
56
+ pipe.safety_checker = None # assuming safe usage/content policy is handled upstream
57
 
58
+ PIPELINES[model_id] = pipe
59
+ return pipe
60
 
 
 
 
 
61
 
62
+ def split_into_scene_prompts(text: str) -> List[str]:
63
+ """Split input script into up to 5 scene prompts.
64
+ - If no explicit Scene headers are found, repeat the whole text to make 5 prompts.
65
+ - If fewer than 5 scenes, pad with the last scene.
66
+ - If more than 5, truncate to 5.
67
+ """
68
+ text = (text or "").strip()
69
+ if not text:
70
+ return []
71
 
72
+ headers = list(SCENE_HEADER.finditer(text))
73
+ if not headers:
74
+ return [text] * 5
75
 
76
+ ambience = text[: headers[0].start()].strip()
77
+ blocks = []
78
+ for i, m in enumerate(headers):
79
+ start = m.start()
80
+ end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
81
+ block = text[start:end].strip()
82
+ blocks.append(block)
83
 
84
+ if len(blocks) < 5 and blocks:
85
+ blocks += [blocks[-1]] * (5 - len(blocks))
86
+ elif len(blocks) > 5:
87
+ blocks = blocks[:5]
 
 
 
 
 
 
 
 
88
 
89
+ if ambience:
90
+ blocks = [f"{ambience}\n\n{b}" for b in blocks]
91
 
92
+ return blocks
 
 
 
 
 
 
 
 
93
 
94
 
95
+ def clamp_size(model_id: str, width: int, height: int) -> Tuple[int, int]:
96
+ """Keep sizes reasonable for CPU and aligned to multiples of 8."""
97
+ w, h = int(width), int(height)
98
+ w -= w % 8
99
+ h -= h % 8
100
+ if "stable-diffusion-xl" in model_id:
101
+ # SDXL works best with longer edge >= ~768; constrain for CPU
102
+ w = max(640, min(w, 1152))
103
+ h = max(640, min(h, 1152))
104
+ else:
105
+ # SD 1.5 sweet spot; keep safe caps for CPU
106
+ w = max(384, min(w, 896))
107
+ h = max(384, min(h, 1152))
108
+ return w, h
109
 
110
 
111
+ def _seed_everything(seed: Optional[int]):
112
+ if seed is None or seed < 0:
113
+ seed = random.randint(0, 2**32 - 1)
114
+ generator = torch.Generator(device="cpu").manual_seed(seed)
115
+ return seed, generator
116
+
117
+
118
+ def _generate_one(
119
+ prompt: str,
120
+ negative_prompt: str,
121
+ model_id: str,
122
+ width: int,
123
+ height: int,
124
+ steps: int,
125
+ guidance: float,
126
+ seed: int,
127
+ ) -> Image.Image:
128
+ seed, generator = _seed_everything(seed)
129
+ pipe = get_pipeline(model_id)
130
+ with torch.inference_mode():
131
+ image = pipe(
132
+ prompt=prompt,
133
+ negative_prompt=negative_prompt or None,
134
+ width=width,
135
+ height=height,
136
+ num_inference_steps=steps,
137
+ guidance_scale=guidance,
138
+ generator=generator,
139
+ ).images[0]
140
+ return image
141
+
142
+
143
+ async def _generate_one_async(**kwargs) -> Image.Image:
144
+ return await asyncio.to_thread(_generate_one, **kwargs)
145
+
146
+
147
+ async def generate_per_scene(
148
+ script_text: str,
149
+ negative_prompt: str,
150
+ model_id: str,
151
+ width: int,
152
+ height: int,
153
+ steps: int,
154
+ guidance: float,
155
+ seed: int,
156
+ ):
157
+ """Sequential generation (CPU-friendly) with progress feedback."""
158
+ prompts = split_into_scene_prompts(script_text)
159
+ if not prompts:
160
+ raise gr.Error("Please enter a prompt or scene script.")
161
+
162
+ images: List[Image.Image] = []
163
+ total = len(prompts)
164
+ progress = gr.Progress(track_tqdm=True)
165
+
166
+ for i, p in enumerate(prompts, start=1):
167
+ progress(i / total, desc=f"Generating scene {i}/{total}")
168
+ try:
169
+ img = await _generate_one_async(
170
+ prompt=p,
171
+ negative_prompt=negative_prompt,
172
+ model_id=model_id,
173
+ width=width,
174
+ height=height,
175
+ steps=steps,
176
+ guidance=guidance,
177
+ seed=seed + (i - 1) if seed >= 0 else seed,
178
+ )
179
+ except Exception as e:
180
+ print(f"[error] scene {i} failed:", e)
181
+ img = Image.new("RGB", (width, height), color=(220, 220, 220))
182
+ images.append(img)
183
+
184
+ return images
185
+
186
+
187
+ def _save_images_to_repo(imgs: List[Image.Image], subdir: str = "outputs") -> List[str]:
188
+ """Save to the Space repo if HF_TOKEN & SPACE_ID are set. Returns repo paths."""
189
+ if not (HF_TOKEN and SPACE_ID):
190
+ return []
191
+ ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
192
+ paths = []
193
+ for idx, img in enumerate(imgs, start=1):
194
+ buf = io.BytesIO()
195
+ img.save(buf, format="PNG")
196
+ buf.seek(0)
197
+ remote_path = f"{subdir}/{ts}_scene{idx}.png"
198
+ API.upload_file(
199
+ path_or_fileobj=buf,
200
+ path_in_repo=remote_path,
201
+ repo_id=SPACE_ID,
202
+ repo_type="space",
203
+ )
204
+ paths.append(remote_path)
205
+ return paths
206
+
207
+
208
+ def validate_inputs(script_text: str, steps: int, guidance: float):
209
+ if not script_text or not script_text.strip():
210
+ raise gr.Error("Please enter a prompt or scene script.")
211
+ if not (10 <= int(steps) <= 60):
212
+ raise gr.Error("Steps must be between 10 and 60.")
213
+ if not (1.0 <= float(guidance) <= 12.0):
214
+ raise gr.Error("Guidance must be between 1.0 and 12.0.")
215
+
216
+
217
+ with gr.Blocks(title="Loomvale Image Lab — CPU") as demo:
218
+ gr.Markdown("""
219
+ # Loomvale Image Lab — CPU
220
+ Enter a single prompt or a multi-scene script using headings like **Scene 1: ...**, **Scene 2: ...**.
221
+ The app will generate up to **5** images (padding/truncating as needed).
222
+ """)
223
+
224
+ with gr.Row():
225
+ model = gr.Dropdown(
226
+ label="Model",
227
+ choices=list(DEFAULT_MODELS.keys()),
228
+ value="Stable Diffusion 1.5 (fastest)",
229
+ )
230
+ model_id_state = gr.State(DEFAULT_MODELS["Stable Diffusion 1.5 (fastest)"])
231
+
232
+ script = gr.Textbox(
233
+ label="Prompt or Multi-Scene Script",
234
+ lines=6,
235
+ placeholder=(
236
+ "Optional ambience on top...\n\n"
237
+ "Scene 1: A cozy studio filled with soft morning light\n"
238
+ "Scene 2: A minimalist desk with a steaming cup of tea\n"
239
+ "Scene 3: ..."
240
+ ),
241
+ )
242
+
243
+ negative = gr.Textbox(
244
+ label="Negative Prompt (optional)",
245
+ placeholder="blurry, low quality, watermark, text, nsfw",
246
+ value="blurry, low quality, watermark, text, worst quality, lowres",
247
+ )
248
+
249
+ w = gr.Slider(384, 1024, value=512, step=8, label="Width")
250
+ h = gr.Slider(512, 1280, value=768, step=8, label="Height")
251
+ steps = gr.Slider(10, 60, value=28, step=1, label="Steps")
252
+ guidance = gr.Slider(1.0, 12.0, value=7.0, step=0.1, label="Guidance Scale")
253
+ seed = gr.Number(value=-1, label="Seed (-1 = random)")
254
+
255
+ can_save = bool(HF_TOKEN and SPACE_ID)
256
+ save_to_repo = gr.Checkbox(
257
+ label=f"Save generated images to this Space repo ({SPACE_ID})",
258
+ value=can_save,
259
+ interactive=can_save,
260
+ visible=True,
261
+ )
262
+
263
+ btn = gr.Button("Generate Images", variant="primary")
264
+ btn_clear = gr.Button("Clear")
265
+ gallery = gr.Gallery(label="Images", columns=5, rows=1, height="auto", allow_preview=True)
266
+ gallery.style(grid=5, preview=True, object_fit="contain") # keep layout tidy
267
+ status = gr.Markdown(visible=True)
268
+
269
+ # Examples for quick testing
270
+ gr.Examples(
271
+ examples=[
272
+ ["Ambient: gentle morning light\n\nScene 1: pastel living room\nScene 2: sunlight on linen curtains\nScene 3: ceramic mug on wooden table"],
273
+ ["Scene 1: cyberpunk alley, neon reflections\nScene 2: rooftop garden at dusk\nScene 3: rainy crosswalk with umbrellas"],
274
+ ],
275
+ inputs=[script],
276
+ label="Examples",
277
+ )
278
+
279
+ def _sync_model_choice(choice):
280
+ mid = DEFAULT_MODELS[choice]
281
+ base_w, base_h = DEFAULT_W_H[mid]
282
+ return mid, gr.update(value=base_w), gr.update(value=base_h)
283
+
284
+ model.change(_sync_model_choice, inputs=model, outputs=[model_id_state, w, h])
285
+
286
+ async def _on_click(
287
+ script_text, negative_prompt, _model_choice, _model_id, width, height, steps_, guidance_, seed_, save_flag
288
+ ):
289
+ validate_inputs(script_text, steps_, guidance_)
290
+ w_clamped, h_clamped = clamp_size(_model_id, int(width), int(height))
291
+
292
+ imgs = await generate_per_scene(
293
+ script_text=script_text,
294
+ negative_prompt=negative_prompt,
295
+ model_id=_model_id,
296
+ width=w_clamped,
297
+ height=h_clamped,
298
+ steps=int(steps_),
299
+ guidance=float(guidance_),
300
+ seed=int(seed_),
301
+ )
302
+
303
+ msg = f"✅ Generated {len(imgs)} image(s) at {w_clamped}×{h_clamped}."
304
+
305
+ links = []
306
+ if save_flag:
307
+ try:
308
+ links = _save_images_to_repo(imgs)
309
+ if links:
310
+ saved_list = "\n".join(f"- {p}" for p in links)
311
+ msg += f"\nSaved:\n{saved_list}"
312
+ else:
313
+ msg += "\nℹ️ Skipped saving (token/repo not configured)."
314
+ except Exception as e:
315
+ print("[save_error]", e)
316
+ msg += "\n⚠️ Save failed (see logs)."
317
+
318
+ return imgs, msg
319
 
320
+ btn.click(
321
+ _on_click,
322
+ inputs=[script, negative, model, model_id_state, w, h, steps, guidance, seed, save_to_repo],
323
+ outputs=[gallery, status],
324
+ concurrency_limit=1,
325
+ )
326
 
327
+ def _on_clear():
328
+ return None, ""
329
 
330
+ btn_clear.click(_on_clear, outputs=[gallery, status])
331
 
 
332
  if __name__ == "__main__":
333
+ demo.queue()
334
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))