File size: 35,016 Bytes
332f1d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
"""
========================================================================
Graph Cut Image Segmentation Pipeline
CSL7360: Computer Vision — Assignment 2
========================================================================
This module implements a complete Graph Cut segmentation pipeline:
  1. Interactive annotation (scribbles) via OpenCV GUI
  2. Foreground/Background modeling using GMMs
  3. Graph construction with unary (data) and pairwise (smoothness) terms
  4. Min-Cut / Max-Flow optimization using PyMaxflow
  5. Iterative refinement of GMM models and graph cuts
  6. Artifact mitigation: morphological cleaning, boundary smoothing
  7. Visualization and comparison of results
========================================================================
"""

import numpy as np
import cv2
import maxflow
import os
import argparse
from sklearn.mixture import GaussianMixture
import matplotlib

# Use non-interactive backend when saving; switch to TkAgg for GUI
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt


# =====================================================================
# Section 1: Interactive Annotation Tool
# =====================================================================

class ScribbleAnnotator:
    """
    Interactive GUI for collecting foreground/background scribbles.
    
    Left mouse button  → Foreground (Green)
    Right mouse button → Background (Red)
    Press 'q' or Enter  → Finish annotation
    Press 'r'           → Reset scribbles
    """

    def __init__(self, image: np.ndarray):
        self.image = image.copy()
        self.display = image.copy()
        self.fg_mask = np.zeros(image.shape[:2], dtype=np.uint8)  # foreground
        self.bg_mask = np.zeros(image.shape[:2], dtype=np.uint8)  # background
        self.drawing = False
        self.mode = None  # 'fg' or 'bg'
        self.brush_size = 5

    def _mouse_callback(self, event, x, y, flags, param):
        """Handle mouse events for drawing scribbles."""
        if event == cv2.EVENT_LBUTTONDOWN:
            self.drawing = True
            self.mode = "fg"
        elif event == cv2.EVENT_RBUTTONDOWN:
            self.drawing = True
            self.mode = "bg"
        elif event == cv2.EVENT_MOUSEMOVE and self.drawing:
            if self.mode == "fg":
                cv2.circle(self.fg_mask, (x, y), self.brush_size, 1, -1)
                cv2.circle(self.display, (x, y), self.brush_size, (0, 255, 0), -1)
            elif self.mode == "bg":
                cv2.circle(self.bg_mask, (x, y), self.brush_size, 1, -1)
                cv2.circle(self.display, (x, y), self.brush_size, (0, 0, 255), -1)
        elif event in (cv2.EVENT_LBUTTONUP, cv2.EVENT_RBUTTONUP):
            self.drawing = False

    def run(self) -> tuple:
        """
        Launch annotation window. Returns (fg_mask, bg_mask) as binary arrays.
        """
        win = "Annotate: LEFT=FG(green), RIGHT=BG(red), q=done, r=reset"
        cv2.namedWindow(win, cv2.WINDOW_NORMAL)
        cv2.setMouseCallback(win, self._mouse_callback)

        while True:
            cv2.imshow(win, self.display)
            key = cv2.waitKey(1) & 0xFF
            if key in (ord("q"), 13):  # q or Enter
                break
            elif key == ord("r"):
                self.display = self.image.copy()
                self.fg_mask[:] = 0
                self.bg_mask[:] = 0

        cv2.destroyAllWindows()
        return self.fg_mask, self.bg_mask


def load_annotations_from_file(image_shape, fg_path, bg_path):
    """
    Load pre-saved annotation masks from disk (for non-interactive / headless mode).
    Masks should be single-channel images where nonzero = annotated.
    """
    h, w = image_shape[:2]
    fg_mask = np.zeros((h, w), dtype=np.uint8)
    bg_mask = np.zeros((h, w), dtype=np.uint8)

    if os.path.exists(fg_path):
        fg_img = cv2.imread(fg_path, cv2.IMREAD_GRAYSCALE)
        if fg_img is not None:
            fg_mask = (cv2.resize(fg_img, (w, h)) > 127).astype(np.uint8)
    if os.path.exists(bg_path):
        bg_img = cv2.imread(bg_path, cv2.IMREAD_GRAYSCALE)
        if bg_img is not None:
            bg_mask = (cv2.resize(bg_img, (w, h)) > 127).astype(np.uint8)

    return fg_mask, bg_mask


def generate_auto_annotations(image: np.ndarray):
    """
    Automatically generate rough foreground/background scribbles.
    Foreground: center region of the image.
    Background: border region of the image.
    This is useful for headless / automated runs.
    """
    h, w = image.shape[:2]
    fg_mask = np.zeros((h, w), dtype=np.uint8)
    bg_mask = np.zeros((h, w), dtype=np.uint8)

    # Foreground: small central cross (like actual scribbles),
    # kept tight so only object pixels are included
    cy, cx = h // 2, w // 2
    rh, rw = h // 10, w // 10          # 10% radius instead of 20%
    t = max(h // 30, 4)                 # scribble thickness
    # Horizontal bar
    fg_mask[cy - t:cy + t, cx - rw:cx + rw] = 1
    # Vertical bar
    fg_mask[cy - rh:cy + rh, cx - t:cx + t] = 1

    # Background: border strips (10% from each edge)
    bh, bw = max(h // 10, 5), max(w // 10, 5)
    bg_mask[:bh, :] = 1
    bg_mask[-bh:, :] = 1
    bg_mask[:, :bw] = 1
    bg_mask[:, -bw:] = 1

    return fg_mask, bg_mask


# =====================================================================
# Section 2: Foreground / Background Modeling (GMM)
# =====================================================================

class PixelGMMModel:
    """
    Gaussian Mixture Model for foreground or background pixel distribution.
    
    Fits a GMM to the color values of annotated/labelled pixels and
    returns log-likelihood scores for any query pixel.
    """

    def __init__(self, n_components: int = 5):
        self.n_components = n_components
        self.gmm = GaussianMixture(
            n_components=n_components,
            covariance_type="full",
            max_iter=200,
            random_state=42,
        )
        self.fitted = False

    def fit(self, pixels: np.ndarray):
        """
        Fit GMM to pixel samples. pixels: (N, 3) array of BGR values.
        """
        if len(pixels) < self.n_components:
            # Fall back to fewer components if too few samples
            self.gmm = GaussianMixture(
                n_components=max(1, len(pixels)),
                covariance_type="full",
                max_iter=200,
                random_state=42,
            )
        self.gmm.fit(pixels)
        self.fitted = True

    def score_pixels(self, pixels: np.ndarray) -> np.ndarray:
        """
        Return per-sample log-likelihood. pixels: (N, 3).
        Higher = more likely to belong to this model.
        """
        if not self.fitted:
            return np.zeros(len(pixels))
        return self.gmm.score_samples(pixels)


def build_gmm_models(image: np.ndarray, fg_mask: np.ndarray, bg_mask: np.ndarray,
                      n_components: int = 5):
    """
    Build foreground and background GMMs from annotated pixels.
    Returns (fg_model, bg_model).
    """
    fg_pixels = image[fg_mask == 1].reshape(-1, 3).astype(np.float64)
    bg_pixels = image[bg_mask == 1].reshape(-1, 3).astype(np.float64)

    fg_model = PixelGMMModel(n_components)
    bg_model = PixelGMMModel(n_components)

    if len(fg_pixels) > 0:
        fg_model.fit(fg_pixels)
    if len(bg_pixels) > 0:
        bg_model.fit(bg_pixels)

    return fg_model, bg_model


# =====================================================================
# Section 3: Energy Formulation & Graph Construction
# =====================================================================

def compute_unary_costs(image: np.ndarray, fg_model: PixelGMMModel,
                        bg_model: PixelGMMModel,
                        fg_mask: np.ndarray, bg_mask: np.ndarray,
                        hard_constraint_weight: float = 1e9) -> tuple:
    """
    Compute unary (data) costs for each pixel.
    
    E_data(x_p) = -log P(I_p | label)
    
    For annotated pixels, we assign a very high cost to the opposite label
    (hard constraints).
    
    Returns:
        fg_cost: (H, W) — cost of assigning pixel to foreground (source)
        bg_cost: (H, W) — cost of assigning pixel to background (sink)
    """
    h, w = image.shape[:2]
    pixels = image.reshape(-1, 3).astype(np.float64)

    # Log-likelihoods from GMMs
    fg_ll = fg_model.score_pixels(pixels).reshape(h, w)
    bg_ll = bg_model.score_pixels(pixels).reshape(h, w)

    # Convert to costs: cost = -log_likelihood  (lower likelihood → higher cost)
    # We negate because score_samples returns log-probability
    # Cost of labeling as foreground = negative log-prob under foreground model
    # We want: if pixel looks like BG, cost of labeling it FG should be high
    # So: cost_fg = -log P(pixel | FG)  ... but score_samples already gives log P
    # Therefore: cost_to_be_sink(bg) = -fg_ll (pixel not matching FG → high bg cost? No.)
    #
    # Standard formulation:
    #   source capacity (weight for cutting source edge = assigning to BG) = -log P(I|BG)
    #   sink capacity   (weight for cutting sink edge = assigning to FG)   = -log P(I|FG)
    #
    # Wait — let's be precise:
    #   If pixel is connected to Source (FG) and Sink (BG),
    #   cutting the source edge → pixel goes to BG → cost should be high if pixel is FG-like
    #   So source_cap = -log P(I|FG)  is WRONG for that.
    #
    # Correct:
    #   source_cap (edge from S to pixel) = -log P(I_p | BG)  → high when pixel unlikely BG
    #   sink_cap   (edge from pixel to T) = -log P(I_p | FG)  → high when pixel unlikely FG
    #
    # Cutting source edge means pixel goes to sink (BG).
    # So source_cap should be the "penalty for going BG" = how unlikely it is under BG = -bg_ll

    source_cap = -bg_ll  # penalty for assigning to background
    sink_cap = -fg_ll    # penalty for assigning to foreground

    # Shift to ensure non-negative costs
    min_val = min(source_cap.min(), sink_cap.min())
    if min_val < 0:
        source_cap -= min_val
        sink_cap -= min_val

    # Hard constraints for annotated pixels
    source_cap[fg_mask == 1] = hard_constraint_weight
    sink_cap[fg_mask == 1] = 0

    source_cap[bg_mask == 1] = 0
    sink_cap[bg_mask == 1] = hard_constraint_weight

    return source_cap, sink_cap


def compute_pairwise_costs(image: np.ndarray, beta: float = None,
                            gamma: float = 50.0) -> tuple:
    """
    Compute pairwise (smoothness) costs between neighboring pixels.
    
    E_smooth(x_p, x_q) = gamma * exp(-beta * ||I_p - I_q||^2)  if x_p ≠ x_q
                        = 0                                      if x_p == x_q
    
    beta = 1 / (2 * <||I_p - I_q||^2>)  (average over all neighbor pairs)
    
    We compute weights for 4-connected neighbors (right, down).
    
    Returns:
        right_weights: (H, W) — smoothness weight for horizontal edges
        down_weights:  (H, W) — smoothness weight for vertical edges
    """
    img = image.astype(np.float64)
    h, w = img.shape[:2]

    # Compute differences for right and down neighbors
    diff_right = img[:, 1:, :] - img[:, :-1, :]   # (H, W-1, 3)
    diff_down = img[1:, :, :] - img[:-1, :, :]     # (H-1, W, 3)

    dist_right = np.sum(diff_right ** 2, axis=2)    # (H, W-1)
    dist_down = np.sum(diff_down ** 2, axis=2)      # (H-1, W)

    # Compute beta from average squared color distance
    if beta is None:
        total_sum = dist_right.sum() + dist_down.sum()
        total_count = dist_right.size + dist_down.size
        avg_dist = total_sum / total_count if total_count > 0 else 1.0
        beta = 1.0 / (2.0 * avg_dist) if avg_dist > 0 else 0.0

    # Smoothness weights
    right_weights = gamma * np.exp(-beta * dist_right)
    down_weights = gamma * np.exp(-beta * dist_down)

    return right_weights, down_weights, beta


def build_graph_and_cut(source_cap: np.ndarray, sink_cap: np.ndarray,
                         right_weights: np.ndarray, down_weights: np.ndarray) -> np.ndarray:
    """
    Construct the graph using PyMaxflow and solve the min-cut / max-flow.
    
    Graph structure:
      - Source node S represents Foreground
      - Sink node T represents Background
      - Each pixel is a node
      - Terminal edges: S→pixel (source_cap), pixel→T (sink_cap)
      - Neighbor edges: between adjacent pixels (pairwise smoothness)
    
    The min-cut partitions pixels into S-set (foreground) and T-set (background).
    
    Returns:
        labels: (H, W) binary mask — 1 = foreground, 0 = background
    """
    h, w = source_cap.shape

    # Create graph
    g = maxflow.Graph[float](h * w, h * w * 2)
    g.add_nodes(h * w)

    # Add terminal edges (unary / data costs)
    for i in range(h):
        for j in range(w):
            idx = i * w + j
            g.add_tedge(idx, source_cap[i, j], sink_cap[i, j])

    # Add pairwise (smoothness) edges — 4-connected neighborhood
    # Right neighbors
    for i in range(h):
        for j in range(w - 1):
            idx1 = i * w + j
            idx2 = i * w + (j + 1)
            weight = right_weights[i, j]
            g.add_edge(idx1, idx2, weight, weight)

    # Down neighbors
    for i in range(h - 1):
        for j in range(w):
            idx1 = i * w + j
            idx2 = (i + 1) * w + j
            weight = down_weights[i, j]
            g.add_edge(idx1, idx2, weight, weight)

    # Solve min-cut / max-flow
    flow = g.maxflow()
    print(f"    Max-flow value: {flow:.2f}")

    # Extract labels: 0 = source side (FG), 1 = sink side (BG) in PyMaxflow
    segments = np.array([g.get_segment(idx) for idx in range(h * w)])
    labels = segments.reshape(h, w)

    # In PyMaxflow: segment 0 = source side = foreground
    #               segment 1 = sink side   = background
    # We want 1 = foreground, 0 = background
    labels = 1 - labels

    return labels


# =====================================================================
# Section 4: Iterative Graph Cut Optimization
# =====================================================================

def iterative_graph_cut(image: np.ndarray, fg_mask: np.ndarray, bg_mask: np.ndarray,
                         n_iterations: int = 3, n_components: int = 5,
                         gamma: float = 50.0) -> tuple:
    """
    Perform iterative graph cut segmentation:
      1. Build initial GMMs from user scribbles.
      2. Construct graph and compute min-cut.
      3. Update GMMs using newly labelled pixels.
      4. Repeat for n_iterations.
    
    Returns:
        final_mask: (H, W) binary segmentation
        all_masks:  list of masks at each iteration (for comparison)
        energies:   list of energy values per iteration
    """
    h, w = image.shape[:2]
    current_fg_mask = fg_mask.copy()
    current_bg_mask = bg_mask.copy()
    all_masks = []
    energies = []

    for it in range(n_iterations):
        print(f"  Iteration {it + 1}/{n_iterations}")

        # Step 1: Build / Update GMMs
        fg_model, bg_model = build_gmm_models(image, current_fg_mask, current_bg_mask,
                                                n_components)

        # Step 2: Compute unary costs
        source_cap, sink_cap = compute_unary_costs(image, fg_model, bg_model,
                                                     fg_mask, bg_mask)

        # Step 3: Compute pairwise costs
        right_w, down_w, beta = compute_pairwise_costs(image, gamma=gamma)

        # Step 4: Build graph and solve min-cut
        labels = build_graph_and_cut(source_cap, sink_cap, right_w, down_w)
        all_masks.append(labels.copy())

        # Compute energy for monitoring convergence
        pixels = image.reshape(-1, 3).astype(np.float64)
        fg_ll = fg_model.score_pixels(pixels).reshape(h, w)
        bg_ll = bg_model.score_pixels(pixels).reshape(h, w)
        data_energy = -np.sum(fg_ll[labels == 1]) - np.sum(bg_ll[labels == 0])

        # Smoothness energy (count boundary edges)
        smooth_energy = 0
        diff_h = (labels[:, 1:] != labels[:, :-1]).astype(float)
        diff_v = (labels[1:, :] != labels[:-1, :]).astype(float)
        smooth_energy = np.sum(diff_h * right_w) + np.sum(diff_v * down_w)
        total_energy = data_energy + smooth_energy
        energies.append(total_energy)
        print(f"    Energy: {total_energy:.2f} (data={data_energy:.2f}, smooth={smooth_energy:.2f})")

        # Step 5: Update masks for next iteration
        current_fg_mask = labels.copy()
        current_bg_mask = (1 - labels).copy()
        # Preserve hard constraints from user annotations
        current_fg_mask[fg_mask == 1] = 1
        current_bg_mask[bg_mask == 1] = 1

    # Return the mask from the lowest-energy iteration (not necessarily the last)
    best_iter = int(np.argmin(energies))
    print(f"  Best iteration: {best_iter + 1} (energy={energies[best_iter]:.2f})")
    return all_masks[best_iter], all_masks, energies


# =====================================================================
# Section 5: Artifact Mitigation & Refinement
# =====================================================================

def remove_small_regions(mask: np.ndarray, min_area: int = 500) -> np.ndarray:
    """
    Remove small isolated foreground and background regions using
    connected component analysis and morphological operations.
    """
    cleaned = mask.copy().astype(np.uint8)

    # Remove small foreground regions
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(cleaned, connectivity=8)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] < min_area:
            cleaned[labels == i] = 0

    # Remove small background holes (invert, clean, invert back)
    inv = 1 - cleaned
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(inv, connectivity=8)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] < min_area:
            cleaned[labels == i] = 1

    return cleaned


def smooth_boundaries(mask: np.ndarray, ksize: int = 5) -> np.ndarray:
    """
    Smooth jagged segmentation boundaries using morphological closing
    followed by Gaussian blur and re-thresholding.
    """
    m = mask.astype(np.uint8) * 255
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))

    # Close small gaps
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, kernel, iterations=1)
    # Open to remove thin protrusions
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN, kernel, iterations=1)
    # Gaussian blur + threshold for smooth boundary
    m = cv2.GaussianBlur(m, (ksize * 2 + 1, ksize * 2 + 1), 0)
    m = (m > 127).astype(np.uint8)

    return m


def ensure_intensity_consistency(mask: np.ndarray, image: np.ndarray,
                                  threshold: float = 30.0) -> np.ndarray:
    """
    Intensity Consistency: re-label pixels near the boundary whose color is
    significantly closer to the opposite region's mean color.

    For each foreground pixel within a border band, if its color distance to
    the background mean is smaller than to the foreground mean, flip it to
    background (and vice versa).  This corrects visually incoherent pixels
    that slipped through the graph cut due to weak data terms.
    """
    refined = mask.copy().astype(np.uint8)
    img_f = image.astype(np.float32)

    fg_pixels = img_f[refined == 1]
    bg_pixels = img_f[refined == 0]

    if len(fg_pixels) == 0 or len(bg_pixels) == 0:
        return refined

    fg_mean = fg_pixels.mean(axis=0)   # mean FG color (BGR)
    bg_mean = bg_pixels.mean(axis=0)   # mean BG color (BGR)

    # Build a narrow band around the boundary (dilate XOR original)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
    dilated = cv2.dilate(refined, kernel, iterations=1)
    eroded  = cv2.erode(refined,  kernel, iterations=1)
    band    = (dilated - eroded).astype(bool)   # True = boundary pixels

    for label, correct_mean, wrong_mean in [(1, fg_mean, bg_mean),
                                             (0, bg_mean, fg_mean)]:
        region_band = band & (refined == label)
        coords = np.argwhere(region_band)
        for (r, c) in coords:
            color = img_f[r, c]
            d_correct = float(np.linalg.norm(color - correct_mean))
            d_wrong   = float(np.linalg.norm(color - wrong_mean))
            # Flip only when the pixel is clearly closer to the opposite mean
            if d_wrong < d_correct - threshold:
                refined[r, c] = 1 - label

    return refined


def refine_segmentation(mask: np.ndarray, image: np.ndarray,
                          min_area: int = None, smooth_ksize: int = 3) -> np.ndarray:
    """
    Full refinement pipeline:
      1. Remove small isolated regions (morphological noise removal)
      2. Smooth jagged boundaries
      3. Intensity consistency correction near boundaries
    min_area defaults to 0.1% of image pixels to scale with image size.
    smooth_ksize reduced to 3 to avoid distorting fine structures.
    """
    print("  Refining segmentation...")
    if min_area is None:
        min_area = max(50, int(mask.size * 0.001))  # 0.1% of pixels
    refined = remove_small_regions(mask, min_area)
    refined = smooth_boundaries(refined, smooth_ksize)
    refined = ensure_intensity_consistency(refined, image)
    return refined


# =====================================================================
# Section 6: Naive Segmentation (for comparison)
# =====================================================================

def naive_thresholding_segmentation(image: np.ndarray) -> np.ndarray:
    """
    Simple Otsu thresholding as a naive baseline for comparison.
    Returns raw Otsu mask; label alignment to graph cut is done after
    graph cut is computed (see align_naive_to_graphcut).
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, mask = cv2.threshold(gray, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return mask


def align_naive_to_graphcut(naive_mask: np.ndarray,
                             reference_mask: np.ndarray) -> np.ndarray:
    """
    Align a naive mask's label convention to match the graph cut reference.
    Checks whether the mask or its inverse has more overlap with the reference,
    and returns whichever agrees more. This handles cases where Otsu/K-Means
    assign FG=1 to the bright region while graph cut assigns FG=1 to the object.
    """
    overlap_normal  = np.sum(naive_mask == reference_mask)
    overlap_inverted = np.sum((1 - naive_mask) == reference_mask)
    if overlap_inverted > overlap_normal:
        return 1 - naive_mask
    return naive_mask


def naive_kmeans_segmentation(image: np.ndarray, k: int = 2) -> np.ndarray:
    """
    K-Means clustering as another naive baseline.
    """
    pixels = image.reshape(-1, 3).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
    _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10,
                                     cv2.KMEANS_RANDOM_CENTERS)
    # Assign the darker cluster as background
    labels = labels.reshape(image.shape[:2])
    if centers[0].mean() > centers[1].mean():
        labels = 1 - labels
    return labels.astype(np.uint8)


# =====================================================================
# Section 7: Visualization
# =====================================================================

def create_overlay(image: np.ndarray, mask: np.ndarray,
                    color: tuple = (0, 255, 0), alpha: float = 0.4) -> np.ndarray:
    """
    Overlay a segmentation mask on the original image.
    """
    overlay = image.copy()
    colored = np.zeros_like(image)
    colored[:] = color
    region = mask.astype(bool)
    overlay[region] = cv2.addWeighted(image[region], 1 - alpha,
                                       colored[region], alpha, 0)
    # Draw contours
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL,
                                    cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(overlay, contours, -1, color, 2)
    return overlay


def visualize_results(image: np.ndarray, fg_mask: np.ndarray, bg_mask: np.ndarray,
                       raw_mask: np.ndarray, refined_mask: np.ndarray,
                       naive_mask: np.ndarray, naive_kmeans_mask: np.ndarray,
                       all_iter_masks: list, energies: list,
                       output_dir: str, img_name: str):
    """
    Generate comprehensive visualization of all results and save to disk.
    """
    # --- Figure 1: Main comparison (2×4 grid) ---
    fig, axes = plt.subplots(2, 4, figsize=(24, 12))
    fig.suptitle(f"Graph Cut Segmentation — {img_name}", fontsize=16, fontweight="bold")

    # Original + scribbles
    scribble_vis = image.copy()
    scribble_vis[fg_mask == 1] = [0, 255, 0]
    scribble_vis[bg_mask == 1] = [0, 0, 255]
    axes[0, 0].imshow(cv2.cvtColor(scribble_vis, cv2.COLOR_BGR2RGB))
    axes[0, 0].set_title("Input + Annotations")
    axes[0, 0].axis("off")

    # Naive segmentation — Otsu
    axes[0, 1].imshow(naive_mask, cmap="gray")
    axes[0, 1].set_title("Naive: Otsu Thresholding")
    axes[0, 1].axis("off")

    # Naive segmentation — K-Means
    axes[0, 2].imshow(naive_kmeans_mask, cmap="gray")
    axes[0, 2].set_title("Naive: K-Means (k=2)")
    axes[0, 2].axis("off")

    # Raw graph cut
    axes[0, 3].imshow(raw_mask, cmap="gray")
    axes[0, 3].set_title("Raw Graph Cut")
    axes[0, 3].axis("off")

    # Refined mask
    axes[1, 0].imshow(refined_mask, cmap="gray")
    axes[1, 0].set_title("Refined Graph Cut")
    axes[1, 0].axis("off")

    # Overlay on original
    overlay = create_overlay(image, refined_mask)
    axes[1, 1].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    axes[1, 1].set_title("Overlay on Original")
    axes[1, 1].axis("off")

    # Extracted foreground
    extracted = image.copy()
    extracted[refined_mask == 0] = [255, 255, 255]
    axes[1, 2].imshow(cv2.cvtColor(extracted, cv2.COLOR_BGR2RGB))
    axes[1, 2].set_title("Extracted Foreground")
    axes[1, 2].axis("off")

    # Side-by-side comparison: best naive vs graph cut (with gap)
    h_cmp = naive_mask.shape[0]
    gap = np.full((h_cmp, 10), 255, dtype=np.uint8)  # white divider
    compare = np.hstack([
        naive_mask.astype(np.uint8) * 255,
        gap,
        refined_mask.astype(np.uint8) * 255
    ])
    axes[1, 3].imshow(compare, cmap="gray")
    axes[1, 3].set_title("Otsu vs Graph Cut (side-by-side)")
    axes[1, 3].axis("off")

    plt.tight_layout()
    fig.savefig(os.path.join(output_dir, f"{img_name}_results.png"), dpi=150,
                bbox_inches="tight")
    plt.close(fig)

    # --- Figure 2: Iteration progression ---
    if len(all_iter_masks) > 1:
        n = len(all_iter_masks)
        fig2, axes2 = plt.subplots(1, n + 1, figsize=(5 * (n + 1), 5))
        fig2.suptitle(f"Iterative Refinement — {img_name}", fontsize=14)

        for i, m in enumerate(all_iter_masks):
            axes2[i].imshow(m, cmap="gray")
            axes2[i].set_title(f"Iteration {i + 1}")
            axes2[i].axis("off")
        axes2[n].imshow(refined_mask, cmap="gray")
        axes2[n].set_title("After Post-Processing")
        axes2[n].axis("off")

        plt.tight_layout()
        fig2.savefig(os.path.join(output_dir, f"{img_name}_iterations.png"), dpi=150,
                     bbox_inches="tight")
        plt.close(fig2)

    # --- Figure 3: Energy convergence ---
    if len(energies) > 1:
        fig3, ax3 = plt.subplots(figsize=(8, 5))
        ax3.plot(range(1, len(energies) + 1), energies, "bo-", linewidth=2, markersize=8)
        ax3.set_xlabel("Iteration", fontsize=12)
        ax3.set_ylabel("Total Energy", fontsize=12)
        ax3.set_title(f"Energy Convergence — {img_name}", fontsize=14)
        ax3.grid(True, alpha=0.3)
        fig3.savefig(os.path.join(output_dir, f"{img_name}_energy.png"), dpi=150,
                     bbox_inches="tight")
        plt.close(fig3)

    print(f"  Visualizations saved to {output_dir}/")


# =====================================================================
# Section 8: Full Pipeline
# =====================================================================

def run_pipeline(image_path: str, output_dir: str = "outputs",
                  n_iterations: int = 3, n_components: int = 5,
                  gamma: float = 50.0, interactive: bool = True,
                  fg_anno_path: str = None, bg_anno_path: str = None,
                  auto_annotate: bool = False,
                  max_dim: int = 400):
    """
    Run the complete Graph Cut segmentation pipeline on a single image.
    
    Parameters:
        image_path:     Path to input image
        output_dir:     Directory to save results
        n_iterations:   Number of graph-cut iterations
        n_components:   GMM components
        gamma:          Smoothness weight
        interactive:    If True, open GUI for scribble annotation
        fg_anno_path:   Path to pre-made foreground annotation mask
        bg_anno_path:   Path to pre-made background annotation mask
        auto_annotate:  If True, generate automatic center/border annotations
        max_dim:        Resize image so largest dimension ≤ max_dim (for speed)
    """
    os.makedirs(output_dir, exist_ok=True)
    img_name = os.path.splitext(os.path.basename(image_path))[0]

    print(f"\n{'='*60}")
    print(f"Processing: {image_path}")
    print(f"{'='*60}")

    # Load image
    image = cv2.imread(image_path)
    if image is None:
        print(f"ERROR: Could not load image '{image_path}'")
        return

    # Resize for tractability
    h, w = image.shape[:2]
    if max(h, w) > max_dim:
        scale = max_dim / max(h, w)
        image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
        print(f"  Resized from ({w},{h}) to {image.shape[1]}x{image.shape[0]}")

    # Step 1: Obtain annotations
    print("Step 1: Obtaining annotations...")
    if interactive:
        annotator = ScribbleAnnotator(image)
        fg_mask, bg_mask = annotator.run()
    elif fg_anno_path and bg_anno_path:
        fg_mask, bg_mask = load_annotations_from_file(image.shape, fg_anno_path, bg_anno_path)
    elif auto_annotate:
        fg_mask, bg_mask = generate_auto_annotations(image)
    else:
        print("  No annotation source specified. Using auto-annotation.")
        fg_mask, bg_mask = generate_auto_annotations(image)

    fg_count = fg_mask.sum()
    bg_count = bg_mask.sum()
    print(f"  Foreground scribble pixels: {fg_count}")
    print(f"  Background scribble pixels: {bg_count}")

    if fg_count == 0 or bg_count == 0:
        print("  WARNING: Need both FG and BG annotations. Using auto-annotation.")
        fg_mask, bg_mask = generate_auto_annotations(image)

    # Step 2: Naive segmentation (baseline — both Otsu and K-Means)
    print("Step 2: Computing naive baseline segmentation...")
    naive_mask = naive_thresholding_segmentation(image)
    naive_kmeans_mask = naive_kmeans_segmentation(image)

    # Step 3: Iterative graph cut
    print("Step 3: Running iterative graph cut segmentation...")
    raw_mask, all_masks, energies = iterative_graph_cut(
        image, fg_mask, bg_mask,
        n_iterations=n_iterations,
        n_components=n_components,
        gamma=gamma,
    )

    # Step 4: Refine segmentation
    print("Step 4: Refining segmentation (artifact mitigation)...")
    refined_mask = refine_segmentation(raw_mask, image)

    # Align naive masks to graph cut label convention (FG=1 must mean the same thing)
    naive_mask = align_naive_to_graphcut(naive_mask, refined_mask)
    naive_kmeans_mask = align_naive_to_graphcut(naive_kmeans_mask, refined_mask)

    # Step 5: Save outputs
    print("Step 5: Saving results...")
    cv2.imwrite(os.path.join(output_dir, f"{img_name}_raw_mask.png"),
                (raw_mask * 255).astype(np.uint8))
    cv2.imwrite(os.path.join(output_dir, f"{img_name}_refined_mask.png"),
                (refined_mask * 255).astype(np.uint8))
    overlay = create_overlay(image, refined_mask)
    cv2.imwrite(os.path.join(output_dir, f"{img_name}_overlay.png"), overlay)

    # Step 6: Visualize
    print("Step 6: Generating visualizations...")
    visualize_results(image, fg_mask, bg_mask, raw_mask, refined_mask,
                       naive_mask, naive_kmeans_mask, all_masks, energies, output_dir, img_name)

    print(f"  Done: {img_name}")
    return refined_mask


# =====================================================================
# Section 9: Entry Point
# =====================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Graph Cut Image Segmentation Pipeline — CSL7360 Assignment 2",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Interactive annotation (opens GUI window)
  python graph_cut_segmentation.py --images img1.jpg img2.jpg img3.jpg

  # Automatic annotation (headless, no GUI)
  python graph_cut_segmentation.py --images img1.jpg --auto

  # Custom parameters
  python graph_cut_segmentation.py --images img1.jpg --iterations 5 --gamma 80 --gmm-components 7
        """,
    )
    parser.add_argument("--images", nargs="+", required=True,
                        help="Paths to input images (at least 3 recommended)")
    parser.add_argument("--output", default="outputs",
                        help="Output directory (default: outputs)")
    parser.add_argument("--iterations", type=int, default=3,
                        help="Number of iterative optimization steps (default: 3)")
    parser.add_argument("--gmm-components", type=int, default=5,
                        help="Number of GMM components per model (default: 5)")
    parser.add_argument("--gamma", type=float, default=50.0,
                        help="Smoothness weight gamma (default: 50.0)")
    parser.add_argument("--max-dim", type=int, default=400,
                        help="Max image dimension for processing (default: 400)")
    parser.add_argument("--auto", action="store_true",
                        help="Use automatic center/border annotations (no GUI)")
    parser.add_argument("--no-interactive", action="store_true",
                        help="Disable interactive GUI (use --auto or provide masks)")

    args = parser.parse_args()

    interactive = not (args.auto or args.no_interactive)

    for img_path in args.images:
        run_pipeline(
            image_path=img_path,
            output_dir=args.output,
            n_iterations=args.iterations,
            n_components=args.gmm_components,
            gamma=args.gamma,
            interactive=interactive,
            auto_annotate=args.auto,
            max_dim=args.max_dim,
        )

    print(f"\nAll results saved in '{args.output}/' directory.")
    print("Pipeline complete.")


if __name__ == "__main__":
    main()