# ── Backend override: must happen before any pyplot import ───────────── import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") # Monkey-patch matplotlib.use so graph_cut_segmentation.py's TkAgg call is a no-op _real_use = matplotlib.use matplotlib.use = lambda *a, **kw: None import gradio as gr import numpy as np import cv2 import io from PIL import Image from graph_cut_segmentation import ( iterative_graph_cut, refine_segmentation, naive_thresholding_segmentation, naive_kmeans_segmentation, align_naive_to_graphcut, create_overlay, generate_auto_annotations, ) matplotlib.use = _real_use # restore after import # ══════════════════════════════════════════════════════════════════════ # Helpers # ══════════════════════════════════════════════════════════════════════ def to_numpy(img): if img is None: return None if isinstance(img, np.ndarray): return img.astype(np.uint8) return np.array(img).astype(np.uint8) def extract_mask(editor_out, target_hw): h, w = target_hw blank = np.zeros((h, w), dtype=np.uint8) if editor_out is None: return blank layers = editor_out.get("layers", []) if isinstance(editor_out, dict) else [editor_out] if not layers: return blank combined = blank.copy() for layer in layers: if layer is None: continue arr = to_numpy(layer) if arr is None: continue if arr.ndim == 3 and arr.shape[2] == 4: alpha = arr[:, :, 3] elif arr.ndim == 3: alpha = np.any(arr > 20, axis=2).astype(np.uint8) * 255 else: alpha = arr alpha = alpha.astype(np.uint8) if alpha.shape != (h, w): alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_NEAREST) combined = np.maximum(combined, (alpha > 10).astype(np.uint8)) return combined def make_energy_plot(energies): fig, ax = plt.subplots(figsize=(7, 4), facecolor="#FFF8F3") ax.set_facecolor("#FFF8F3") iters = list(range(1, len(energies) + 1)) ax.plot(iters, energies, "o-", color="#E8845A", linewidth=2.5, markersize=9, markerfacecolor="#C85E35", markeredgecolor="white", markeredgewidth=1.5) best_i = int(np.argmin(energies)) ax.axvline(best_i + 1, color="#A0522D", linestyle="--", alpha=0.65, label=f"Best iteration: {best_i + 1}") ax.legend(fontsize=10, framealpha=0.7, edgecolor="#D4B896") ax.set_xlabel("Iteration", fontsize=12, color="#3D2B1F") ax.set_ylabel("Total Energy", fontsize=12, color="#3D2B1F") ax.set_title("Energy Convergence", fontsize=14, fontweight="bold", color="#3D2B1F") ax.grid(True, alpha=0.25, color="#C4A882") for spine in ax.spines.values(): spine.set_edgecolor("#D4B896") ax.tick_params(colors="#5C3D2E") plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor="#FFF8F3") plt.close(fig) buf.seek(0) return Image.open(buf).copy() def make_iterations_plot(all_masks, refined_mask): n = len(all_masks) cols = n + 1 fig, axes = plt.subplots(1, cols, figsize=(4 * cols, 4), facecolor="#FFF8F3") if cols == 1: axes = [axes] for i, m in enumerate(all_masks): axes[i].imshow(m, cmap="gray") axes[i].set_title(f"Iteration {i + 1}", fontsize=11, color="#3D2B1F") axes[i].axis("off") axes[n].imshow(refined_mask, cmap="gray") axes[n].set_title("Post-Processed", fontsize=11, color="#C85E35", fontweight="bold") axes[n].axis("off") plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#FFF8F3") plt.close(fig) buf.seek(0) return Image.open(buf).copy() # ══════════════════════════════════════════════════════════════════════ # Core segmentation # ══════════════════════════════════════════════════════════════════════ def run_segmentation(fg_editor, bg_editor, uploaded_image, max_dim, iterations, gamma, n_components, use_auto): if uploaded_image is None: raise gr.Error("Please upload an image first.") img_arr = to_numpy(uploaded_image) if img_arr.ndim == 2: img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB) image_bgr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR) h, w = image_bgr.shape[:2] max_dim = int(max_dim) if max(h, w) > max_dim: scale = max_dim / max(h, w) image_bgr = cv2.resize(image_bgr, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) h, w = image_bgr.shape[:2] if use_auto: fg_mask, bg_mask = generate_auto_annotations(image_bgr) else: fg_mask = extract_mask(fg_editor, (h, w)) bg_mask = extract_mask(bg_editor, (h, w)) if fg_mask.sum() == 0 or bg_mask.sum() == 0: raise gr.Error( "Both foreground (green) and background (red) scribbles are required. " "Draw on each canvas, or enable Auto Annotation." ) raw_mask, all_masks, energies = iterative_graph_cut( image_bgr, fg_mask, bg_mask, n_iterations=int(iterations), n_components=int(n_components), gamma=float(gamma), ) refined_mask = refine_segmentation(raw_mask, image_bgr) naive_otsu = align_naive_to_graphcut(naive_thresholding_segmentation(image_bgr), refined_mask) naive_km = align_naive_to_graphcut(naive_kmeans_segmentation(image_bgr), refined_mask) annot = image_bgr.copy() annot[fg_mask == 1] = [0, 255, 0] annot[bg_mask == 1] = [0, 0, 255] def gray3(m): return cv2.cvtColor((m * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB) ext = image_bgr.copy() ext[refined_mask == 0] = [255, 255, 255] return ( cv2.cvtColor(annot, cv2.COLOR_BGR2RGB), gray3(raw_mask), gray3(refined_mask), cv2.cvtColor(create_overlay(image_bgr, refined_mask), cv2.COLOR_BGR2RGB), cv2.cvtColor(ext, cv2.COLOR_BGR2RGB), gray3(naive_otsu), gray3(naive_km), make_energy_plot(energies), make_iterations_plot(all_masks, refined_mask), ) def update_editors(img): if img is None: return gr.update(value=None), gr.update(value=None) pil = Image.fromarray(img.astype(np.uint8)) return gr.update(value=pil), gr.update(value=pil) # ══════════════════════════════════════════════════════════════════════ # CSS — forces light warm theme over every Gradio 6.x element # ══════════════════════════════════════════════════════════════════════ CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap'); /* ══ NUCLEAR LIGHT MODE — overrides every Gradio 6.x dark element ══ */ :root { /* ── Warm palette ── */ --warm-bg: #FFF8F3; --warm-card: #FFFFFF; --warm-border: #EDD9C8; --warm-text: #3D2B1F; --warm-muted: #7A4F3A; --warm-accent: #E8845A; --warm-accent2: #C85E35; --warm-light: #FFF3EC; /* ── Override ALL Gradio v6 CSS custom-property theme tokens ── */ --body-background-fill: #FFF8F3; --background-fill-primary: #FFFFFF; --background-fill-secondary: #FFF8F3; --block-background-fill: #FFFFFF; --block-border-color: #EDD9C8; --block-border-width: 1px; --block-label-background-fill: #FFF3EC; --block-label-border-color: #EDD9C8; --block-label-text-color: #3D2B1F; --block-label-text-weight: 700; --block-title-text-color: #3D2B1F; --block-title-text-weight: 700; --panel-background-fill: #FFFFFF; --panel-border-color: #EDD9C8; --border-color-primary: #EDD9C8; --border-color-accent: #E8845A; --input-background-fill: #FFF3EC; --input-border-color: #EDD9C8; --input-border-color-focus: #E8845A; --input-placeholder-color: #B09080; --color-accent: #E8845A; --color-accent-soft: #FFF3EC; --button-small-text-color: #3D2B1F; --button-secondary-background-fill: #FFF3EC; --button-secondary-text-color: #3D2B1F; --button-secondary-border-color: #EDD9C8; --button-secondary-background-fill-hover: #EDD9C8; --neutral-50: #FFFFFF; --neutral-100: #FFF8F3; --neutral-200: #FFF3EC; --neutral-300: #EDD9C8; --neutral-400: #D4B896; --neutral-500: #B09080; --neutral-600: #7A4F3A; --neutral-700: #5C3D2E; --neutral-800: #3D2B1F; --neutral-900: #2A1810; --neutral-950: #1A0E08; color-scheme: light !important; } /* ── Force light color-scheme on everything ── */ *, *::before, *::after { box-sizing: border-box; color-scheme: light !important; } /* ── Critical: neutralise Gradio dark-mode media-query overrides ── */ @media (prefers-color-scheme: dark) { :root { color-scheme: light !important; --body-background-fill: #FFF8F3 !important; --background-fill-primary: #FFFFFF !important; --background-fill-secondary: #FFF8F3 !important; --block-background-fill: #FFFFFF !important; --block-border-color: #EDD9C8 !important; --block-label-background-fill: #FFF3EC !important; --block-label-border-color: #EDD9C8 !important; --block-label-text-color: #3D2B1F !important; --block-title-text-color: #3D2B1F !important; --panel-background-fill: #FFFFFF !important; --panel-border-color: #EDD9C8 !important; --border-color-primary: #EDD9C8 !important; --input-background-fill: #FFF3EC !important; --input-border-color: #EDD9C8 !important; --neutral-50: #FFFFFF !important; --neutral-100: #FFF8F3 !important; --neutral-200: #FFF3EC !important; --neutral-300: #EDD9C8 !important; --neutral-700: #5C3D2E !important; --neutral-800: #3D2B1F !important; --neutral-900: #2A1810 !important; --neutral-950: #1A0E08 !important; } html, body, .gradio-container, .block, .panel, .form, .wrap { background-color: var(--warm-bg) !important; color: var(--warm-text) !important; } } /* ── Page base ── */ html, body, .gradio-container, .gradio-container *:not(button):not(input):not(canvas):not(svg):not(path) { font-family: 'Inter', sans-serif !important; } body, .gradio-container, .gradio-container > .main, .gradio-container > .main > .wrap, .gap, footer { background-color: var(--warm-bg) !important; color: var(--warm-text) !important; } /* ── Kill ALL dark backgrounds on blocks/containers ── */ .block, .block.padded, .panel, .form, .wrap, .inner-wrap, .contain, .box, .input-wrapper, .output-class, .preview, .image-container, .upload-container, .component-wrapper, [class*="svelte-"], .row, .col { background-color: var(--warm-card) !important; border-color: var(--warm-border) !important; } /* ── ALL text everywhere ── */ span, p, div, h1, h2, h3, h4, h5, label, legend, li, em, strong { color: var(--warm-text) !important; } /* ── Gradio block labels (the dark top bar on every component) ── */ .block > label, .block > .label-wrap, .block > div > label, .block > div > .label-wrap, label[data-testid], .label-wrap, .block label, label.svelte-1b6s6s, label.svelte-1ydv1sl, [class*="label-wrap"], [class*="block-label"] { background: var(--warm-light) !important; background-color: var(--warm-light) !important; border-color: var(--warm-border) !important; color: var(--warm-text) !important; } .block label span, .label-wrap span, .block label > span, [class*="label-wrap"] span { color: var(--warm-text) !important; font-weight: 700 !important; font-size: 13px !important; } .block label svg, .label-wrap svg, [class*="label-wrap"] svg { color: var(--warm-accent) !important; fill: var(--warm-accent) !important; } /* ── ALL non-run buttons (toolbar, upload icons, etc.) ── */ .block button:not(#run-btn), .gradio-container button:not(#run-btn), [class*="toolbar"] button, [class*="toolbox"] button, [class*="tool-"] button { background: #FFFFFF !important; background-color: #FFFFFF !important; color: var(--warm-text) !important; border: 1px solid var(--warm-border) !important; border-radius: 8px !important; } .block button:not(#run-btn):hover, .gradio-container button:not(#run-btn):hover { background: var(--warm-light) !important; border-color: var(--warm-accent) !important; } /* Button SVG icons — make them warm-dark so visible on white bg */ .block button:not(#run-btn) svg, .gradio-container button:not(#run-btn) svg, [class*="toolbar"] button svg, [class*="toolbox"] button svg { fill: var(--warm-text) !important; stroke: var(--warm-text) !important; color: var(--warm-text) !important; } /* ── Toolbar containers (pill strips around icon groups) ── */ [class*="toolbar"], [class*="toolbox"], [class*="tool-bar"], [role="toolbar"], [role="group"], .tools, .tool-strip, .controls, .actions { background: #FFFFFF !important; background-color: #FFFFFF !important; border: 1px solid var(--warm-border) !important; border-radius: 10px !important; } /* Transparent pass-through divs inside blocks */ .block > div > div { background: transparent !important; } /* Inline dark-background style overrides */ .gradio-container [style*="background: rgb(0"], .gradio-container [style*="background: rgba(0"], .gradio-container [style*="background-color: rgb(0"], .gradio-container [style*="background-color: rgba(0"], .gradio-container [style*="background:#"], .gradio-container [style*="background: #0"], .gradio-container [style*="background: #1"], .gradio-container [style*="background: #2"] { background: var(--warm-light) !important; background-color: var(--warm-light) !important; } /* ── ImageEditor: every layer light ── */ [data-testid*="image"], [data-testid*="editor"], .image-editor, .gradio-imageeditor, [class*="image-editor"], [class*="imageeditor"] { background: #F5EDE4 !important; } [data-testid*="image"] canvas, .image-editor canvas { background: #FAFAFA !important; } /* ImageEditor toolbar pill — the dark strip on the left/top */ [data-testid*="image"] > div, [data-testid*="editor"] > div, .image-editor > div, [class*="image-editor"] > div { background: #FFFFFF !important; background-color: #FFFFFF !important; } .image-editor button, [class*="image-editor"] button, [data-testid*="image"] button, [data-testid*="editor"] button { background: #FFFFFF !important; background-color: #FFFFFF !important; color: var(--warm-text) !important; border: 1px solid var(--warm-border) !important; } .image-editor button svg, [class*="image-editor"] button svg, [data-testid*="image"] button svg, [data-testid*="editor"] button svg { fill: var(--warm-text) !important; stroke: var(--warm-text) !important; color: var(--warm-text) !important; } /* ── Inputs ── */ input[type="number"], input[type="text"], textarea { background: var(--warm-light) !important; border: 1.5px solid var(--warm-border) !important; color: var(--warm-text) !important; border-radius: 8px !important; } input[type="range"] { accent-color: var(--warm-accent) !important; } input[type="checkbox"] { accent-color: var(--warm-accent) !important; } /* ── RUN BUTTON ── */ #run-btn { background: linear-gradient(135deg, #E8845A 0%, #C85E35 100%) !important; color: #FFFFFF !important; border: none !important; border-radius: 14px !important; font-size: 18px !important; font-weight: 800 !important; padding: 18px 0 !important; box-shadow: 0 8px 28px rgba(200,94,53,0.42) !important; transition: all 0.2s ease !important; width: 100% !important; } #run-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 12px 36px rgba(200,94,53,0.55) !important; } /* ── Custom HTML elements ── */ .hero-wrap { text-align: center; padding: 40px 24px 28px; background: linear-gradient(160deg, #FFF8F3 0%, #FDECD8 100%) !important; border-radius: 20px; border: 1px solid var(--warm-border); margin-bottom: 20px; } .hero-badge { display: inline-block; background: linear-gradient(135deg, #F2C4A0, #EDA882) !important; color: #7A3B1E !important; border-radius: 30px; padding: 6px 20px; font-size: 11px !important; font-weight: 800 !important; letter-spacing: 1.2px; text-transform: uppercase; margin-bottom: 18px; } .hero-title { font-size: 38px !important; font-weight: 800 !important; color: #3D2B1F !important; margin: 0 0 12px !important; line-height: 1.1 !important; display: block; } .hero-sub { font-size: 15.5px !important; color: #7A4F3A !important; max-width: 600px; margin: 0 auto !important; line-height: 1.7 !important; display: block; } .sec-header { display: flex !important; align-items: center; gap: 12px; padding: 18px 0 14px; border-bottom: 2px solid var(--warm-border); margin-bottom: 18px; background: transparent !important; } .step-num { width: 32px; height: 32px; background: linear-gradient(135deg, #E8845A, #C85E35) !important; color: #FFFFFF !important; border-radius: 50%; display: inline-flex !important; align-items: center; justify-content: center; font-size: 14px !important; font-weight: 800 !important; flex-shrink: 0; box-shadow: 0 3px 10px rgba(200,94,53,0.35); } .sec-title-text { font-size: 18px !important; font-weight: 800 !important; color: #3D2B1F !important; } .sec-sub { font-size: 13px !important; color: #7A4F3A !important; font-weight: 400 !important; } .tips-box { background: var(--warm-light) !important; border-left: 4px solid var(--warm-accent); border-radius: 0 12px 12px 0; padding: 16px 18px; font-size: 13.5px !important; color: #5C3D2E !important; line-height: 1.75; } .tips-box b { color: var(--warm-accent2) !important; } .anno-label { text-align: center; font-size: 13.5px !important; font-weight: 800 !important; padding: 10px 0 8px; border-radius: 8px; margin-bottom: 8px; display: block !important; } .anno-fg { background: #E8F5E9 !important; color: #1B5E20 !important; border: 1.5px solid #A5D6A7; } .anno-bg { background: #FFEBEE !important; color: #B71C1C !important; border: 1.5px solid #EF9A9A; } .hint-text { font-size: 12.5px !important; color: var(--warm-muted) !important; line-height: 1.6; padding: 10px 4px 0; display: block; } .warm-divider { border: none; border-top: 1.5px solid var(--warm-border); margin: 6px 0 24px; } .footer-wrap { text-align: center; padding: 28px 0 12px; font-size: 13px !important; color: #B09080 !important; border-top: 1px solid var(--warm-border); margin-top: 12px; background: transparent !important; } """ # ══════════════════════════════════════════════════════════════════════ # UI # ══════════════════════════════════════════════════════════════════════ with gr.Blocks( css=CSS, title="Graph Cut Segmentation", theme=gr.themes.Soft( primary_hue=gr.themes.colors.orange, secondary_hue=gr.themes.colors.amber, neutral_hue=gr.themes.colors.stone, font=gr.themes.GoogleFont("Inter"), ).set( body_background_fill="#FFF8F3", body_text_color="#3D2B1F", block_background_fill="#FFFFFF", block_border_color="#EDD9C8", block_label_background_fill="#FFF3EC", block_label_text_color="#3D2B1F", block_label_text_weight="700", block_title_text_color="#3D2B1F", block_title_text_weight="700", input_background_fill="#FFF3EC", input_border_color="#EDD9C8", input_border_color_focus="#E8845A", input_placeholder_color="#B09080", checkbox_background_color="#FFF3EC", checkbox_background_color_selected="#E8845A", checkbox_border_color="#EDD9C8", checkbox_label_text_color="#3D2B1F", slider_color="#E8845A", button_primary_background_fill="#E8845A", button_primary_background_fill_hover="#C85E35", button_primary_text_color="#FFFFFF", button_primary_border_color="transparent", button_secondary_background_fill="#FFF3EC", button_secondary_text_color="#3D2B1F", border_color_primary="#EDD9C8", border_color_accent="#E8845A", shadow_drop="0 2px 12px rgba(180,110,60,0.08)", color_accent="#E8845A", color_accent_soft="#FFF3EC", link_text_color="#E8845A", ), ) as demo: # ── Hero ────────────────────────────────────────────────────────── gr.HTML("""
Graph Cut  ·  GMM  ·  PyMaxflow  ·  Energy Minimisation
🍂 Graph Cut Image Segmentation
Upload an image, paint foreground & background scribbles, and let energy-minimisation Graph Cut isolate your object — powered by Gaussian Mixture Models and iterative refinement.
""") # ── STEP 1: Upload ──────────────────────────────────────────────── gr.HTML("""
1 Upload Image
""") with gr.Row(equal_height=True): with gr.Column(scale=3): img_upload = gr.Image( label="Input Image", type="numpy", sources=["upload", "clipboard"], height=280, ) with gr.Column(scale=1): gr.HTML("""
Tips for best results

✅ Clear object boundary from background
✅ Natural photos, portraits, products
✅ Any resolution — resized automatically
✅ JPEG or PNG

⚡ Higher contrast = cleaner segmentation
⚡ Draw scribbles in diverse colour areas
""") gr.HTML('
') # ── STEP 2: Parameters ──────────────────────────────────────────── gr.HTML("""
2 Configure Parameters
""") with gr.Row(): max_dim = gr.Slider(200, 800, value=400, step=50, label="Max Dimension (px)", info="Larger = more detail but slower. 400 recommended.") iterations = gr.Slider(1, 10, value=3, step=1, label="Iterations", info="GMM re-estimation rounds. 3–5 is optimal.") with gr.Row(): gamma = gr.Slider(10, 200, value=50, step=5, label="Smoothness γ", info="Higher = smoother boundary. Default 50.") n_comp = gr.Slider(2, 10, value=5, step=1, label="GMM Components K", info="Colour clusters per region. 5 fits most images.") use_auto = gr.Checkbox( label="⚡ Auto Annotation — skip drawing (uses centre/border heuristic)", value=False, ) gr.HTML('
') # ── STEP 3: Annotate ────────────────────────────────────────────── gr.HTML("""
3 Annotate — skip this step if Auto Annotation is enabled above
""") with gr.Row(): with gr.Column(): gr.HTML('
🟢 FOREGROUND  — paint over the object to keep
') fg_editor = gr.ImageEditor( label="Foreground Canvas", show_label=False, height=380, brush=gr.Brush( default_size=14, default_color="#00CC44", colors=["#00CC44", "#00FF00", "#22AA55"], color_mode="defaults", ), ) gr.HTML("""
✏️ Draw green strokes across different parts of the object (body, edges, texture areas) for a richer GMM colour model.
""") with gr.Column(): gr.HTML('
🔴 BACKGROUND  — paint over background areas
') bg_editor = gr.ImageEditor( label="Background Canvas", show_label=False, height=380, brush=gr.Brush( default_size=14, default_color="#FF3333", colors=["#FF3333", "#CC0000", "#FF6666"], color_mode="defaults", ), ) gr.HTML("""
✏️ Draw red strokes on background regions. Cover varied textures (sky, floor, wall…) for better discrimination.
""") img_upload.change( fn=update_editors, inputs=img_upload, outputs=[fg_editor, bg_editor], ) gr.HTML('
') # ── RUN ─────────────────────────────────────────────────────────── run_btn = gr.Button( "▶ Run Graph Cut Segmentation", elem_id="run-btn", variant="primary", ) gr.HTML('
') # ── STEP 4: Results ─────────────────────────────────────────────── gr.HTML("""
4 Segmentation Results
""") with gr.Row(): out_annot = gr.Image(label="📌 Input + Annotations", height=260) out_raw = gr.Image(label="✂️ Raw Graph Cut", height=260) out_refined = gr.Image(label="✨ Refined Graph Cut", height=260) with gr.Row(): out_overlay = gr.Image(label="🎨 Overlay on Original", height=260) out_extract = gr.Image(label="🖼️ Extracted Foreground", height=260) out_otsu = gr.Image(label="📊 Naive: Otsu", height=260) out_km = gr.Image(label="📊 Naive: K-Means (k=2)", height=260) gr.HTML('
') # ── STEP 5: Analysis ────────────────────────────────────────────── gr.HTML("""
5 Convergence & Iteration Analysis
""") with gr.Row(): out_energy = gr.Image(label="📈 Energy Convergence", height=360) out_iters = gr.Image(label="🔄 Iterative Mask Progression", height=360) # ── Wire ────────────────────────────────────────────────────────── run_btn.click( fn=run_segmentation, inputs=[fg_editor, bg_editor, img_upload, max_dim, iterations, gamma, n_comp, use_auto], outputs=[out_annot, out_raw, out_refined, out_overlay, out_extract, out_otsu, out_km, out_energy, out_iters], show_progress="full", ) # ── Footer ──────────────────────────────────────────────────────── gr.HTML(""" """) if __name__ == "__main__": demo.launch()