import gradio as gr import numpy as np import torch from pathlib import Path from PIL import Image from inference import load_model, predict # -------- Hub / Weights Configuration -------- HUB_REPO_ID = "Suzyloubna/bridge-unetpp" # Hugging Face model repo (change if you renamed it) WEIGHTS_FILENAME = "MILESTONE_090_ACHIEVED_iou_0.9077.pth" WEIGHTS_PATH = Path(WEIGHTS_FILENAME) # Try to fetch weights from Hub if not present locally # (Requires 'huggingface-hub' in requirements.txt) try: if not WEIGHTS_PATH.exists(): print(f"Weights file {WEIGHTS_FILENAME} not found locally. Downloading from {HUB_REPO_ID} ...") from huggingface_hub import hf_hub_download hf_hub_download( repo_id=HUB_REPO_ID, filename=WEIGHTS_FILENAME, local_dir=".", # place file in current working directory local_dir_use_symlinks=False # make a real copy (Spaces friendly) ) if WEIGHTS_PATH.exists(): print("Download complete.") else: print("Download attempted but file still not found.") except Exception as dl_err: print(f"WARNING: Could not download weights automatically: {dl_err}") # ---------------- Runtime / Device ---------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CLASS_INFO = [ {"id": 0, "name": "background", "color": (0, 0, 0)}, {"id": 1, "name": "beton", "color": (0, 114, 189)}, {"id": 2, "name": "steel", "color": (200, 30, 30)}, ] COLOR_MAP = np.array([c["color"] for c in CLASS_INFO], dtype=np.uint8) # ---------------- Load Model (defensive) ---------------- model_load_error = None model = None try: if WEIGHTS_PATH.exists(): model = load_model(str(WEIGHTS_PATH), DEVICE) else: model_load_error = f"Weight file {WEIGHTS_FILENAME} not found after download attempt." except Exception as e: model_load_error = f"Model failed to load: {e}" # ---------------- Utility Functions ---------------- def resize_mask_to_original(mask_np_small: np.ndarray, original_shape): H, W = original_shape[:2] if mask_np_small.shape[:2] == (H, W): return mask_np_small pil_small = Image.fromarray(mask_np_small.astype(np.uint8)) pil_big = pil_small.resize((W, H), resample=Image.NEAREST) return np.array(pil_big) def overlay_mask(original_np: np.ndarray, mask_np: np.ndarray, alpha: float = 0.5): color_mask = COLOR_MAP[mask_np] blended = (1 - alpha) * original_np.astype(np.float32) + alpha * color_mask return blended.clip(0, 255).astype(np.uint8) def compute_class_stats(mask_np: np.ndarray): total = mask_np.size counts = np.bincount(mask_np.flatten(), minlength=len(COLOR_MAP)) stats = [] for info in CLASS_INFO: cid = info["id"] count = int(counts[cid]) if cid < len(counts) else 0 pct = (count / total * 100.0) if total else 0.0 stats.append({**info, "count": count, "pct": pct}) return stats def build_legend_html(stats): rows = [] for s in stats: r, g, b = s["color"] rows.append(f"""
{s['id']}: {s['name']}
{s['count']} px {s['pct']:.2f}%
""") return f"""
Segmentation Legend
{''.join(rows)}
""" def raw_mask_download(mask_np: np.ndarray): from io import BytesIO import base64 img = Image.fromarray(mask_np.astype(np.uint8)) bio = BytesIO() img.save(bio, format="PNG") bio.seek(0) return "data:image/png;base64," + base64.b64encode(bio.read()).decode() def make_colored_mask_rgba(mask_np: np.ndarray, bg_opacity: float): """ Return an RGBA image where background class (0) has adjustable opacity. bg_opacity in [0,1]. """ rgb = COLOR_MAP[mask_np] # (H,W,3) H, W = mask_np.shape alpha_channel = np.full((H, W), 255, dtype=np.uint8) alpha_channel[mask_np == 0] = int(bg_opacity * 255) rgba = np.dstack([rgb, alpha_channel]).astype(np.uint8) return Image.fromarray(rgba, mode="RGBA") def run_segmentation(image, view_mode, alpha, show_colored, return_small, bg_opacity): if model is None: return (None, None, "

Model not loaded.

", f"{model_load_error or 'Model error.'}") if image is None: return (None, None, "

No image yet.

", "No mask.") pred_mask = predict(image, model, DEVICE) mask_small = pred_mask.numpy() H, W = image.shape[:2] if return_small: mask_np = mask_small if view_mode == "Overlay": pil_orig = Image.fromarray(image.astype(np.uint8)) base_img = np.array(pil_orig.resize(mask_small.shape[::-1], resample=Image.BILINEAR)) else: base_img = image else: mask_np = resize_mask_to_original(mask_small, (H, W)) base_img = image if view_mode == "Colored Mask": out_img = make_colored_mask_rgba(mask_np, bg_opacity) elif view_mode == "Overlay": blended = overlay_mask(base_img, mask_np, alpha=alpha) out_img = Image.fromarray(blended) else: # Raw Class Indices max_id = len(COLOR_MAP) - 1 norm = (mask_np / max_id * 255).astype(np.uint8) gray_rgb = np.stack([norm, norm, norm], axis=-1) out_img = Image.fromarray(gray_rgb) if show_colored: colored_only = make_colored_mask_rgba(mask_np, bg_opacity) else: colored_only = None stats = compute_class_stats(mask_np) legend_html = build_legend_html(stats) download_link = raw_mask_download(mask_np) download_html = f"Download Raw Mask (PNG)" return out_img, colored_only, legend_html, download_html def clear_outputs(): return None, None, "

Cleared.

", "" # ---------------- Load CSS ---------------- css_path = Path(__file__).parent / "style.css" css_text = css_path.read_text(encoding="utf-8") # ---------------- Interface Layout ---------------- with gr.Blocks(css=css_text, title="Hey Inspector • Drone Bridge Image Segmentation") as demo: gr.HTML("""

Hey Inspector • Drone Bridge Image Segmentation

""") if model_load_error: gr.HTML(f"
{model_load_error}
") gr.HTML("

Upload an image and choose how you want to visualize the segmentation.

") with gr.Row(): with gr.Column(scale=5, elem_classes="panel glass left-panel"): input_image = gr.Image( label="Input Image", type="numpy", image_mode="RGB", sources=["upload", "clipboard", "webcam"] ) view_mode = gr.Radio( ["Colored Mask", "Overlay", "Raw Class Indices"], value="Colored Mask", label="View Mode", elem_id="view-mode-radio" ) alpha = gr.Slider( 0.0, 1.0, value=0.5, step=0.05, label="Overlay Opacity", elem_id="alpha-slider" ) bg_opacity = gr.Slider( 0.0, 1.0, value=1.0, step=0.05, label="Background Opacity (Colored Mask)", elem_id="bg-opacity-slider" ) show_colored = gr.Checkbox(value=True, label="Show 'Colored Mask (Always)' panel") return_small = gr.Checkbox(value=False, label="Return downsized (256x256) mask instead of original size") with gr.Row(): run_btn = gr.Button("Run Segmentation", elem_id="run-btn", variant="primary") clear_btn = gr.Button("Clear", elem_id="clear-btn") with gr.Column(scale=7, elem_classes="panel glass right-panel"): gr.Markdown("#### Results") output_image = gr.Image(label="Result View", type="pil") color_mask_output = gr.Image(label="Colored Mask (Always)", type="pil") legend_html = gr.HTML("

Legend will appear here after segmentation.

") download_html = gr.HTML("") gr.Markdown(""" **Tips** - Background Opacity affects only Colored Mask outputs (main and the 'always' panel). - Set it to 0 to hide background and emphasize target classes. - Overlay mode ignores the background opacity slider (uses original image + colored mask). - Raw Class Indices is a grayscale class map. """) gr.HTML(""" """) run_btn.click( fn=run_segmentation, inputs=[input_image, view_mode, alpha, show_colored, return_small, bg_opacity], outputs=[output_image, color_mask_output, legend_html, download_html] ) clear_btn.click( fn=clear_outputs, inputs=None, outputs=[output_image, color_mask_output, legend_html, download_html] ) if __name__ == "__main__": demo.launch()