ICGenAIShare07 commited on
Commit
3d2b58b
·
verified ·
1 Parent(s): ecc0c37

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -0
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from PIL import Image
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import gradio as gr
8
+ import torch
9
+
10
+ import spaces # type: ignore
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ from safetensors.torch import load_file
14
+
15
+ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
16
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
17
+ from diffusers.models.controlnets.controlnet import ControlNetModel
18
+ from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline
19
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
20
+ from transformers import CLIPTextModel, CLIPTokenizer
21
+
22
+ BIG_CSS = """
23
+ /* Global bump */
24
+ .gradio-container {
25
+ font-size: 18px !important;
26
+ }
27
+
28
+ /* Force most UI text bigger */
29
+ .gradio-container * {
30
+ font-size: 18px !important;
31
+ }
32
+
33
+ /* Keep markdown headings bigger */
34
+ .gradio-container h1 { font-size: 28px !important; }
35
+ .gradio-container h2 { font-size: 24px !important; }
36
+ .gradio-container h3 { font-size: 20px !important; }
37
+
38
+ /* Slightly smaller helper/info text if you want */
39
+ .gradio-container .info,
40
+ .gradio-container .prose p,
41
+ .gradio-container .prose li {
42
+ font-size: 16px !important;
43
+ line-height: 1.35 !important;
44
+ }
45
+ """
46
+
47
+ # -----------------------------
48
+ # Pipeline builder
49
+ # -----------------------------
50
+ def build_controlnet_pipe(
51
+ base_model_name: str,
52
+ controlnet: ControlNetModel,
53
+ vae: AutoencoderKL,
54
+ unet: UNet2DConditionModel,
55
+ text_encoder: CLIPTextModel,
56
+ tokenizer: CLIPTokenizer,
57
+ device: torch.device,
58
+ weight_dtype: torch.dtype,
59
+ use_unipc: bool = True,
60
+ ) -> StableDiffusionControlNetPipeline:
61
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
62
+ base_model_name,
63
+ vae=vae,
64
+ text_encoder=text_encoder,
65
+ tokenizer=tokenizer,
66
+ unet=unet,
67
+ controlnet=controlnet,
68
+ safety_checker=None,
69
+ torch_dtype=weight_dtype,
70
+ )
71
+ if use_unipc:
72
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
73
+ pipe = pipe.to(device)
74
+ pipe.set_progress_bar_config(disable=True)
75
+ return pipe
76
+
77
+ @dataclass
78
+ class CannyCFG:
79
+ use_clahe: bool = True
80
+ clahe_clip: float = 2.0
81
+ clahe_grid: int = 8
82
+
83
+ gaussian_ksize: int = 5
84
+ gaussian_sigma: float = 1.2
85
+
86
+ high_pct: float = 90.0 # higher -> fewer edges (stricter)
87
+ low_ratio: float = 0.4 # low = low_ratio * high
88
+
89
+ aperture_size: int = 3
90
+ l2_gradient: bool = True
91
+
92
+
93
+ def canny_percentile(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image:
94
+ gray = np.array(pil_img.convert("L"), dtype=np.uint8)
95
+
96
+ if cfg.use_clahe:
97
+ clahe = cv2.createCLAHE(
98
+ clipLimit=float(cfg.clahe_clip),
99
+ tileGridSize=(int(cfg.clahe_grid), int(cfg.clahe_grid)),
100
+ )
101
+ gray = clahe.apply(gray)
102
+
103
+ k = int(cfg.gaussian_ksize) | 1 # ensure odd
104
+ blur = cv2.GaussianBlur(gray, (k, k), float(cfg.gaussian_sigma))
105
+
106
+ gx = cv2.Sobel(blur, cv2.CV_32F, 1, 0, ksize=3)
107
+ gy = cv2.Sobel(blur, cv2.CV_32F, 0, 1, ksize=3)
108
+ mag = cv2.magnitude(gx, gy)
109
+
110
+ high = float(np.percentile(mag, float(cfg.high_pct)))
111
+ low = float(cfg.low_ratio) * high
112
+ if high <= low:
113
+ high = low + 1.0
114
+
115
+ ap = int(cfg.aperture_size)
116
+ if ap not in (3, 5, 7):
117
+ ap = 3
118
+
119
+ edges = cv2.Canny(
120
+ blur,
121
+ threshold1=low,
122
+ threshold2=high,
123
+ apertureSize=ap,
124
+ L2gradient=bool(cfg.l2_gradient),
125
+ )
126
+ return Image.fromarray(edges, mode="L")
127
+
128
+
129
+ # -----------------------------
130
+ # Config
131
+ # -----------------------------
132
+ BASE_MODEL = "sd-legacy/stable-diffusion-v1-5"
133
+ WEIGHTS_REPO = "mvp-lab/ControlNet_Weight"
134
+ WEIGHTS_FILENAME = "diffusion_pytorch_model_1.safetensors"
135
+
136
+ LOCAL_WEIGHTS = os.getenv(
137
+ "CONTROLNET_WEIGHTS",
138
+ "/home/nik/ImperialWork/GenerativeAi/sd15-controlnet-trainer/controlnet_laion/final/diffusion_pytorch_model.safetensors",
139
+ )
140
+ if os.path.isfile(LOCAL_WEIGHTS):
141
+ CONTROLNET_PATH = LOCAL_WEIGHTS
142
+ else:
143
+ CONTROLNET_PATH = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILENAME, repo_type="model")
144
+
145
+ DTYPE = torch.float32
146
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+
148
+
149
+ # -----------------------------
150
+ # Model load (once)
151
+ # -----------------------------
152
+ vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE)
153
+ unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet", torch_dtype=DTYPE)
154
+ tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
155
+ text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder", torch_dtype=DTYPE)
156
+
157
+ vae.requires_grad_(False)
158
+ unet.requires_grad_(False)
159
+ text_encoder.requires_grad_(False)
160
+
161
+ controlnet = ControlNetModel.from_unet(unet, conditioning_channels=3)
162
+ state = load_file(CONTROLNET_PATH)
163
+ missing, unexpected = controlnet.load_state_dict(state, strict=False)
164
+
165
+ pipe = build_controlnet_pipe(
166
+ base_model_name=BASE_MODEL,
167
+ controlnet=controlnet,
168
+ vae=vae,
169
+ unet=unet,
170
+ text_encoder=text_encoder,
171
+ tokenizer=tokenizer,
172
+ device=DEVICE,
173
+ weight_dtype=DTYPE,
174
+ use_unipc=True,
175
+ )
176
+
177
+
178
+ # -----------------------------
179
+ # Helpers: fixed resize policy (longest side = 512, keep aspect, divisible by 8)
180
+ # -----------------------------
181
+ def round_down_to_multiple(x: int, m: int = 8) -> int:
182
+ return max(m, (x // m) * m)
183
+
184
+ def resize_longest_side_div8(img: Image.Image, longest: int = 512) -> tuple[Image.Image, int, int]:
185
+ w, h = img.size
186
+ if w <= 0 or h <= 0:
187
+ raise ValueError("Invalid image size")
188
+
189
+ scale = float(longest) / float(max(w, h))
190
+ tw = int(round(w * scale))
191
+ th = int(round(h * scale))
192
+
193
+ tw = round_down_to_multiple(tw, 8)
194
+ th = round_down_to_multiple(th, 8)
195
+
196
+ tw = max(8, tw)
197
+ th = max(8, th)
198
+
199
+ resized = img.resize((tw, th), resample=Image.BICUBIC) # type: ignore
200
+ return resized, tw, th
201
+
202
+ def compute_canny_rgb(img_rgb_resized: Image.Image, use_clahe: bool, edge_amount: float, smoothing: float) -> Image.Image:
203
+ high_pct = 95.0 - 20.0 * float(edge_amount) # 0 => 95 (few), 1 => 75 (many)
204
+ high_pct = float(np.clip(high_pct, 70.0, 99.0))
205
+
206
+ gaussian_sigma = 0.6 + 2.2 * float(smoothing) # 0 => 0.6, 1 => 2.8
207
+
208
+ cfg = CannyCFG(
209
+ use_clahe=bool(use_clahe),
210
+ clahe_clip=2.0,
211
+ clahe_grid=8,
212
+ gaussian_ksize=5,
213
+ gaussian_sigma=float(gaussian_sigma),
214
+ high_pct=float(high_pct),
215
+ low_ratio=0.4,
216
+ aperture_size=3,
217
+ l2_gradient=True,
218
+ )
219
+ edges_l = canny_percentile(img_rgb_resized, cfg)
220
+ return edges_l.convert("RGB")
221
+
222
+ def update_canny_preview(input_image, use_clahe, edge_amount, smoothing):
223
+ if input_image is None:
224
+ return None, None, 512, 512
225
+
226
+ if not isinstance(input_image, Image.Image):
227
+ input_image = Image.fromarray(input_image)
228
+
229
+ img_rgb0 = input_image.convert("RGB")
230
+ img_rgb, width, height = resize_longest_side_div8(img_rgb0, longest=512)
231
+
232
+ canny = compute_canny_rgb(
233
+ img_rgb,
234
+ use_clahe=use_clahe,
235
+ edge_amount=float(edge_amount),
236
+ smoothing=float(smoothing),
237
+ )
238
+ return canny, canny, width, height
239
+
240
+
241
+ @spaces.GPU
242
+ @torch.inference_mode()
243
+ def generate_from_canny(
244
+ canny: Image.Image,
245
+ width: int,
246
+ height: int,
247
+ prompt: str,
248
+ negative_prompt: str,
249
+ guidance_scale: float,
250
+ num_inference_steps: int,
251
+ num_images: int,
252
+ controlnet_conditioning_scale: float,
253
+ ):
254
+ if canny is None:
255
+ raise gr.Error("Canny conditioning image missing. Upload an image first.")
256
+ if int(num_images) < 1:
257
+ raise gr.Error("num_images must be >= 1")
258
+
259
+ gens = [torch.Generator(device=DEVICE).manual_seed(i) for i in range(int(num_images))]
260
+
261
+ imgs = pipe(
262
+ prompt=[prompt] * int(num_images),
263
+ negative_prompt=[negative_prompt] * int(num_images),
264
+ image=[canny] * int(num_images),
265
+ num_inference_steps=int(num_inference_steps),
266
+ guidance_scale=float(guidance_scale),
267
+ height=int(height),
268
+ width=int(width),
269
+ generator=gens,
270
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
271
+ ).images # type: ignore
272
+
273
+ first = imgs[0] if imgs else None
274
+ return first, imgs
275
+
276
+ def next_image(images, idx):
277
+ if not images:
278
+ return None, 0, "0 / 0"
279
+ idx = (int(idx) + 1) % len(images)
280
+ return images[idx], idx, f"{idx + 1} / {len(images)}"
281
+
282
+ def prev_image(images, idx):
283
+ if not images:
284
+ return None, 0, "0 / 0"
285
+ idx = (int(idx) - 1) % len(images)
286
+ return images[idx], idx, f"{idx + 1} / {len(images)}"
287
+
288
+
289
+ # -----------------------------
290
+ # UI
291
+ # -----------------------------
292
+ IMG_H = 360 # uniform-ish size for both preview boxes
293
+
294
+ with gr.Blocks(css=BIG_CSS) as demo:
295
+ gr.Markdown("# Canny-Edge ControlNet Demo")
296
+ gr.Markdown("**Note:** Trained on aesthetic/artistic images — best results come from similar, stylised inputs.")
297
+
298
+ # state
299
+ canny_state = gr.State(None)
300
+ width_state = gr.State(512)
301
+ height_state = gr.State(512)
302
+
303
+ gen_images_state = gr.State([]) # list[PIL]
304
+ gen_index_state = gr.State(0)
305
+
306
+ with gr.Row():
307
+ # ---- Left: Canny + Canny controls ----
308
+ with gr.Column(scale=1):
309
+ input_image = gr.Image(
310
+ label="Input Image",
311
+ type="pil",
312
+ image_mode="RGB",
313
+ height=IMG_H,
314
+ )
315
+
316
+ canny_preview = gr.Image(
317
+ label="Canny edges",
318
+ type="pil",
319
+ height=IMG_H,
320
+ )
321
+
322
+ gr.Markdown("### Edge controls")
323
+ use_clahe = gr.Checkbox(
324
+ label="Stabilise contrast (CLAHE)",
325
+ value=True,
326
+ info="Helps edges stay consistent under different lighting/contrast.",
327
+ )
328
+ edge_amount = gr.Slider(
329
+ label="Edge Amount",
330
+ minimum=0.0, maximum=1.0, value=0.6, step=0.01,
331
+ info="More = detect more edges (more detail). Less = cleaner outline.",
332
+ )
333
+ smoothing = gr.Slider(
334
+ label="Smoothing",
335
+ minimum=0.0, maximum=1.0, value=0.4, step=0.01,
336
+ info="More = reduce tiny texture/noise edges, cleaner structure.",
337
+ )
338
+
339
+ # ---- Right: Generated output + generation controls ----
340
+ with gr.Column(scale=1):
341
+ generated = gr.Image(
342
+ label="Generated image",
343
+ type="pil",
344
+ height=IMG_H,
345
+ )
346
+
347
+ with gr.Row():
348
+ prev_btn = gr.Button("◀ Prev")
349
+ page_label = gr.Markdown("0 / 0")
350
+ next_btn = gr.Button("Next ▶")
351
+
352
+ gr.Markdown("### Generation controls")
353
+ positive_prompt = gr.Textbox(
354
+ label="Positive Prompt",
355
+ value="",
356
+ lines=2,
357
+ info="Describe what you want. The edges guide the structure.",
358
+ )
359
+ negative_prompt = gr.Textbox(
360
+ label="Negative Prompt",
361
+ value="",
362
+ lines=2,
363
+ info="Things to avoid (e.g. blurry, deformed, low quality).",
364
+ )
365
+
366
+ with gr.Row():
367
+ guidance_scale = gr.Slider(
368
+ label="Guidance Scale",
369
+ minimum=1.0, maximum=15.0, value=7.5, step=0.1,
370
+ info="Higher = follow text prompt more strongly (can drift from edges).",
371
+ )
372
+ controlnet_conditioning_scale = gr.Slider(
373
+ label="Control Strength",
374
+ minimum=0.0, maximum=2.0, value=1.0, step=0.05,
375
+ info="Higher = follow edges more strongly. Too high can reduce creativity.",
376
+ )
377
+
378
+ with gr.Row():
379
+ num_inference_steps = gr.Slider(
380
+ label="Steps",
381
+ minimum=10, maximum=80, value=50, step=1,
382
+ info="More steps can improve quality but is slower.",
383
+ )
384
+ num_images = gr.Slider(
385
+ label="Samples",
386
+ minimum=1, maximum=8, value=4, step=1,
387
+ info="How many images to generate.",
388
+ )
389
+
390
+ run_btn = gr.Button("Generate", variant="primary")
391
+
392
+ # Auto-update Canny preview on changes (CPU)
393
+ auto_inputs = [input_image, use_clahe, edge_amount, smoothing]
394
+ for c in auto_inputs:
395
+ c.change(
396
+ fn=update_canny_preview,
397
+ inputs=auto_inputs,
398
+ outputs=[canny_preview, canny_state, width_state, height_state],
399
+ )
400
+
401
+ # Generate (GPU) -> store list -> show first -> update paging label
402
+ run_btn.click(
403
+ fn=generate_from_canny,
404
+ inputs=[
405
+ canny_state,
406
+ width_state,
407
+ height_state,
408
+ positive_prompt,
409
+ negative_prompt,
410
+ guidance_scale,
411
+ num_inference_steps,
412
+ num_images,
413
+ controlnet_conditioning_scale,
414
+ ],
415
+ outputs=[generated, gen_images_state], # visible output first => proper "Generating..." UX
416
+ ).then(
417
+ fn=lambda imgs: (0, f"1 / {len(imgs)}") if imgs else (0, "0 / 0"),
418
+ inputs=[gen_images_state],
419
+ outputs=[gen_index_state, page_label],
420
+ )
421
+
422
+ # Paging buttons (CPU)
423
+ next_btn.click(
424
+ fn=next_image,
425
+ inputs=[gen_images_state, gen_index_state],
426
+ outputs=[generated, gen_index_state, page_label],
427
+ )
428
+ prev_btn.click(
429
+ fn=prev_image,
430
+ inputs=[gen_images_state, gen_index_state],
431
+ outputs=[generated, gen_index_state, page_label],
432
+ )
433
+
434
+ if __name__ == "__main__":
435
+ demo.launch()