Spaces:
Sleeping
Sleeping
| """ | |
| ======================================================================== | |
| Graph Cut Image Segmentation Pipeline | |
| CSL7360: Computer Vision — Assignment 2 | |
| ======================================================================== | |
| This module implements a complete Graph Cut segmentation pipeline: | |
| 1. Interactive annotation (scribbles) via OpenCV GUI | |
| 2. Foreground/Background modeling using GMMs | |
| 3. Graph construction with unary (data) and pairwise (smoothness) terms | |
| 4. Min-Cut / Max-Flow optimization using PyMaxflow | |
| 5. Iterative refinement of GMM models and graph cuts | |
| 6. Artifact mitigation: morphological cleaning, boundary smoothing | |
| 7. Visualization and comparison of results | |
| ======================================================================== | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import maxflow | |
| import os | |
| import argparse | |
| from sklearn.mixture import GaussianMixture | |
| import matplotlib | |
| # Use non-interactive backend when saving; switch to TkAgg for GUI | |
| matplotlib.use("TkAgg") | |
| import matplotlib.pyplot as plt | |
| # ===================================================================== | |
| # Section 1: Interactive Annotation Tool | |
| # ===================================================================== | |
| class ScribbleAnnotator: | |
| """ | |
| Interactive GUI for collecting foreground/background scribbles. | |
| Left mouse button → Foreground (Green) | |
| Right mouse button → Background (Red) | |
| Press 'q' or Enter → Finish annotation | |
| Press 'r' → Reset scribbles | |
| """ | |
| def __init__(self, image: np.ndarray): | |
| self.image = image.copy() | |
| self.display = image.copy() | |
| self.fg_mask = np.zeros(image.shape[:2], dtype=np.uint8) # foreground | |
| self.bg_mask = np.zeros(image.shape[:2], dtype=np.uint8) # background | |
| self.drawing = False | |
| self.mode = None # 'fg' or 'bg' | |
| self.brush_size = 5 | |
| def _mouse_callback(self, event, x, y, flags, param): | |
| """Handle mouse events for drawing scribbles.""" | |
| if event == cv2.EVENT_LBUTTONDOWN: | |
| self.drawing = True | |
| self.mode = "fg" | |
| elif event == cv2.EVENT_RBUTTONDOWN: | |
| self.drawing = True | |
| self.mode = "bg" | |
| elif event == cv2.EVENT_MOUSEMOVE and self.drawing: | |
| if self.mode == "fg": | |
| cv2.circle(self.fg_mask, (x, y), self.brush_size, 1, -1) | |
| cv2.circle(self.display, (x, y), self.brush_size, (0, 255, 0), -1) | |
| elif self.mode == "bg": | |
| cv2.circle(self.bg_mask, (x, y), self.brush_size, 1, -1) | |
| cv2.circle(self.display, (x, y), self.brush_size, (0, 0, 255), -1) | |
| elif event in (cv2.EVENT_LBUTTONUP, cv2.EVENT_RBUTTONUP): | |
| self.drawing = False | |
| def run(self) -> tuple: | |
| """ | |
| Launch annotation window. Returns (fg_mask, bg_mask) as binary arrays. | |
| """ | |
| win = "Annotate: LEFT=FG(green), RIGHT=BG(red), q=done, r=reset" | |
| cv2.namedWindow(win, cv2.WINDOW_NORMAL) | |
| cv2.setMouseCallback(win, self._mouse_callback) | |
| while True: | |
| cv2.imshow(win, self.display) | |
| key = cv2.waitKey(1) & 0xFF | |
| if key in (ord("q"), 13): # q or Enter | |
| break | |
| elif key == ord("r"): | |
| self.display = self.image.copy() | |
| self.fg_mask[:] = 0 | |
| self.bg_mask[:] = 0 | |
| cv2.destroyAllWindows() | |
| return self.fg_mask, self.bg_mask | |
| def load_annotations_from_file(image_shape, fg_path, bg_path): | |
| """ | |
| Load pre-saved annotation masks from disk (for non-interactive / headless mode). | |
| Masks should be single-channel images where nonzero = annotated. | |
| """ | |
| h, w = image_shape[:2] | |
| fg_mask = np.zeros((h, w), dtype=np.uint8) | |
| bg_mask = np.zeros((h, w), dtype=np.uint8) | |
| if os.path.exists(fg_path): | |
| fg_img = cv2.imread(fg_path, cv2.IMREAD_GRAYSCALE) | |
| if fg_img is not None: | |
| fg_mask = (cv2.resize(fg_img, (w, h)) > 127).astype(np.uint8) | |
| if os.path.exists(bg_path): | |
| bg_img = cv2.imread(bg_path, cv2.IMREAD_GRAYSCALE) | |
| if bg_img is not None: | |
| bg_mask = (cv2.resize(bg_img, (w, h)) > 127).astype(np.uint8) | |
| return fg_mask, bg_mask | |
| def generate_auto_annotations(image: np.ndarray): | |
| """ | |
| Automatically generate rough foreground/background scribbles. | |
| Foreground: center region of the image. | |
| Background: border region of the image. | |
| This is useful for headless / automated runs. | |
| """ | |
| h, w = image.shape[:2] | |
| fg_mask = np.zeros((h, w), dtype=np.uint8) | |
| bg_mask = np.zeros((h, w), dtype=np.uint8) | |
| # Foreground: small central cross (like actual scribbles), | |
| # kept tight so only object pixels are included | |
| cy, cx = h // 2, w // 2 | |
| rh, rw = h // 10, w // 10 # 10% radius instead of 20% | |
| t = max(h // 30, 4) # scribble thickness | |
| # Horizontal bar | |
| fg_mask[cy - t:cy + t, cx - rw:cx + rw] = 1 | |
| # Vertical bar | |
| fg_mask[cy - rh:cy + rh, cx - t:cx + t] = 1 | |
| # Background: border strips (10% from each edge) | |
| bh, bw = max(h // 10, 5), max(w // 10, 5) | |
| bg_mask[:bh, :] = 1 | |
| bg_mask[-bh:, :] = 1 | |
| bg_mask[:, :bw] = 1 | |
| bg_mask[:, -bw:] = 1 | |
| return fg_mask, bg_mask | |
| # ===================================================================== | |
| # Section 2: Foreground / Background Modeling (GMM) | |
| # ===================================================================== | |
| class PixelGMMModel: | |
| """ | |
| Gaussian Mixture Model for foreground or background pixel distribution. | |
| Fits a GMM to the color values of annotated/labelled pixels and | |
| returns log-likelihood scores for any query pixel. | |
| """ | |
| def __init__(self, n_components: int = 5): | |
| self.n_components = n_components | |
| self.gmm = GaussianMixture( | |
| n_components=n_components, | |
| covariance_type="full", | |
| max_iter=200, | |
| random_state=42, | |
| ) | |
| self.fitted = False | |
| def fit(self, pixels: np.ndarray): | |
| """ | |
| Fit GMM to pixel samples. pixels: (N, 3) array of BGR values. | |
| """ | |
| if len(pixels) < self.n_components: | |
| # Fall back to fewer components if too few samples | |
| self.gmm = GaussianMixture( | |
| n_components=max(1, len(pixels)), | |
| covariance_type="full", | |
| max_iter=200, | |
| random_state=42, | |
| ) | |
| self.gmm.fit(pixels) | |
| self.fitted = True | |
| def score_pixels(self, pixels: np.ndarray) -> np.ndarray: | |
| """ | |
| Return per-sample log-likelihood. pixels: (N, 3). | |
| Higher = more likely to belong to this model. | |
| """ | |
| if not self.fitted: | |
| return np.zeros(len(pixels)) | |
| return self.gmm.score_samples(pixels) | |
| def build_gmm_models(image: np.ndarray, fg_mask: np.ndarray, bg_mask: np.ndarray, | |
| n_components: int = 5): | |
| """ | |
| Build foreground and background GMMs from annotated pixels. | |
| Returns (fg_model, bg_model). | |
| """ | |
| fg_pixels = image[fg_mask == 1].reshape(-1, 3).astype(np.float64) | |
| bg_pixels = image[bg_mask == 1].reshape(-1, 3).astype(np.float64) | |
| fg_model = PixelGMMModel(n_components) | |
| bg_model = PixelGMMModel(n_components) | |
| if len(fg_pixels) > 0: | |
| fg_model.fit(fg_pixels) | |
| if len(bg_pixels) > 0: | |
| bg_model.fit(bg_pixels) | |
| return fg_model, bg_model | |
| # ===================================================================== | |
| # Section 3: Energy Formulation & Graph Construction | |
| # ===================================================================== | |
| def compute_unary_costs(image: np.ndarray, fg_model: PixelGMMModel, | |
| bg_model: PixelGMMModel, | |
| fg_mask: np.ndarray, bg_mask: np.ndarray, | |
| hard_constraint_weight: float = 1e9) -> tuple: | |
| """ | |
| Compute unary (data) costs for each pixel. | |
| E_data(x_p) = -log P(I_p | label) | |
| For annotated pixels, we assign a very high cost to the opposite label | |
| (hard constraints). | |
| Returns: | |
| fg_cost: (H, W) — cost of assigning pixel to foreground (source) | |
| bg_cost: (H, W) — cost of assigning pixel to background (sink) | |
| """ | |
| h, w = image.shape[:2] | |
| pixels = image.reshape(-1, 3).astype(np.float64) | |
| # Log-likelihoods from GMMs | |
| fg_ll = fg_model.score_pixels(pixels).reshape(h, w) | |
| bg_ll = bg_model.score_pixels(pixels).reshape(h, w) | |
| # Convert to costs: cost = -log_likelihood (lower likelihood → higher cost) | |
| # We negate because score_samples returns log-probability | |
| # Cost of labeling as foreground = negative log-prob under foreground model | |
| # We want: if pixel looks like BG, cost of labeling it FG should be high | |
| # So: cost_fg = -log P(pixel | FG) ... but score_samples already gives log P | |
| # Therefore: cost_to_be_sink(bg) = -fg_ll (pixel not matching FG → high bg cost? No.) | |
| # | |
| # Standard formulation: | |
| # source capacity (weight for cutting source edge = assigning to BG) = -log P(I|BG) | |
| # sink capacity (weight for cutting sink edge = assigning to FG) = -log P(I|FG) | |
| # | |
| # Wait — let's be precise: | |
| # If pixel is connected to Source (FG) and Sink (BG), | |
| # cutting the source edge → pixel goes to BG → cost should be high if pixel is FG-like | |
| # So source_cap = -log P(I|FG) is WRONG for that. | |
| # | |
| # Correct: | |
| # source_cap (edge from S to pixel) = -log P(I_p | BG) → high when pixel unlikely BG | |
| # sink_cap (edge from pixel to T) = -log P(I_p | FG) → high when pixel unlikely FG | |
| # | |
| # Cutting source edge means pixel goes to sink (BG). | |
| # So source_cap should be the "penalty for going BG" = how unlikely it is under BG = -bg_ll | |
| source_cap = -bg_ll # penalty for assigning to background | |
| sink_cap = -fg_ll # penalty for assigning to foreground | |
| # Shift to ensure non-negative costs | |
| min_val = min(source_cap.min(), sink_cap.min()) | |
| if min_val < 0: | |
| source_cap -= min_val | |
| sink_cap -= min_val | |
| # Hard constraints for annotated pixels | |
| source_cap[fg_mask == 1] = hard_constraint_weight | |
| sink_cap[fg_mask == 1] = 0 | |
| source_cap[bg_mask == 1] = 0 | |
| sink_cap[bg_mask == 1] = hard_constraint_weight | |
| return source_cap, sink_cap | |
| def compute_pairwise_costs(image: np.ndarray, beta: float = None, | |
| gamma: float = 50.0) -> tuple: | |
| """ | |
| Compute pairwise (smoothness) costs between neighboring pixels. | |
| E_smooth(x_p, x_q) = gamma * exp(-beta * ||I_p - I_q||^2) if x_p ≠ x_q | |
| = 0 if x_p == x_q | |
| beta = 1 / (2 * <||I_p - I_q||^2>) (average over all neighbor pairs) | |
| We compute weights for 4-connected neighbors (right, down). | |
| Returns: | |
| right_weights: (H, W) — smoothness weight for horizontal edges | |
| down_weights: (H, W) — smoothness weight for vertical edges | |
| """ | |
| img = image.astype(np.float64) | |
| h, w = img.shape[:2] | |
| # Compute differences for right and down neighbors | |
| diff_right = img[:, 1:, :] - img[:, :-1, :] # (H, W-1, 3) | |
| diff_down = img[1:, :, :] - img[:-1, :, :] # (H-1, W, 3) | |
| dist_right = np.sum(diff_right ** 2, axis=2) # (H, W-1) | |
| dist_down = np.sum(diff_down ** 2, axis=2) # (H-1, W) | |
| # Compute beta from average squared color distance | |
| if beta is None: | |
| total_sum = dist_right.sum() + dist_down.sum() | |
| total_count = dist_right.size + dist_down.size | |
| avg_dist = total_sum / total_count if total_count > 0 else 1.0 | |
| beta = 1.0 / (2.0 * avg_dist) if avg_dist > 0 else 0.0 | |
| # Smoothness weights | |
| right_weights = gamma * np.exp(-beta * dist_right) | |
| down_weights = gamma * np.exp(-beta * dist_down) | |
| return right_weights, down_weights, beta | |
| def build_graph_and_cut(source_cap: np.ndarray, sink_cap: np.ndarray, | |
| right_weights: np.ndarray, down_weights: np.ndarray) -> np.ndarray: | |
| """ | |
| Construct the graph using PyMaxflow and solve the min-cut / max-flow. | |
| Graph structure: | |
| - Source node S represents Foreground | |
| - Sink node T represents Background | |
| - Each pixel is a node | |
| - Terminal edges: S→pixel (source_cap), pixel→T (sink_cap) | |
| - Neighbor edges: between adjacent pixels (pairwise smoothness) | |
| The min-cut partitions pixels into S-set (foreground) and T-set (background). | |
| Returns: | |
| labels: (H, W) binary mask — 1 = foreground, 0 = background | |
| """ | |
| h, w = source_cap.shape | |
| # Create graph | |
| g = maxflow.Graph[float](h * w, h * w * 2) | |
| g.add_nodes(h * w) | |
| # Add terminal edges (unary / data costs) | |
| for i in range(h): | |
| for j in range(w): | |
| idx = i * w + j | |
| g.add_tedge(idx, source_cap[i, j], sink_cap[i, j]) | |
| # Add pairwise (smoothness) edges — 4-connected neighborhood | |
| # Right neighbors | |
| for i in range(h): | |
| for j in range(w - 1): | |
| idx1 = i * w + j | |
| idx2 = i * w + (j + 1) | |
| weight = right_weights[i, j] | |
| g.add_edge(idx1, idx2, weight, weight) | |
| # Down neighbors | |
| for i in range(h - 1): | |
| for j in range(w): | |
| idx1 = i * w + j | |
| idx2 = (i + 1) * w + j | |
| weight = down_weights[i, j] | |
| g.add_edge(idx1, idx2, weight, weight) | |
| # Solve min-cut / max-flow | |
| flow = g.maxflow() | |
| print(f" Max-flow value: {flow:.2f}") | |
| # Extract labels: 0 = source side (FG), 1 = sink side (BG) in PyMaxflow | |
| segments = np.array([g.get_segment(idx) for idx in range(h * w)]) | |
| labels = segments.reshape(h, w) | |
| # In PyMaxflow: segment 0 = source side = foreground | |
| # segment 1 = sink side = background | |
| # We want 1 = foreground, 0 = background | |
| labels = 1 - labels | |
| return labels | |
| # ===================================================================== | |
| # Section 4: Iterative Graph Cut Optimization | |
| # ===================================================================== | |
| def iterative_graph_cut(image: np.ndarray, fg_mask: np.ndarray, bg_mask: np.ndarray, | |
| n_iterations: int = 3, n_components: int = 5, | |
| gamma: float = 50.0) -> tuple: | |
| """ | |
| Perform iterative graph cut segmentation: | |
| 1. Build initial GMMs from user scribbles. | |
| 2. Construct graph and compute min-cut. | |
| 3. Update GMMs using newly labelled pixels. | |
| 4. Repeat for n_iterations. | |
| Returns: | |
| final_mask: (H, W) binary segmentation | |
| all_masks: list of masks at each iteration (for comparison) | |
| energies: list of energy values per iteration | |
| """ | |
| h, w = image.shape[:2] | |
| current_fg_mask = fg_mask.copy() | |
| current_bg_mask = bg_mask.copy() | |
| all_masks = [] | |
| energies = [] | |
| for it in range(n_iterations): | |
| print(f" Iteration {it + 1}/{n_iterations}") | |
| # Step 1: Build / Update GMMs | |
| fg_model, bg_model = build_gmm_models(image, current_fg_mask, current_bg_mask, | |
| n_components) | |
| # Step 2: Compute unary costs | |
| source_cap, sink_cap = compute_unary_costs(image, fg_model, bg_model, | |
| fg_mask, bg_mask) | |
| # Step 3: Compute pairwise costs | |
| right_w, down_w, beta = compute_pairwise_costs(image, gamma=gamma) | |
| # Step 4: Build graph and solve min-cut | |
| labels = build_graph_and_cut(source_cap, sink_cap, right_w, down_w) | |
| all_masks.append(labels.copy()) | |
| # Compute energy for monitoring convergence | |
| pixels = image.reshape(-1, 3).astype(np.float64) | |
| fg_ll = fg_model.score_pixels(pixels).reshape(h, w) | |
| bg_ll = bg_model.score_pixels(pixels).reshape(h, w) | |
| data_energy = -np.sum(fg_ll[labels == 1]) - np.sum(bg_ll[labels == 0]) | |
| # Smoothness energy (count boundary edges) | |
| smooth_energy = 0 | |
| diff_h = (labels[:, 1:] != labels[:, :-1]).astype(float) | |
| diff_v = (labels[1:, :] != labels[:-1, :]).astype(float) | |
| smooth_energy = np.sum(diff_h * right_w) + np.sum(diff_v * down_w) | |
| total_energy = data_energy + smooth_energy | |
| energies.append(total_energy) | |
| print(f" Energy: {total_energy:.2f} (data={data_energy:.2f}, smooth={smooth_energy:.2f})") | |
| # Step 5: Update masks for next iteration | |
| current_fg_mask = labels.copy() | |
| current_bg_mask = (1 - labels).copy() | |
| # Preserve hard constraints from user annotations | |
| current_fg_mask[fg_mask == 1] = 1 | |
| current_bg_mask[bg_mask == 1] = 1 | |
| # Return the mask from the lowest-energy iteration (not necessarily the last) | |
| best_iter = int(np.argmin(energies)) | |
| print(f" Best iteration: {best_iter + 1} (energy={energies[best_iter]:.2f})") | |
| return all_masks[best_iter], all_masks, energies | |
| # ===================================================================== | |
| # Section 5: Artifact Mitigation & Refinement | |
| # ===================================================================== | |
| def remove_small_regions(mask: np.ndarray, min_area: int = 500) -> np.ndarray: | |
| """ | |
| Remove small isolated foreground and background regions using | |
| connected component analysis and morphological operations. | |
| """ | |
| cleaned = mask.copy().astype(np.uint8) | |
| # Remove small foreground regions | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(cleaned, connectivity=8) | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] < min_area: | |
| cleaned[labels == i] = 0 | |
| # Remove small background holes (invert, clean, invert back) | |
| inv = 1 - cleaned | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(inv, connectivity=8) | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] < min_area: | |
| cleaned[labels == i] = 1 | |
| return cleaned | |
| def smooth_boundaries(mask: np.ndarray, ksize: int = 5) -> np.ndarray: | |
| """ | |
| Smooth jagged segmentation boundaries using morphological closing | |
| followed by Gaussian blur and re-thresholding. | |
| """ | |
| m = mask.astype(np.uint8) * 255 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) | |
| # Close small gaps | |
| m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, kernel, iterations=1) | |
| # Open to remove thin protrusions | |
| m = cv2.morphologyEx(m, cv2.MORPH_OPEN, kernel, iterations=1) | |
| # Gaussian blur + threshold for smooth boundary | |
| m = cv2.GaussianBlur(m, (ksize * 2 + 1, ksize * 2 + 1), 0) | |
| m = (m > 127).astype(np.uint8) | |
| return m | |
| def ensure_intensity_consistency(mask: np.ndarray, image: np.ndarray, | |
| threshold: float = 30.0) -> np.ndarray: | |
| """ | |
| Intensity Consistency: re-label pixels near the boundary whose color is | |
| significantly closer to the opposite region's mean color. | |
| For each foreground pixel within a border band, if its color distance to | |
| the background mean is smaller than to the foreground mean, flip it to | |
| background (and vice versa). This corrects visually incoherent pixels | |
| that slipped through the graph cut due to weak data terms. | |
| """ | |
| refined = mask.copy().astype(np.uint8) | |
| img_f = image.astype(np.float32) | |
| fg_pixels = img_f[refined == 1] | |
| bg_pixels = img_f[refined == 0] | |
| if len(fg_pixels) == 0 or len(bg_pixels) == 0: | |
| return refined | |
| fg_mean = fg_pixels.mean(axis=0) # mean FG color (BGR) | |
| bg_mean = bg_pixels.mean(axis=0) # mean BG color (BGR) | |
| # Build a narrow band around the boundary (dilate XOR original) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| dilated = cv2.dilate(refined, kernel, iterations=1) | |
| eroded = cv2.erode(refined, kernel, iterations=1) | |
| band = (dilated - eroded).astype(bool) # True = boundary pixels | |
| for label, correct_mean, wrong_mean in [(1, fg_mean, bg_mean), | |
| (0, bg_mean, fg_mean)]: | |
| region_band = band & (refined == label) | |
| coords = np.argwhere(region_band) | |
| for (r, c) in coords: | |
| color = img_f[r, c] | |
| d_correct = float(np.linalg.norm(color - correct_mean)) | |
| d_wrong = float(np.linalg.norm(color - wrong_mean)) | |
| # Flip only when the pixel is clearly closer to the opposite mean | |
| if d_wrong < d_correct - threshold: | |
| refined[r, c] = 1 - label | |
| return refined | |
| def refine_segmentation(mask: np.ndarray, image: np.ndarray, | |
| min_area: int = None, smooth_ksize: int = 3) -> np.ndarray: | |
| """ | |
| Full refinement pipeline: | |
| 1. Remove small isolated regions (morphological noise removal) | |
| 2. Smooth jagged boundaries | |
| 3. Intensity consistency correction near boundaries | |
| min_area defaults to 0.1% of image pixels to scale with image size. | |
| smooth_ksize reduced to 3 to avoid distorting fine structures. | |
| """ | |
| print(" Refining segmentation...") | |
| if min_area is None: | |
| min_area = max(50, int(mask.size * 0.001)) # 0.1% of pixels | |
| refined = remove_small_regions(mask, min_area) | |
| refined = smooth_boundaries(refined, smooth_ksize) | |
| refined = ensure_intensity_consistency(refined, image) | |
| return refined | |
| # ===================================================================== | |
| # Section 6: Naive Segmentation (for comparison) | |
| # ===================================================================== | |
| def naive_thresholding_segmentation(image: np.ndarray) -> np.ndarray: | |
| """ | |
| Simple Otsu thresholding as a naive baseline for comparison. | |
| Returns raw Otsu mask; label alignment to graph cut is done after | |
| graph cut is computed (see align_naive_to_graphcut). | |
| """ | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| _, mask = cv2.threshold(gray, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return mask | |
| def align_naive_to_graphcut(naive_mask: np.ndarray, | |
| reference_mask: np.ndarray) -> np.ndarray: | |
| """ | |
| Align a naive mask's label convention to match the graph cut reference. | |
| Checks whether the mask or its inverse has more overlap with the reference, | |
| and returns whichever agrees more. This handles cases where Otsu/K-Means | |
| assign FG=1 to the bright region while graph cut assigns FG=1 to the object. | |
| """ | |
| overlap_normal = np.sum(naive_mask == reference_mask) | |
| overlap_inverted = np.sum((1 - naive_mask) == reference_mask) | |
| if overlap_inverted > overlap_normal: | |
| return 1 - naive_mask | |
| return naive_mask | |
| def naive_kmeans_segmentation(image: np.ndarray, k: int = 2) -> np.ndarray: | |
| """ | |
| K-Means clustering as another naive baseline. | |
| """ | |
| pixels = image.reshape(-1, 3).astype(np.float32) | |
| criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2) | |
| _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, | |
| cv2.KMEANS_RANDOM_CENTERS) | |
| # Assign the darker cluster as background | |
| labels = labels.reshape(image.shape[:2]) | |
| if centers[0].mean() > centers[1].mean(): | |
| labels = 1 - labels | |
| return labels.astype(np.uint8) | |
| # ===================================================================== | |
| # Section 7: Visualization | |
| # ===================================================================== | |
| def create_overlay(image: np.ndarray, mask: np.ndarray, | |
| color: tuple = (0, 255, 0), alpha: float = 0.4) -> np.ndarray: | |
| """ | |
| Overlay a segmentation mask on the original image. | |
| """ | |
| overlay = image.copy() | |
| colored = np.zeros_like(image) | |
| colored[:] = color | |
| region = mask.astype(bool) | |
| overlay[region] = cv2.addWeighted(image[region], 1 - alpha, | |
| colored[region], alpha, 0) | |
| # Draw contours | |
| contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(overlay, contours, -1, color, 2) | |
| return overlay | |
| def visualize_results(image: np.ndarray, fg_mask: np.ndarray, bg_mask: np.ndarray, | |
| raw_mask: np.ndarray, refined_mask: np.ndarray, | |
| naive_mask: np.ndarray, naive_kmeans_mask: np.ndarray, | |
| all_iter_masks: list, energies: list, | |
| output_dir: str, img_name: str): | |
| """ | |
| Generate comprehensive visualization of all results and save to disk. | |
| """ | |
| # --- Figure 1: Main comparison (2×4 grid) --- | |
| fig, axes = plt.subplots(2, 4, figsize=(24, 12)) | |
| fig.suptitle(f"Graph Cut Segmentation — {img_name}", fontsize=16, fontweight="bold") | |
| # Original + scribbles | |
| scribble_vis = image.copy() | |
| scribble_vis[fg_mask == 1] = [0, 255, 0] | |
| scribble_vis[bg_mask == 1] = [0, 0, 255] | |
| axes[0, 0].imshow(cv2.cvtColor(scribble_vis, cv2.COLOR_BGR2RGB)) | |
| axes[0, 0].set_title("Input + Annotations") | |
| axes[0, 0].axis("off") | |
| # Naive segmentation — Otsu | |
| axes[0, 1].imshow(naive_mask, cmap="gray") | |
| axes[0, 1].set_title("Naive: Otsu Thresholding") | |
| axes[0, 1].axis("off") | |
| # Naive segmentation — K-Means | |
| axes[0, 2].imshow(naive_kmeans_mask, cmap="gray") | |
| axes[0, 2].set_title("Naive: K-Means (k=2)") | |
| axes[0, 2].axis("off") | |
| # Raw graph cut | |
| axes[0, 3].imshow(raw_mask, cmap="gray") | |
| axes[0, 3].set_title("Raw Graph Cut") | |
| axes[0, 3].axis("off") | |
| # Refined mask | |
| axes[1, 0].imshow(refined_mask, cmap="gray") | |
| axes[1, 0].set_title("Refined Graph Cut") | |
| axes[1, 0].axis("off") | |
| # Overlay on original | |
| overlay = create_overlay(image, refined_mask) | |
| axes[1, 1].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)) | |
| axes[1, 1].set_title("Overlay on Original") | |
| axes[1, 1].axis("off") | |
| # Extracted foreground | |
| extracted = image.copy() | |
| extracted[refined_mask == 0] = [255, 255, 255] | |
| axes[1, 2].imshow(cv2.cvtColor(extracted, cv2.COLOR_BGR2RGB)) | |
| axes[1, 2].set_title("Extracted Foreground") | |
| axes[1, 2].axis("off") | |
| # Side-by-side comparison: best naive vs graph cut (with gap) | |
| h_cmp = naive_mask.shape[0] | |
| gap = np.full((h_cmp, 10), 255, dtype=np.uint8) # white divider | |
| compare = np.hstack([ | |
| naive_mask.astype(np.uint8) * 255, | |
| gap, | |
| refined_mask.astype(np.uint8) * 255 | |
| ]) | |
| axes[1, 3].imshow(compare, cmap="gray") | |
| axes[1, 3].set_title("Otsu vs Graph Cut (side-by-side)") | |
| axes[1, 3].axis("off") | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(output_dir, f"{img_name}_results.png"), dpi=150, | |
| bbox_inches="tight") | |
| plt.close(fig) | |
| # --- Figure 2: Iteration progression --- | |
| if len(all_iter_masks) > 1: | |
| n = len(all_iter_masks) | |
| fig2, axes2 = plt.subplots(1, n + 1, figsize=(5 * (n + 1), 5)) | |
| fig2.suptitle(f"Iterative Refinement — {img_name}", fontsize=14) | |
| for i, m in enumerate(all_iter_masks): | |
| axes2[i].imshow(m, cmap="gray") | |
| axes2[i].set_title(f"Iteration {i + 1}") | |
| axes2[i].axis("off") | |
| axes2[n].imshow(refined_mask, cmap="gray") | |
| axes2[n].set_title("After Post-Processing") | |
| axes2[n].axis("off") | |
| plt.tight_layout() | |
| fig2.savefig(os.path.join(output_dir, f"{img_name}_iterations.png"), dpi=150, | |
| bbox_inches="tight") | |
| plt.close(fig2) | |
| # --- Figure 3: Energy convergence --- | |
| if len(energies) > 1: | |
| fig3, ax3 = plt.subplots(figsize=(8, 5)) | |
| ax3.plot(range(1, len(energies) + 1), energies, "bo-", linewidth=2, markersize=8) | |
| ax3.set_xlabel("Iteration", fontsize=12) | |
| ax3.set_ylabel("Total Energy", fontsize=12) | |
| ax3.set_title(f"Energy Convergence — {img_name}", fontsize=14) | |
| ax3.grid(True, alpha=0.3) | |
| fig3.savefig(os.path.join(output_dir, f"{img_name}_energy.png"), dpi=150, | |
| bbox_inches="tight") | |
| plt.close(fig3) | |
| print(f" Visualizations saved to {output_dir}/") | |
| # ===================================================================== | |
| # Section 8: Full Pipeline | |
| # ===================================================================== | |
| def run_pipeline(image_path: str, output_dir: str = "outputs", | |
| n_iterations: int = 3, n_components: int = 5, | |
| gamma: float = 50.0, interactive: bool = True, | |
| fg_anno_path: str = None, bg_anno_path: str = None, | |
| auto_annotate: bool = False, | |
| max_dim: int = 400): | |
| """ | |
| Run the complete Graph Cut segmentation pipeline on a single image. | |
| Parameters: | |
| image_path: Path to input image | |
| output_dir: Directory to save results | |
| n_iterations: Number of graph-cut iterations | |
| n_components: GMM components | |
| gamma: Smoothness weight | |
| interactive: If True, open GUI for scribble annotation | |
| fg_anno_path: Path to pre-made foreground annotation mask | |
| bg_anno_path: Path to pre-made background annotation mask | |
| auto_annotate: If True, generate automatic center/border annotations | |
| max_dim: Resize image so largest dimension ≤ max_dim (for speed) | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| img_name = os.path.splitext(os.path.basename(image_path))[0] | |
| print(f"\n{'='*60}") | |
| print(f"Processing: {image_path}") | |
| print(f"{'='*60}") | |
| # Load image | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| print(f"ERROR: Could not load image '{image_path}'") | |
| return | |
| # Resize for tractability | |
| h, w = image.shape[:2] | |
| if max(h, w) > max_dim: | |
| scale = max_dim / max(h, w) | |
| image = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) | |
| print(f" Resized from ({w},{h}) to {image.shape[1]}x{image.shape[0]}") | |
| # Step 1: Obtain annotations | |
| print("Step 1: Obtaining annotations...") | |
| if interactive: | |
| annotator = ScribbleAnnotator(image) | |
| fg_mask, bg_mask = annotator.run() | |
| elif fg_anno_path and bg_anno_path: | |
| fg_mask, bg_mask = load_annotations_from_file(image.shape, fg_anno_path, bg_anno_path) | |
| elif auto_annotate: | |
| fg_mask, bg_mask = generate_auto_annotations(image) | |
| else: | |
| print(" No annotation source specified. Using auto-annotation.") | |
| fg_mask, bg_mask = generate_auto_annotations(image) | |
| fg_count = fg_mask.sum() | |
| bg_count = bg_mask.sum() | |
| print(f" Foreground scribble pixels: {fg_count}") | |
| print(f" Background scribble pixels: {bg_count}") | |
| if fg_count == 0 or bg_count == 0: | |
| print(" WARNING: Need both FG and BG annotations. Using auto-annotation.") | |
| fg_mask, bg_mask = generate_auto_annotations(image) | |
| # Step 2: Naive segmentation (baseline — both Otsu and K-Means) | |
| print("Step 2: Computing naive baseline segmentation...") | |
| naive_mask = naive_thresholding_segmentation(image) | |
| naive_kmeans_mask = naive_kmeans_segmentation(image) | |
| # Step 3: Iterative graph cut | |
| print("Step 3: Running iterative graph cut segmentation...") | |
| raw_mask, all_masks, energies = iterative_graph_cut( | |
| image, fg_mask, bg_mask, | |
| n_iterations=n_iterations, | |
| n_components=n_components, | |
| gamma=gamma, | |
| ) | |
| # Step 4: Refine segmentation | |
| print("Step 4: Refining segmentation (artifact mitigation)...") | |
| refined_mask = refine_segmentation(raw_mask, image) | |
| # Align naive masks to graph cut label convention (FG=1 must mean the same thing) | |
| naive_mask = align_naive_to_graphcut(naive_mask, refined_mask) | |
| naive_kmeans_mask = align_naive_to_graphcut(naive_kmeans_mask, refined_mask) | |
| # Step 5: Save outputs | |
| print("Step 5: Saving results...") | |
| cv2.imwrite(os.path.join(output_dir, f"{img_name}_raw_mask.png"), | |
| (raw_mask * 255).astype(np.uint8)) | |
| cv2.imwrite(os.path.join(output_dir, f"{img_name}_refined_mask.png"), | |
| (refined_mask * 255).astype(np.uint8)) | |
| overlay = create_overlay(image, refined_mask) | |
| cv2.imwrite(os.path.join(output_dir, f"{img_name}_overlay.png"), overlay) | |
| # Step 6: Visualize | |
| print("Step 6: Generating visualizations...") | |
| visualize_results(image, fg_mask, bg_mask, raw_mask, refined_mask, | |
| naive_mask, naive_kmeans_mask, all_masks, energies, output_dir, img_name) | |
| print(f" Done: {img_name}") | |
| return refined_mask | |
| # ===================================================================== | |
| # Section 9: Entry Point | |
| # ===================================================================== | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Graph Cut Image Segmentation Pipeline — CSL7360 Assignment 2", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Interactive annotation (opens GUI window) | |
| python graph_cut_segmentation.py --images img1.jpg img2.jpg img3.jpg | |
| # Automatic annotation (headless, no GUI) | |
| python graph_cut_segmentation.py --images img1.jpg --auto | |
| # Custom parameters | |
| python graph_cut_segmentation.py --images img1.jpg --iterations 5 --gamma 80 --gmm-components 7 | |
| """, | |
| ) | |
| parser.add_argument("--images", nargs="+", required=True, | |
| help="Paths to input images (at least 3 recommended)") | |
| parser.add_argument("--output", default="outputs", | |
| help="Output directory (default: outputs)") | |
| parser.add_argument("--iterations", type=int, default=3, | |
| help="Number of iterative optimization steps (default: 3)") | |
| parser.add_argument("--gmm-components", type=int, default=5, | |
| help="Number of GMM components per model (default: 5)") | |
| parser.add_argument("--gamma", type=float, default=50.0, | |
| help="Smoothness weight gamma (default: 50.0)") | |
| parser.add_argument("--max-dim", type=int, default=400, | |
| help="Max image dimension for processing (default: 400)") | |
| parser.add_argument("--auto", action="store_true", | |
| help="Use automatic center/border annotations (no GUI)") | |
| parser.add_argument("--no-interactive", action="store_true", | |
| help="Disable interactive GUI (use --auto or provide masks)") | |
| args = parser.parse_args() | |
| interactive = not (args.auto or args.no_interactive) | |
| for img_path in args.images: | |
| run_pipeline( | |
| image_path=img_path, | |
| output_dir=args.output, | |
| n_iterations=args.iterations, | |
| n_components=args.gmm_components, | |
| gamma=args.gamma, | |
| interactive=interactive, | |
| auto_annotate=args.auto, | |
| max_dim=args.max_dim, | |
| ) | |
| print(f"\nAll results saved in '{args.output}/' directory.") | |
| print("Pipeline complete.") | |
| if __name__ == "__main__": | |
| main() | |