priyadip commited on
Commit
affbfcd
Β·
verified Β·
1 Parent(s): 1c41084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +391 -278
app.py CHANGED
@@ -5,20 +5,16 @@ import matplotlib.pyplot as plt
5
  import warnings
6
  warnings.filterwarnings("ignore")
7
 
8
- # Monkey-patch matplotlib.use so that graph_cut_segmentation.py's
9
- # module-level `matplotlib.use("TkAgg")` call becomes a no-op on import.
10
- # We cannot modify that file, so we intercept here instead.
11
  _real_use = matplotlib.use
12
  matplotlib.use = lambda *a, **kw: None
13
 
14
- # ── Standard imports ───────────────────────────────────────────────────
15
  import gradio as gr
16
  import numpy as np
17
  import cv2
18
  import io
19
  from PIL import Image
20
 
21
- # ── Pipeline imports (TkAgg call inside module is now intercepted) ─────
22
  from graph_cut_segmentation import (
23
  iterative_graph_cut,
24
  refine_segmentation,
@@ -29,8 +25,7 @@ from graph_cut_segmentation import (
29
  generate_auto_annotations,
30
  )
31
 
32
- # Restore real matplotlib.use now that the problematic import is done
33
- matplotlib.use = _real_use
34
 
35
 
36
  # ══════════════════════════════════════════════════════════════════════
@@ -46,20 +41,13 @@ def to_numpy(img):
46
 
47
 
48
  def extract_mask(editor_out, target_hw):
49
- """
50
- Pull a binary mask from a Gradio ImageEditor output dict.
51
- Any drawn pixel (alpha > 10) becomes foreground in the mask.
52
- """
53
  h, w = target_hw
54
  blank = np.zeros((h, w), dtype=np.uint8)
55
-
56
  if editor_out is None:
57
  return blank
58
-
59
  layers = editor_out.get("layers", []) if isinstance(editor_out, dict) else [editor_out]
60
  if not layers:
61
  return blank
62
-
63
  combined = blank.copy()
64
  for layer in layers:
65
  if layer is None:
@@ -67,7 +55,6 @@ def extract_mask(editor_out, target_hw):
67
  arr = to_numpy(layer)
68
  if arr is None:
69
  continue
70
- # RGBA: use alpha; RGB: use any non-black pixel
71
  if arr.ndim == 3 and arr.shape[2] == 4:
72
  alpha = arr[:, :, 3]
73
  elif arr.ndim == 3:
@@ -78,20 +65,18 @@ def extract_mask(editor_out, target_hw):
78
  if alpha.shape != (h, w):
79
  alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_NEAREST)
80
  combined = np.maximum(combined, (alpha > 10).astype(np.uint8))
81
-
82
  return combined
83
 
84
 
85
  def make_energy_plot(energies):
86
  fig, ax = plt.subplots(figsize=(7, 4), facecolor="#FFF8F3")
87
  ax.set_facecolor("#FFF8F3")
88
- iters = range(1, len(energies) + 1)
89
- ax.plot(list(iters), energies, "o-",
90
- color="#E8845A", linewidth=2.5, markersize=9,
91
  markerfacecolor="#C85E35", markeredgecolor="white", markeredgewidth=1.5)
92
  best_i = int(np.argmin(energies))
93
- ax.axvline(best_i + 1, color="#A0522D", linestyle="--",
94
- alpha=0.65, label=f"Best iteration: {best_i + 1}")
95
  ax.legend(fontsize=10, framealpha=0.7, edgecolor="#D4B896")
96
  ax.set_xlabel("Iteration", fontsize=12, color="#3D2B1F")
97
  ax.set_ylabel("Total Energy", fontsize=12, color="#3D2B1F")
@@ -119,8 +104,7 @@ def make_iterations_plot(all_masks, refined_mask):
119
  axes[i].set_title(f"Iteration {i + 1}", fontsize=11, color="#3D2B1F")
120
  axes[i].axis("off")
121
  axes[n].imshow(refined_mask, cmap="gray")
122
- axes[n].set_title("Post-Processed", fontsize=11,
123
- color="#C85E35", fontweight="bold")
124
  axes[n].axis("off")
125
  plt.tight_layout()
126
  buf = io.BytesIO()
@@ -131,22 +115,18 @@ def make_iterations_plot(all_masks, refined_mask):
131
 
132
 
133
  # ══════════════════════════════════════════════════════════════════════
134
- # Core segmentation function
135
  # ══════════════════════════════════════════════════════════════════════
136
 
137
  def run_segmentation(fg_editor, bg_editor, uploaded_image,
138
  max_dim, iterations, gamma, n_components, use_auto):
139
-
140
  if uploaded_image is None:
141
  raise gr.Error("Please upload an image first.")
142
-
143
  img_arr = to_numpy(uploaded_image)
144
  if img_arr.ndim == 2:
145
  img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB)
146
-
147
  image_bgr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
148
 
149
- # Resize
150
  h, w = image_bgr.shape[:2]
151
  max_dim = int(max_dim)
152
  if max(h, w) > max_dim:
@@ -155,7 +135,6 @@ def run_segmentation(fg_editor, bg_editor, uploaded_image,
155
  interpolation=cv2.INTER_AREA)
156
  h, w = image_bgr.shape[:2]
157
 
158
- # Annotations
159
  if use_auto:
160
  fg_mask, bg_mask = generate_auto_annotations(image_bgr)
161
  else:
@@ -163,12 +142,10 @@ def run_segmentation(fg_editor, bg_editor, uploaded_image,
163
  bg_mask = extract_mask(bg_editor, (h, w))
164
  if fg_mask.sum() == 0 or bg_mask.sum() == 0:
165
  raise gr.Error(
166
- "Both foreground and background scribbles are required. "
167
- "Draw on the object (left canvas) and the background (right canvas), "
168
- "or tick 'Auto Annotation'."
169
  )
170
 
171
- # Run graph cut pipeline
172
  raw_mask, all_masks, energies = iterative_graph_cut(
173
  image_bgr, fg_mask, bg_mask,
174
  n_iterations=int(iterations),
@@ -177,50 +154,33 @@ def run_segmentation(fg_editor, bg_editor, uploaded_image,
177
  )
178
  refined_mask = refine_segmentation(raw_mask, image_bgr)
179
 
180
- # Naive baselines (aligned to graph-cut label convention)
181
- naive_otsu = naive_thresholding_segmentation(image_bgr)
182
- naive_km = naive_kmeans_segmentation(image_bgr)
183
- naive_otsu = align_naive_to_graphcut(naive_otsu, refined_mask)
184
- naive_km = align_naive_to_graphcut(naive_km, refined_mask)
185
 
186
- # ── Build output visuals ──────────────────────────────────────────
187
-
188
- # Annotation visualisation
189
  annot = image_bgr.copy()
190
  annot[fg_mask == 1] = [0, 255, 0]
191
  annot[bg_mask == 1] = [0, 0, 255]
192
- annot_rgb = cv2.cvtColor(annot, cv2.COLOR_BGR2RGB)
193
 
194
- # Masks β†’ 3-channel for display
195
  def gray3(m):
196
  return cv2.cvtColor((m * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
197
 
198
- raw_rgb = gray3(raw_mask)
199
- refined_rgb = gray3(refined_mask)
200
- otsu_rgb = gray3(naive_otsu)
201
- km_rgb = gray3(naive_km)
202
-
203
- # Overlay
204
- overlay_rgb = cv2.cvtColor(
205
- create_overlay(image_bgr, refined_mask), cv2.COLOR_BGR2RGB)
206
-
207
- # Extracted foreground (white background)
208
  ext = image_bgr.copy()
209
  ext[refined_mask == 0] = [255, 255, 255]
210
- ext_rgb = cv2.cvtColor(ext, cv2.COLOR_BGR2RGB)
211
-
212
- # Plots
213
- energy_img = make_energy_plot(energies)
214
- iter_img = make_iterations_plot(all_masks, refined_mask)
215
 
216
- return (annot_rgb, raw_rgb, refined_rgb,
217
- overlay_rgb, ext_rgb,
218
- otsu_rgb, km_rgb,
219
- energy_img, iter_img)
 
 
 
 
 
 
 
220
 
221
 
222
  def update_editors(img):
223
- """Push uploaded image as background into both annotation editors."""
224
  if img is None:
225
  return gr.update(value=None), gr.update(value=None)
226
  pil = Image.fromarray(img.astype(np.uint8))
@@ -228,343 +188,497 @@ def update_editors(img):
228
 
229
 
230
  # ══════════════════════════════════════════════════════════════════════
231
- # UI β€” warm aesthetic
232
  # ══════════════════════════════════════════════════════════════════════
233
 
234
  CSS = """
235
- /* ── Base ── */
236
- body, .gradio-container {
237
- background-color: #FFF8F3 !important;
238
- font-family: 'Inter', 'Helvetica Neue', Arial, sans-serif;
 
 
 
 
 
 
 
 
 
239
  }
240
 
241
- /* ── Cards ── */
242
- .card {
243
- background: #FFFFFF;
244
- border-radius: 16px;
245
- border: 1px solid #EDD9C8;
246
- box-shadow: 0 4px 18px rgba(180, 110, 60, 0.08);
247
- padding: 22px 24px;
248
- margin-bottom: 18px;
 
 
 
 
249
  }
250
 
251
- /* ── Section title ── */
252
- .sec-title {
253
- display: flex;
254
- align-items: center;
255
- gap: 10px;
256
- font-size: 17px;
257
- font-weight: 700;
258
- color: #3D2B1F;
259
- margin: 0 0 16px 0;
 
 
 
 
 
 
 
260
  }
261
- .step-badge {
262
- background: linear-gradient(135deg, #E8845A, #C85E35);
263
- color: white;
264
- border-radius: 50%;
265
- width: 30px; height: 30px;
266
- display: inline-flex;
267
- align-items: center; justify-content: center;
268
- font-size: 13px; font-weight: 800;
269
- flex-shrink: 0;
270
  }
271
 
272
- /* ── Run button ── */
273
- #run-btn {
274
- background: linear-gradient(135deg, #E8845A 0%, #C85E35 100%) !important;
275
- color: white !important;
276
- border: none !important;
277
- border-radius: 14px !important;
278
- font-size: 17px !important;
279
  font-weight: 700 !important;
280
- padding: 15px 0 !important;
281
- letter-spacing: 0.4px !important;
282
- box-shadow: 0 6px 22px rgba(200, 94, 53, 0.38) !important;
283
- transition: transform 0.18s ease, box-shadow 0.18s ease !important;
284
- width: 100% !important;
285
  }
286
- #run-btn:hover {
287
- transform: translateY(-2px) !important;
288
- box-shadow: 0 10px 28px rgba(200, 94, 53, 0.48) !important;
 
 
289
  }
290
 
291
- /* ── Labels ── */
292
- label > span, .label-wrap span {
293
- color: #5C3D2E !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  font-weight: 600 !important;
295
  }
296
- input[type=range] { accent-color: #E8845A; }
297
- input[type=checkbox] { accent-color: #E8845A; }
298
-
299
- /* ── Output image labels ── */
300
- .out-lbl {
301
- font-size: 11.5px;
302
- font-weight: 700;
303
- color: #A07060;
304
- text-align: center;
305
- text-transform: uppercase;
306
- letter-spacing: 0.6px;
307
- margin-bottom: 5px;
308
  }
309
- .out-lbl.gc { color: #C85E35; }
310
- .out-lbl.naive { color: #888; }
311
-
312
- /* ── Hint text ── */
313
- .hint {
314
- font-size: 12.5px;
315
- color: #A07060;
316
- margin-top: 8px;
317
- line-height: 1.55;
318
  }
319
 
320
- /* ── Hero ── */
321
- .hero {
 
 
 
322
  text-align: center;
323
- padding: 36px 20px 24px;
 
 
 
 
324
  }
325
  .hero-badge {
326
  display: inline-block;
327
- background: #F2C4A0;
328
  color: #7A3B1E;
329
- border-radius: 20px;
330
- padding: 5px 16px;
331
- font-size: 11.5px;
332
- font-weight: 700;
333
- margin-bottom: 16px;
334
- letter-spacing: 1px;
335
  text-transform: uppercase;
 
336
  }
337
- .hero h1 {
338
- font-size: 34px;
339
- font-weight: 800;
340
- color: #3D2B1F;
341
- margin: 0 0 10px;
342
- line-height: 1.15;
343
  }
344
- .hero p {
345
- font-size: 15px;
346
- color: #7A4F3A;
347
- max-width: 640px;
348
- margin: 0 auto;
349
- line-height: 1.65;
350
  }
351
 
352
- /* ── Divider ── */
353
- .warm-hr {
354
- border: none;
355
- border-top: 1px solid #EDD9C8;
356
- margin: 4px 0 22px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  }
358
 
359
- /* ── Tips box ── */
360
  .tips-box {
361
- background: #FFF3EC;
362
- border-left: 4px solid #E8845A;
363
- border-radius: 0 10px 10px 0;
364
- padding: 14px 16px;
365
- font-size: 13.5px;
366
- color: #5C3D2E;
367
- line-height: 1.7;
 
368
  }
 
369
 
370
- /* ── Annotation color labels ── */
371
- .fg-label {
372
- font-size: 13px; font-weight: 700;
373
- color: #1A7A1A;
374
- text-align: center; margin-bottom: 6px;
 
 
 
375
  }
376
- .bg-label {
377
- font-size: 13px; font-weight: 700;
378
- color: #B52020;
379
- text-align: center; margin-bottom: 6px;
 
 
 
 
 
380
  }
381
- """
382
 
383
- HERO = """
384
- <div class="hero">
385
- <div class="hero-badge">Graph Cut Β· GMM Β· PyMaxflow Β· Energy Minimisation</div>
386
- <h1>πŸ‚ Graph Cut Image Segmentation</h1>
387
- <p>
388
- Upload an image, paint foreground and background scribbles, then let
389
- energy-minimisation-based Graph Cut precisely isolate your object β€”
390
- guided by Gaussian Mixture Models and iterative refinement.
391
- </p>
392
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  """
394
 
 
 
 
 
395
  with gr.Blocks(
396
- theme=gr.themes.Soft(
 
 
397
  primary_hue=gr.themes.colors.orange,
398
  secondary_hue=gr.themes.colors.amber,
399
  neutral_hue=gr.themes.colors.stone,
400
  font=gr.themes.GoogleFont("Inter"),
401
  ).set(
402
  body_background_fill="#FFF8F3",
 
403
  block_background_fill="#FFFFFF",
404
  block_border_color="#EDD9C8",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  button_primary_background_fill="#E8845A",
406
  button_primary_background_fill_hover="#C85E35",
407
- button_primary_text_color="white",
 
 
 
 
 
 
 
 
 
408
  ),
409
- css=CSS,
410
- title="Graph Cut Segmentation",
411
  ) as demo:
412
 
413
- # ── Hero ───────────────────────────────────────────────────────────
414
- gr.HTML(HERO)
415
- gr.HTML('<hr class="warm-hr">')
416
-
417
- # ── Step 1: Upload ─────────────────────────────────────────────────
418
  gr.HTML("""
419
- <div class="card">
420
- <div class="sec-title">
421
- <span class="step-badge">1</span> Upload Image
 
 
 
 
422
  </div>
 
423
  """)
424
- with gr.Row():
 
 
 
 
 
 
 
 
425
  with gr.Column(scale=3):
426
  img_upload = gr.Image(
427
  label="Input Image",
428
  type="numpy",
429
  sources=["upload", "clipboard"],
430
- height=300,
431
  )
432
  with gr.Column(scale=1):
433
  gr.HTML("""
434
  <div class="tips-box">
435
- <b>Tips for best results</b><br>
436
- β€’ Object with clear boundary from background<br>
437
- β€’ Natural photos, portraits, products work great<br>
438
- β€’ Higher contrast = cleaner segmentation<br>
439
- β€’ Any resolution β€” we resize automatically<br>
440
- β€’ JPEG or PNG
 
441
  </div>
442
  """)
443
- gr.HTML("</div>")
444
 
445
- # ── Step 2: Parameters ─────────��───────────────────────────────────
 
 
446
  gr.HTML("""
447
- <div class="card">
448
- <div class="sec-title">
449
- <span class="step-badge">2</span> Configure Parameters
450
- </div>
451
  """)
452
  with gr.Row():
453
- max_dim = gr.Slider(200, 800, value=400, step=50, label="Max Dimension (px)",
454
- info="Larger = more detail, slower. 400 is a good default.")
455
- iterations = gr.Slider(1, 10, value=3, step=1, label="Iterations",
456
- info="How many GMM re-estimation rounds. 3–5 recommended.")
457
  with gr.Row():
458
- gamma = gr.Slider(10, 200, value=50, step=5, label="Smoothness Ξ³",
459
- info="Higher = smoother boundary. 50 is a good default.")
460
- n_comp = gr.Slider(2, 10, value=5, step=1, label="GMM Components K",
461
- info="Colour clusters per region. 5 fits most images.")
462
  use_auto = gr.Checkbox(
463
- label="Auto Annotation β€” skip drawing (uses centre/border heuristic)",
464
  value=False,
465
  )
466
- gr.HTML("</div>")
467
 
468
- # ── Step 3: Annotate ───────────────────────────────────────────────
 
 
469
  gr.HTML("""
470
- <div class="card">
471
- <div class="sec-title">
472
- <span class="step-badge">3</span> Annotate &nbsp;
473
- <span style="font-size:13px; font-weight:400; color:#A07060;">
474
- (skip if Auto Annotation is enabled above)
475
- </span>
476
- </div>
477
  """)
478
  with gr.Row():
479
  with gr.Column():
480
- gr.HTML('<div class="fg-label">🟒 FOREGROUND β€” Paint over the object to keep</div>')
481
  fg_editor = gr.ImageEditor(
482
- label="Foreground",
483
  show_label=False,
484
  height=380,
485
  brush=gr.Brush(
486
- default_size=12,
487
  default_color="#00CC44",
488
  colors=["#00CC44", "#00FF00", "#22AA55"],
489
  color_mode="defaults",
490
  ),
491
  )
492
  gr.HTML("""
493
- <div class="hint">
494
- ✏️ Draw <b>green strokes</b> across the foreground object.
495
- Cover several different regions β€” ears, body, edges β€” for a
496
- richer GMM colour model.
497
  </div>
498
  """)
499
  with gr.Column():
500
- gr.HTML('<div class="bg-label">πŸ”΄ BACKGROUND β€” Paint over the background area</div>')
501
  bg_editor = gr.ImageEditor(
502
- label="Background",
503
  show_label=False,
504
  height=380,
505
  brush=gr.Brush(
506
- default_size=12,
507
  default_color="#FF3333",
508
  colors=["#FF3333", "#CC0000", "#FF6666"],
509
  color_mode="defaults",
510
  ),
511
  )
512
  gr.HTML("""
513
- <div class="hint">
514
- ✏️ Draw <b>red strokes</b> on background areas.
515
- Cover different background textures (sky, floor, wall…)
516
- to give the GMM diverse samples.
517
  </div>
518
  """)
519
- gr.HTML("</div>")
520
 
521
- # Sync upload β†’ both editors
522
  img_upload.change(
523
  fn=update_editors,
524
  inputs=img_upload,
525
  outputs=[fg_editor, bg_editor],
526
  )
527
 
528
- # ── Run ────────────────────────────────────────────────────────────
529
- gr.HTML('<div style="padding: 4px 0 20px;">')
 
530
  run_btn = gr.Button(
531
- "β–Ά Run Graph Cut Segmentation",
532
  elem_id="run-btn",
533
  variant="primary",
534
  )
535
- gr.HTML("</div>")
536
 
537
- # ── Step 4: Results ────────────────────────────────────────────────
 
 
538
  gr.HTML("""
539
- <div class="card">
540
- <div class="sec-title">
541
- <span class="step-badge">4</span> Segmentation Results
542
- </div>
543
  """)
544
  with gr.Row():
545
- out_annot = gr.Image(label="Input + Annotations", height=260)
546
- out_raw = gr.Image(label="Raw Graph Cut", height=260)
547
- out_refined = gr.Image(label="Refined Graph Cut", height=260)
548
  with gr.Row():
549
- out_overlay = gr.Image(label="Overlay on Original", height=260)
550
- out_extract = gr.Image(label="Extracted Foreground", height=260)
551
- out_otsu = gr.Image(label="Naive: Otsu", height=260)
552
- out_km = gr.Image(label="Naive: K-Means (k=2)", height=260)
553
- gr.HTML("</div>")
554
 
555
- # ── Step 5: Analysis ───────────────────────────────────────────────
 
 
556
  gr.HTML("""
557
- <div class="card">
558
- <div class="sec-title">
559
- <span class="step-badge">5</span> Convergence &amp; Iteration Analysis
560
- </div>
561
  """)
562
  with gr.Row():
563
- out_energy = gr.Image(label="Energy Convergence", height=340)
564
- out_iters = gr.Image(label="Iterative Mask Progression", height=340)
565
- gr.HTML("</div>")
566
 
567
- # ── Wire up ────────────────────────────────────────────────────────
568
  run_btn.click(
569
  fn=run_segmentation,
570
  inputs=[fg_editor, bg_editor, img_upload,
@@ -576,11 +690,10 @@ with gr.Blocks(
576
  show_progress="full",
577
  )
578
 
579
- # ── Footer ─────────────────────────────────────────────────────────
580
  gr.HTML("""
581
- <div style="text-align:center; padding: 24px 0 10px;
582
- font-size: 13px; color: #B09080;">
583
- CSL7360: Computer Vision Β· Assignment 2 Β·
584
  Graph Cut Segmentation via PyMaxflow &amp; GMMs
585
  </div>
586
  """)
 
5
  import warnings
6
  warnings.filterwarnings("ignore")
7
 
8
+ # Monkey-patch matplotlib.use so graph_cut_segmentation.py's TkAgg call is a no-op
 
 
9
  _real_use = matplotlib.use
10
  matplotlib.use = lambda *a, **kw: None
11
 
 
12
  import gradio as gr
13
  import numpy as np
14
  import cv2
15
  import io
16
  from PIL import Image
17
 
 
18
  from graph_cut_segmentation import (
19
  iterative_graph_cut,
20
  refine_segmentation,
 
25
  generate_auto_annotations,
26
  )
27
 
28
+ matplotlib.use = _real_use # restore after import
 
29
 
30
 
31
  # ══════════════════════════════════════════════════════════════════════
 
41
 
42
 
43
  def extract_mask(editor_out, target_hw):
 
 
 
 
44
  h, w = target_hw
45
  blank = np.zeros((h, w), dtype=np.uint8)
 
46
  if editor_out is None:
47
  return blank
 
48
  layers = editor_out.get("layers", []) if isinstance(editor_out, dict) else [editor_out]
49
  if not layers:
50
  return blank
 
51
  combined = blank.copy()
52
  for layer in layers:
53
  if layer is None:
 
55
  arr = to_numpy(layer)
56
  if arr is None:
57
  continue
 
58
  if arr.ndim == 3 and arr.shape[2] == 4:
59
  alpha = arr[:, :, 3]
60
  elif arr.ndim == 3:
 
65
  if alpha.shape != (h, w):
66
  alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_NEAREST)
67
  combined = np.maximum(combined, (alpha > 10).astype(np.uint8))
 
68
  return combined
69
 
70
 
71
  def make_energy_plot(energies):
72
  fig, ax = plt.subplots(figsize=(7, 4), facecolor="#FFF8F3")
73
  ax.set_facecolor("#FFF8F3")
74
+ iters = list(range(1, len(energies) + 1))
75
+ ax.plot(iters, energies, "o-", color="#E8845A", linewidth=2.5, markersize=9,
 
76
  markerfacecolor="#C85E35", markeredgecolor="white", markeredgewidth=1.5)
77
  best_i = int(np.argmin(energies))
78
+ ax.axvline(best_i + 1, color="#A0522D", linestyle="--", alpha=0.65,
79
+ label=f"Best iteration: {best_i + 1}")
80
  ax.legend(fontsize=10, framealpha=0.7, edgecolor="#D4B896")
81
  ax.set_xlabel("Iteration", fontsize=12, color="#3D2B1F")
82
  ax.set_ylabel("Total Energy", fontsize=12, color="#3D2B1F")
 
104
  axes[i].set_title(f"Iteration {i + 1}", fontsize=11, color="#3D2B1F")
105
  axes[i].axis("off")
106
  axes[n].imshow(refined_mask, cmap="gray")
107
+ axes[n].set_title("Post-Processed", fontsize=11, color="#C85E35", fontweight="bold")
 
108
  axes[n].axis("off")
109
  plt.tight_layout()
110
  buf = io.BytesIO()
 
115
 
116
 
117
  # ══════════════════════════════════════════════════════════════════════
118
+ # Core segmentation
119
  # ══════════════════════════════════════════════════════════════════════
120
 
121
  def run_segmentation(fg_editor, bg_editor, uploaded_image,
122
  max_dim, iterations, gamma, n_components, use_auto):
 
123
  if uploaded_image is None:
124
  raise gr.Error("Please upload an image first.")
 
125
  img_arr = to_numpy(uploaded_image)
126
  if img_arr.ndim == 2:
127
  img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB)
 
128
  image_bgr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
129
 
 
130
  h, w = image_bgr.shape[:2]
131
  max_dim = int(max_dim)
132
  if max(h, w) > max_dim:
 
135
  interpolation=cv2.INTER_AREA)
136
  h, w = image_bgr.shape[:2]
137
 
 
138
  if use_auto:
139
  fg_mask, bg_mask = generate_auto_annotations(image_bgr)
140
  else:
 
142
  bg_mask = extract_mask(bg_editor, (h, w))
143
  if fg_mask.sum() == 0 or bg_mask.sum() == 0:
144
  raise gr.Error(
145
+ "Both foreground (green) and background (red) scribbles are required. "
146
+ "Draw on each canvas, or enable Auto Annotation."
 
147
  )
148
 
 
149
  raw_mask, all_masks, energies = iterative_graph_cut(
150
  image_bgr, fg_mask, bg_mask,
151
  n_iterations=int(iterations),
 
154
  )
155
  refined_mask = refine_segmentation(raw_mask, image_bgr)
156
 
157
+ naive_otsu = align_naive_to_graphcut(naive_thresholding_segmentation(image_bgr), refined_mask)
158
+ naive_km = align_naive_to_graphcut(naive_kmeans_segmentation(image_bgr), refined_mask)
 
 
 
159
 
 
 
 
160
  annot = image_bgr.copy()
161
  annot[fg_mask == 1] = [0, 255, 0]
162
  annot[bg_mask == 1] = [0, 0, 255]
 
163
 
 
164
  def gray3(m):
165
  return cv2.cvtColor((m * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
166
 
 
 
 
 
 
 
 
 
 
 
167
  ext = image_bgr.copy()
168
  ext[refined_mask == 0] = [255, 255, 255]
 
 
 
 
 
169
 
170
+ return (
171
+ cv2.cvtColor(annot, cv2.COLOR_BGR2RGB),
172
+ gray3(raw_mask),
173
+ gray3(refined_mask),
174
+ cv2.cvtColor(create_overlay(image_bgr, refined_mask), cv2.COLOR_BGR2RGB),
175
+ cv2.cvtColor(ext, cv2.COLOR_BGR2RGB),
176
+ gray3(naive_otsu),
177
+ gray3(naive_km),
178
+ make_energy_plot(energies),
179
+ make_iterations_plot(all_masks, refined_mask),
180
+ )
181
 
182
 
183
  def update_editors(img):
 
184
  if img is None:
185
  return gr.update(value=None), gr.update(value=None)
186
  pil = Image.fromarray(img.astype(np.uint8))
 
188
 
189
 
190
  # ══════════════════════════════════════════════════════════════════════
191
+ # CSS β€” forces light warm theme over every Gradio 6.x element
192
  # ══════════════════════════════════════════════════════════════════════
193
 
194
  CSS = """
195
+ /* ─── Google Font ─── */
196
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
197
+
198
+ /* ─── Root tokens ─── */
199
+ :root {
200
+ --warm-bg: #FFF8F3;
201
+ --warm-card: #FFFFFF;
202
+ --warm-border: #EDD9C8;
203
+ --warm-text: #3D2B1F;
204
+ --warm-muted: #7A4F3A;
205
+ --warm-accent: #E8845A;
206
+ --warm-accent2: #C85E35;
207
+ --warm-light: #FFF3EC;
208
  }
209
 
210
+ /* ─── Base ─── */
211
+ *, *::before, *::after { box-sizing: border-box; }
212
+
213
+ body,
214
+ .gradio-container,
215
+ .gradio-container > .main,
216
+ .gradio-container > .main > .wrap,
217
+ .gap,
218
+ footer {
219
+ background-color: var(--warm-bg) !important;
220
+ font-family: 'Inter', sans-serif !important;
221
+ color: var(--warm-text) !important;
222
  }
223
 
224
+ /* ─── Remove dark panel backgrounds ─── */
225
+ .block,
226
+ .block.padded,
227
+ .panel,
228
+ .form,
229
+ .box,
230
+ .contain,
231
+ .wrap,
232
+ .inner-wrap,
233
+ .input-wrapper,
234
+ .output-class,
235
+ .image-container,
236
+ .preview {
237
+ background: var(--warm-card) !important;
238
+ border-color: var(--warm-border) !important;
239
+ color: var(--warm-text) !important;
240
  }
241
+
242
+ /* ─── All text nodes ─── */
243
+ p, span, div, h1, h2, h3, h4, label, legend, li {
244
+ color: var(--warm-text) !important;
245
+ font-family: 'Inter', sans-serif !important;
 
 
 
 
246
  }
247
 
248
+ /* ─── Labels above components ─── */
249
+ .label-wrap > span,
250
+ label > span,
251
+ .block > label > span {
252
+ font-size: 13.5px !important;
 
 
253
  font-weight: 700 !important;
254
+ color: var(--warm-text) !important;
255
+ letter-spacing: 0.15px !important;
 
 
 
256
  }
257
+
258
+ /* ─── Info/description text below sliders ─── */
259
+ .info, .description, [class*="description"] {
260
+ font-size: 12px !important;
261
+ color: var(--warm-muted) !important;
262
  }
263
 
264
+ /* ─── Inputs / numbers ─── */
265
+ input[type="number"],
266
+ input[type="text"],
267
+ textarea {
268
+ background: var(--warm-light) !important;
269
+ border: 1.5px solid var(--warm-border) !important;
270
+ color: var(--warm-text) !important;
271
+ border-radius: 8px !important;
272
+ }
273
+ input[type="number"]:focus,
274
+ input[type="text"]:focus,
275
+ textarea:focus {
276
+ border-color: var(--warm-accent) !important;
277
+ outline: none !important;
278
+ box-shadow: 0 0 0 3px rgba(232,132,90,0.15) !important;
279
+ }
280
+
281
+ /* ─── Sliders ─── */
282
+ input[type="range"] {
283
+ accent-color: var(--warm-accent) !important;
284
+ cursor: pointer !important;
285
+ }
286
+
287
+ /* ─── Checkbox ─── */
288
+ input[type="checkbox"] {
289
+ accent-color: var(--warm-accent) !important;
290
+ width: 16px !important;
291
+ height: 16px !important;
292
+ }
293
+ .checkbox-wrap label span,
294
+ .checkbox-label {
295
+ color: var(--warm-text) !important;
296
  font-weight: 600 !important;
297
  }
298
+
299
+ /* ─── Image upload area ─── */
300
+ .upload-container,
301
+ .upload-btn,
302
+ [data-testid="image"] {
303
+ background: var(--warm-light) !important;
304
+ border: 2px dashed var(--warm-border) !important;
305
+ border-radius: 12px !important;
 
 
 
 
306
  }
307
+ .upload-container:hover { border-color: var(--warm-accent) !important; }
308
+
309
+ /* ─── ImageEditor dark canvas area ─── */
310
+ .image-editor,
311
+ .image-editor > *,
312
+ [data-testid="image-editor"],
313
+ [data-testid="image-editor"] > * {
314
+ background: #F5EDE4 !important;
 
315
  }
316
 
317
+ /* ─── Row/Column spacers ─── */
318
+ .row, .column { background: transparent !important; }
319
+
320
+ /* ─── Hero section ─── */
321
+ .hero-wrap {
322
  text-align: center;
323
+ padding: 40px 24px 28px;
324
+ background: linear-gradient(160deg, #FFF8F3 0%, #FDECD8 100%);
325
+ border-radius: 20px;
326
+ border: 1px solid var(--warm-border);
327
+ margin-bottom: 20px;
328
  }
329
  .hero-badge {
330
  display: inline-block;
331
+ background: linear-gradient(135deg, #F2C4A0, #EDA882);
332
  color: #7A3B1E;
333
+ border-radius: 30px;
334
+ padding: 6px 20px;
335
+ font-size: 11px;
336
+ font-weight: 800;
337
+ letter-spacing: 1.2px;
 
338
  text-transform: uppercase;
339
+ margin-bottom: 18px;
340
  }
341
+ .hero-title {
342
+ font-size: 38px !important;
343
+ font-weight: 800 !important;
344
+ color: #3D2B1F !important;
345
+ margin: 0 0 12px !important;
346
+ line-height: 1.1 !important;
347
  }
348
+ .hero-sub {
349
+ font-size: 15.5px !important;
350
+ color: #7A4F3A !important;
351
+ max-width: 600px;
352
+ margin: 0 auto !important;
353
+ line-height: 1.7 !important;
354
  }
355
 
356
+ /* ─── Section header ─── */
357
+ .sec-header {
358
+ display: flex;
359
+ align-items: center;
360
+ gap: 12px;
361
+ padding: 18px 0 14px;
362
+ border-bottom: 2px solid var(--warm-border);
363
+ margin-bottom: 18px;
364
+ }
365
+ .step-num {
366
+ width: 32px; height: 32px;
367
+ background: linear-gradient(135deg, var(--warm-accent), var(--warm-accent2));
368
+ color: white;
369
+ border-radius: 50%;
370
+ display: inline-flex;
371
+ align-items: center; justify-content: center;
372
+ font-size: 14px; font-weight: 800;
373
+ flex-shrink: 0;
374
+ box-shadow: 0 3px 10px rgba(200,94,53,0.35);
375
+ }
376
+ .sec-title-text {
377
+ font-size: 18px !important;
378
+ font-weight: 800 !important;
379
+ color: #3D2B1F !important;
380
+ margin: 0 !important;
381
+ }
382
+ .sec-sub {
383
+ font-size: 13px !important;
384
+ color: var(--warm-muted) !important;
385
+ margin: 0 !important;
386
+ font-weight: 400 !important;
387
  }
388
 
389
+ /* ─── Tips box ─── */
390
  .tips-box {
391
+ background: var(--warm-light);
392
+ border-left: 4px solid var(--warm-accent);
393
+ border-radius: 0 12px 12px 0;
394
+ padding: 16px 18px;
395
+ font-size: 13.5px !important;
396
+ color: #5C3D2E !important;
397
+ line-height: 1.75;
398
+ height: 100%;
399
  }
400
+ .tips-box b { color: var(--warm-accent2) !important; }
401
 
402
+ /* ─── Annotation labels ─── */
403
+ .anno-label {
404
+ text-align: center;
405
+ font-size: 13.5px !important;
406
+ font-weight: 800 !important;
407
+ padding: 10px 0 8px;
408
+ border-radius: 8px;
409
+ margin-bottom: 8px;
410
  }
411
+ .anno-fg { background: #E8F5E9; color: #1B5E20 !important; border: 1.5px solid #A5D6A7; }
412
+ .anno-bg { background: #FFEBEE; color: #B71C1C !important; border: 1.5px solid #EF9A9A; }
413
+
414
+ /* ─── Hint text ─── */
415
+ .hint-text {
416
+ font-size: 12.5px !important;
417
+ color: var(--warm-muted) !important;
418
+ line-height: 1.6;
419
+ padding: 10px 4px 0;
420
  }
 
421
 
422
+ /* ─── Divider ─── */
423
+ .warm-divider {
424
+ border: none;
425
+ border-top: 1.5px solid var(--warm-border);
426
+ margin: 6px 0 24px;
427
+ }
428
+
429
+ /* ─── RUN BUTTON ─── */
430
+ #run-btn {
431
+ background: linear-gradient(135deg, #E8845A 0%, #C85E35 100%) !important;
432
+ color: #FFFFFF !important;
433
+ border: none !important;
434
+ border-radius: 14px !important;
435
+ font-size: 18px !important;
436
+ font-weight: 800 !important;
437
+ padding: 18px 0 !important;
438
+ letter-spacing: 0.5px !important;
439
+ box-shadow: 0 8px 28px rgba(200,94,53,0.42) !important;
440
+ transition: all 0.2s ease !important;
441
+ width: 100% !important;
442
+ cursor: pointer !important;
443
+ }
444
+ #run-btn:hover {
445
+ transform: translateY(-2px) !important;
446
+ box-shadow: 0 12px 36px rgba(200,94,53,0.55) !important;
447
+ }
448
+ #run-btn:active { transform: translateY(0) !important; }
449
+
450
+ /* ─── Output image panels ─── */
451
+ .output-panel {
452
+ background: var(--warm-card) !important;
453
+ border-radius: 14px !important;
454
+ border: 1px solid var(--warm-border) !important;
455
+ overflow: hidden !important;
456
+ box-shadow: 0 2px 12px rgba(180,110,60,0.07) !important;
457
+ }
458
+
459
+ /* ─── Progress bar ─── */
460
+ .progress-bar { background: var(--warm-accent) !important; }
461
+
462
+ /* ─── Footer ─── */
463
+ .footer-wrap {
464
+ text-align: center;
465
+ padding: 28px 0 12px;
466
+ font-size: 13px !important;
467
+ color: #B09080 !important;
468
+ border-top: 1px solid var(--warm-border);
469
+ margin-top: 12px;
470
+ }
471
  """
472
 
473
+ # ══════════════════════════════════════════════════════════════════════
474
+ # UI
475
+ # ══════════════════════════════════════════════════════════════════════
476
+
477
  with gr.Blocks(
478
+ css=CSS,
479
+ title="Graph Cut Segmentation",
480
+ theme=gr.themes.Base(
481
  primary_hue=gr.themes.colors.orange,
482
  secondary_hue=gr.themes.colors.amber,
483
  neutral_hue=gr.themes.colors.stone,
484
  font=gr.themes.GoogleFont("Inter"),
485
  ).set(
486
  body_background_fill="#FFF8F3",
487
+ body_text_color="#3D2B1F",
488
  block_background_fill="#FFFFFF",
489
  block_border_color="#EDD9C8",
490
+ block_label_background_fill="#FFF3EC",
491
+ block_label_text_color="#3D2B1F",
492
+ block_label_text_weight="700",
493
+ block_title_text_color="#3D2B1F",
494
+ block_title_text_weight="700",
495
+ input_background_fill="#FFF3EC",
496
+ input_border_color="#EDD9C8",
497
+ input_border_color_focus="#E8845A",
498
+ input_placeholder_color="#B09080",
499
+ checkbox_background_color="#FFF3EC",
500
+ checkbox_background_color_selected="#E8845A",
501
+ checkbox_border_color="#EDD9C8",
502
+ checkbox_label_text_color="#3D2B1F",
503
+ slider_color="#E8845A",
504
  button_primary_background_fill="#E8845A",
505
  button_primary_background_fill_hover="#C85E35",
506
+ button_primary_text_color="#FFFFFF",
507
+ button_primary_border_color="transparent",
508
+ button_secondary_background_fill="#FFF3EC",
509
+ button_secondary_text_color="#3D2B1F",
510
+ border_color_primary="#EDD9C8",
511
+ border_color_accent="#E8845A",
512
+ shadow_drop="0 2px 12px rgba(180,110,60,0.08)",
513
+ color_accent="#E8845A",
514
+ color_accent_soft="#FFF3EC",
515
+ link_text_color="#E8845A",
516
  ),
 
 
517
  ) as demo:
518
 
519
+ # ── Hero ──────────────────────────────────────────────────────────
 
 
 
 
520
  gr.HTML("""
521
+ <div class="hero-wrap">
522
+ <div class="hero-badge">Graph Cut &nbsp;Β·&nbsp; GMM &nbsp;Β·&nbsp; PyMaxflow &nbsp;Β·&nbsp; Energy Minimisation</div>
523
+ <div class="hero-title">πŸ‚ Graph Cut Image Segmentation</div>
524
+ <div class="hero-sub">
525
+ Upload an image, paint foreground &amp; background scribbles, and let
526
+ energy-minimisation Graph Cut isolate your object β€” powered by
527
+ Gaussian Mixture Models and iterative refinement.
528
  </div>
529
+ </div>
530
  """)
531
+
532
+ # ── STEP 1: Upload ────────────────────────────────────────────────
533
+ gr.HTML("""
534
+ <div class="sec-header">
535
+ <span class="step-num">1</span>
536
+ <span class="sec-title-text">Upload Image</span>
537
+ </div>
538
+ """)
539
+ with gr.Row(equal_height=True):
540
  with gr.Column(scale=3):
541
  img_upload = gr.Image(
542
  label="Input Image",
543
  type="numpy",
544
  sources=["upload", "clipboard"],
545
+ height=280,
546
  )
547
  with gr.Column(scale=1):
548
  gr.HTML("""
549
  <div class="tips-box">
550
+ <b>Tips for best results</b><br><br>
551
+ βœ… Clear object boundary from background<br>
552
+ βœ… Natural photos, portraits, products<br>
553
+ βœ… Any resolution β€” resized automatically<br>
554
+ βœ… JPEG or PNG<br><br>
555
+ ⚑ Higher contrast = cleaner segmentation<br>
556
+ ⚑ Draw scribbles in diverse colour areas
557
  </div>
558
  """)
 
559
 
560
+ gr.HTML('<hr class="warm-divider">')
561
+
562
+ # ── STEP 2: Parameters ────────────────────────────────────────────
563
  gr.HTML("""
564
+ <div class="sec-header">
565
+ <span class="step-num">2</span>
566
+ <span class="sec-title-text">Configure Parameters</span>
567
+ </div>
568
  """)
569
  with gr.Row():
570
+ max_dim = gr.Slider(200, 800, value=400, step=50, label="Max Dimension (px)",
571
+ info="Larger = more detail but slower. 400 recommended.")
572
+ iterations = gr.Slider(1, 10, value=3, step=1, label="Iterations",
573
+ info="GMM re-estimation rounds. 3–5 is optimal.")
574
  with gr.Row():
575
+ gamma = gr.Slider(10, 200, value=50, step=5, label="Smoothness Ξ³",
576
+ info="Higher = smoother boundary. Default 50.")
577
+ n_comp = gr.Slider(2, 10, value=5, step=1, label="GMM Components K",
578
+ info="Colour clusters per region. 5 fits most images.")
579
  use_auto = gr.Checkbox(
580
+ label="⚑ Auto Annotation β€” skip drawing (uses centre/border heuristic)",
581
  value=False,
582
  )
 
583
 
584
+ gr.HTML('<hr class="warm-divider">')
585
+
586
+ # ── STEP 3: Annotate ──────────────────────────────────────────────
587
  gr.HTML("""
588
+ <div class="sec-header">
589
+ <span class="step-num">3</span>
590
+ <span class="sec-title-text">Annotate</span>
591
+ <span class="sec-sub">β€” skip this step if Auto Annotation is enabled above</span>
592
+ </div>
 
 
593
  """)
594
  with gr.Row():
595
  with gr.Column():
596
+ gr.HTML('<div class="anno-label anno-fg">🟒 FOREGROUND &nbsp;β€” paint over the object to keep</div>')
597
  fg_editor = gr.ImageEditor(
598
+ label="Foreground Canvas",
599
  show_label=False,
600
  height=380,
601
  brush=gr.Brush(
602
+ default_size=14,
603
  default_color="#00CC44",
604
  colors=["#00CC44", "#00FF00", "#22AA55"],
605
  color_mode="defaults",
606
  ),
607
  )
608
  gr.HTML("""
609
+ <div class="hint-text">
610
+ ✏️ Draw <strong>green strokes</strong> across different parts of the object
611
+ (body, edges, texture areas) for a richer GMM colour model.
 
612
  </div>
613
  """)
614
  with gr.Column():
615
+ gr.HTML('<div class="anno-label anno-bg">πŸ”΄ BACKGROUND &nbsp;β€” paint over background areas</div>')
616
  bg_editor = gr.ImageEditor(
617
+ label="Background Canvas",
618
  show_label=False,
619
  height=380,
620
  brush=gr.Brush(
621
+ default_size=14,
622
  default_color="#FF3333",
623
  colors=["#FF3333", "#CC0000", "#FF6666"],
624
  color_mode="defaults",
625
  ),
626
  )
627
  gr.HTML("""
628
+ <div class="hint-text">
629
+ ✏️ Draw <strong>red strokes</strong> on background regions.
630
+ Cover varied textures (sky, floor, wall…) for better discrimination.
 
631
  </div>
632
  """)
 
633
 
 
634
  img_upload.change(
635
  fn=update_editors,
636
  inputs=img_upload,
637
  outputs=[fg_editor, bg_editor],
638
  )
639
 
640
+ gr.HTML('<hr class="warm-divider">')
641
+
642
+ # ── RUN ───────────────────────────────────────────────────────────
643
  run_btn = gr.Button(
644
+ "β–Ά Run Graph Cut Segmentation",
645
  elem_id="run-btn",
646
  variant="primary",
647
  )
 
648
 
649
+ gr.HTML('<hr class="warm-divider">')
650
+
651
+ # ── STEP 4: Results ───────────────────────────────────────────────
652
  gr.HTML("""
653
+ <div class="sec-header">
654
+ <span class="step-num">4</span>
655
+ <span class="sec-title-text">Segmentation Results</span>
656
+ </div>
657
  """)
658
  with gr.Row():
659
+ out_annot = gr.Image(label="πŸ“Œ Input + Annotations", height=260)
660
+ out_raw = gr.Image(label="βœ‚οΈ Raw Graph Cut", height=260)
661
+ out_refined = gr.Image(label="✨ Refined Graph Cut", height=260)
662
  with gr.Row():
663
+ out_overlay = gr.Image(label="🎨 Overlay on Original", height=260)
664
+ out_extract = gr.Image(label="πŸ–ΌοΈ Extracted Foreground", height=260)
665
+ out_otsu = gr.Image(label="πŸ“Š Naive: Otsu", height=260)
666
+ out_km = gr.Image(label="πŸ“Š Naive: K-Means (k=2)", height=260)
 
667
 
668
+ gr.HTML('<hr class="warm-divider">')
669
+
670
+ # ── STEP 5: Analysis ──────────────────────────────────────────────
671
  gr.HTML("""
672
+ <div class="sec-header">
673
+ <span class="step-num">5</span>
674
+ <span class="sec-title-text">Convergence &amp; Iteration Analysis</span>
675
+ </div>
676
  """)
677
  with gr.Row():
678
+ out_energy = gr.Image(label="πŸ“ˆ Energy Convergence", height=360)
679
+ out_iters = gr.Image(label="πŸ”„ Iterative Mask Progression", height=360)
 
680
 
681
+ # ── Wire ──────────────────────────────────────────────────────────
682
  run_btn.click(
683
  fn=run_segmentation,
684
  inputs=[fg_editor, bg_editor, img_upload,
 
690
  show_progress="full",
691
  )
692
 
693
+ # ��─ Footer ────────────────────────────────────────────────────────
694
  gr.HTML("""
695
+ <div class="footer-wrap">
696
+ CSL7360: Computer Vision &nbsp;Β·&nbsp; Assignment 2 &nbsp;Β·&nbsp;
 
697
  Graph Cut Segmentation via PyMaxflow &amp; GMMs
698
  </div>
699
  """)