import gradio as gr import numpy as np import torch from pathlib import Path from PIL import Image from inference import load_model, predict # -------- Hub / Weights Configuration -------- HUB_REPO_ID = "Suzyloubna/bridge-unetpp" # Hugging Face model repo (change if you renamed it) WEIGHTS_FILENAME = "MILESTONE_090_ACHIEVED_iou_0.9077.pth" WEIGHTS_PATH = Path(WEIGHTS_FILENAME) # Try to fetch weights from Hub if not present locally # (Requires 'huggingface-hub' in requirements.txt) try: if not WEIGHTS_PATH.exists(): print(f"Weights file {WEIGHTS_FILENAME} not found locally. Downloading from {HUB_REPO_ID} ...") from huggingface_hub import hf_hub_download hf_hub_download( repo_id=HUB_REPO_ID, filename=WEIGHTS_FILENAME, local_dir=".", # place file in current working directory local_dir_use_symlinks=False # make a real copy (Spaces friendly) ) if WEIGHTS_PATH.exists(): print("Download complete.") else: print("Download attempted but file still not found.") except Exception as dl_err: print(f"WARNING: Could not download weights automatically: {dl_err}") # ---------------- Runtime / Device ---------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CLASS_INFO = [ {"id": 0, "name": "background", "color": (0, 0, 0)}, {"id": 1, "name": "beton", "color": (0, 114, 189)}, {"id": 2, "name": "steel", "color": (200, 30, 30)}, ] COLOR_MAP = np.array([c["color"] for c in CLASS_INFO], dtype=np.uint8) # ---------------- Load Model (defensive) ---------------- model_load_error = None model = None try: if WEIGHTS_PATH.exists(): model = load_model(str(WEIGHTS_PATH), DEVICE) else: model_load_error = f"Weight file {WEIGHTS_FILENAME} not found after download attempt." except Exception as e: model_load_error = f"Model failed to load: {e}" # ---------------- Utility Functions ---------------- def resize_mask_to_original(mask_np_small: np.ndarray, original_shape): H, W = original_shape[:2] if mask_np_small.shape[:2] == (H, W): return mask_np_small pil_small = Image.fromarray(mask_np_small.astype(np.uint8)) pil_big = pil_small.resize((W, H), resample=Image.NEAREST) return np.array(pil_big) def overlay_mask(original_np: np.ndarray, mask_np: np.ndarray, alpha: float = 0.5): color_mask = COLOR_MAP[mask_np] blended = (1 - alpha) * original_np.astype(np.float32) + alpha * color_mask return blended.clip(0, 255).astype(np.uint8) def compute_class_stats(mask_np: np.ndarray): total = mask_np.size counts = np.bincount(mask_np.flatten(), minlength=len(COLOR_MAP)) stats = [] for info in CLASS_INFO: cid = info["id"] count = int(counts[cid]) if cid < len(counts) else 0 pct = (count / total * 100.0) if total else 0.0 stats.append({**info, "count": count, "pct": pct}) return stats def build_legend_html(stats): rows = [] for s in stats: r, g, b = s["color"] rows.append(f"""
Model not loaded.
", f"{model_load_error or 'Model error.'}") if image is None: return (None, None, "No image yet.
", "No mask.") pred_mask = predict(image, model, DEVICE) mask_small = pred_mask.numpy() H, W = image.shape[:2] if return_small: mask_np = mask_small if view_mode == "Overlay": pil_orig = Image.fromarray(image.astype(np.uint8)) base_img = np.array(pil_orig.resize(mask_small.shape[::-1], resample=Image.BILINEAR)) else: base_img = image else: mask_np = resize_mask_to_original(mask_small, (H, W)) base_img = image if view_mode == "Colored Mask": out_img = make_colored_mask_rgba(mask_np, bg_opacity) elif view_mode == "Overlay": blended = overlay_mask(base_img, mask_np, alpha=alpha) out_img = Image.fromarray(blended) else: # Raw Class Indices max_id = len(COLOR_MAP) - 1 norm = (mask_np / max_id * 255).astype(np.uint8) gray_rgb = np.stack([norm, norm, norm], axis=-1) out_img = Image.fromarray(gray_rgb) if show_colored: colored_only = make_colored_mask_rgba(mask_np, bg_opacity) else: colored_only = None stats = compute_class_stats(mask_np) legend_html = build_legend_html(stats) download_link = raw_mask_download(mask_np) download_html = f"Download Raw Mask (PNG)" return out_img, colored_only, legend_html, download_html def clear_outputs(): return None, None, "Cleared.
", "Upload an image and choose how you want to visualize the segmentation.
") with gr.Row(): with gr.Column(scale=5, elem_classes="panel glass left-panel"): input_image = gr.Image( label="Input Image", type="numpy", image_mode="RGB", sources=["upload", "clipboard", "webcam"] ) view_mode = gr.Radio( ["Colored Mask", "Overlay", "Raw Class Indices"], value="Colored Mask", label="View Mode", elem_id="view-mode-radio" ) alpha = gr.Slider( 0.0, 1.0, value=0.5, step=0.05, label="Overlay Opacity", elem_id="alpha-slider" ) bg_opacity = gr.Slider( 0.0, 1.0, value=1.0, step=0.05, label="Background Opacity (Colored Mask)", elem_id="bg-opacity-slider" ) show_colored = gr.Checkbox(value=True, label="Show 'Colored Mask (Always)' panel") return_small = gr.Checkbox(value=False, label="Return downsized (256x256) mask instead of original size") with gr.Row(): run_btn = gr.Button("Run Segmentation", elem_id="run-btn", variant="primary") clear_btn = gr.Button("Clear", elem_id="clear-btn") with gr.Column(scale=7, elem_classes="panel glass right-panel"): gr.Markdown("#### Results") output_image = gr.Image(label="Result View", type="pil") color_mask_output = gr.Image(label="Colored Mask (Always)", type="pil") legend_html = gr.HTML("Legend will appear here after segmentation.
") download_html = gr.HTML("