|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from inference import load_model, predict |
|
|
|
|
|
|
|
|
HUB_REPO_ID = "Suzyloubna/bridge-unetpp" |
|
|
WEIGHTS_FILENAME = "MILESTONE_090_ACHIEVED_iou_0.9077.pth" |
|
|
WEIGHTS_PATH = Path(WEIGHTS_FILENAME) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if not WEIGHTS_PATH.exists(): |
|
|
print(f"Weights file {WEIGHTS_FILENAME} not found locally. Downloading from {HUB_REPO_ID} ...") |
|
|
from huggingface_hub import hf_hub_download |
|
|
hf_hub_download( |
|
|
repo_id=HUB_REPO_ID, |
|
|
filename=WEIGHTS_FILENAME, |
|
|
local_dir=".", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
if WEIGHTS_PATH.exists(): |
|
|
print("Download complete.") |
|
|
else: |
|
|
print("Download attempted but file still not found.") |
|
|
except Exception as dl_err: |
|
|
print(f"WARNING: Could not download weights automatically: {dl_err}") |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
CLASS_INFO = [ |
|
|
{"id": 0, "name": "background", "color": (0, 0, 0)}, |
|
|
{"id": 1, "name": "beton", "color": (0, 114, 189)}, |
|
|
{"id": 2, "name": "steel", "color": (200, 30, 30)}, |
|
|
] |
|
|
COLOR_MAP = np.array([c["color"] for c in CLASS_INFO], dtype=np.uint8) |
|
|
|
|
|
|
|
|
model_load_error = None |
|
|
model = None |
|
|
try: |
|
|
if WEIGHTS_PATH.exists(): |
|
|
model = load_model(str(WEIGHTS_PATH), DEVICE) |
|
|
else: |
|
|
model_load_error = f"Weight file {WEIGHTS_FILENAME} not found after download attempt." |
|
|
except Exception as e: |
|
|
model_load_error = f"Model failed to load: {e}" |
|
|
|
|
|
|
|
|
def resize_mask_to_original(mask_np_small: np.ndarray, original_shape): |
|
|
H, W = original_shape[:2] |
|
|
if mask_np_small.shape[:2] == (H, W): |
|
|
return mask_np_small |
|
|
pil_small = Image.fromarray(mask_np_small.astype(np.uint8)) |
|
|
pil_big = pil_small.resize((W, H), resample=Image.NEAREST) |
|
|
return np.array(pil_big) |
|
|
|
|
|
def overlay_mask(original_np: np.ndarray, mask_np: np.ndarray, alpha: float = 0.5): |
|
|
color_mask = COLOR_MAP[mask_np] |
|
|
blended = (1 - alpha) * original_np.astype(np.float32) + alpha * color_mask |
|
|
return blended.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
def compute_class_stats(mask_np: np.ndarray): |
|
|
total = mask_np.size |
|
|
counts = np.bincount(mask_np.flatten(), minlength=len(COLOR_MAP)) |
|
|
stats = [] |
|
|
for info in CLASS_INFO: |
|
|
cid = info["id"] |
|
|
count = int(counts[cid]) if cid < len(counts) else 0 |
|
|
pct = (count / total * 100.0) if total else 0.0 |
|
|
stats.append({**info, "count": count, "pct": pct}) |
|
|
return stats |
|
|
|
|
|
def build_legend_html(stats): |
|
|
rows = [] |
|
|
for s in stats: |
|
|
r, g, b = s["color"] |
|
|
rows.append(f""" |
|
|
<div class="legend-item" aria-label="Class {s['name']}"> |
|
|
<div class="legend-color" style="--c: rgb({r},{g},{b});"></div> |
|
|
<div class="legend-meta"> |
|
|
<div class="legend-name">{s['id']}: {s['name']}</div> |
|
|
<div class="legend-stats"> |
|
|
<span class="legend-count">{s['count']} px</span> |
|
|
<span class="legend-pct">{s['pct']:.2f}%</span> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
return f""" |
|
|
<div class="legend-wrapper" id="legend-wrapper"> |
|
|
<div class="legend-header"> |
|
|
<span>Segmentation Legend</span> |
|
|
<button onclick="toggleLegend()" class="legend-toggle-btn" aria-label="Collapse legend">⤢</button> |
|
|
</div> |
|
|
<div id="legend-body" class="legend-body expanded"> |
|
|
{''.join(rows)} |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
def raw_mask_download(mask_np: np.ndarray): |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
img = Image.fromarray(mask_np.astype(np.uint8)) |
|
|
bio = BytesIO() |
|
|
img.save(bio, format="PNG") |
|
|
bio.seek(0) |
|
|
return "data:image/png;base64," + base64.b64encode(bio.read()).decode() |
|
|
|
|
|
def make_colored_mask_rgba(mask_np: np.ndarray, bg_opacity: float): |
|
|
""" |
|
|
Return an RGBA image where background class (0) has adjustable opacity. |
|
|
bg_opacity in [0,1]. |
|
|
""" |
|
|
rgb = COLOR_MAP[mask_np] |
|
|
H, W = mask_np.shape |
|
|
alpha_channel = np.full((H, W), 255, dtype=np.uint8) |
|
|
alpha_channel[mask_np == 0] = int(bg_opacity * 255) |
|
|
rgba = np.dstack([rgb, alpha_channel]).astype(np.uint8) |
|
|
return Image.fromarray(rgba, mode="RGBA") |
|
|
|
|
|
def run_segmentation(image, view_mode, alpha, show_colored, return_small, bg_opacity): |
|
|
if model is None: |
|
|
return (None, None, "<p class='legend-empty'>Model not loaded.</p>", |
|
|
f"<span style='color:#ff8080'>{model_load_error or 'Model error.'}</span>") |
|
|
if image is None: |
|
|
return (None, None, "<p class='legend-empty'>No image yet.</p>", |
|
|
"<span style='opacity:0.6'>No mask.</span>") |
|
|
|
|
|
pred_mask = predict(image, model, DEVICE) |
|
|
mask_small = pred_mask.numpy() |
|
|
|
|
|
H, W = image.shape[:2] |
|
|
if return_small: |
|
|
mask_np = mask_small |
|
|
if view_mode == "Overlay": |
|
|
pil_orig = Image.fromarray(image.astype(np.uint8)) |
|
|
base_img = np.array(pil_orig.resize(mask_small.shape[::-1], resample=Image.BILINEAR)) |
|
|
else: |
|
|
base_img = image |
|
|
else: |
|
|
mask_np = resize_mask_to_original(mask_small, (H, W)) |
|
|
base_img = image |
|
|
|
|
|
if view_mode == "Colored Mask": |
|
|
out_img = make_colored_mask_rgba(mask_np, bg_opacity) |
|
|
elif view_mode == "Overlay": |
|
|
blended = overlay_mask(base_img, mask_np, alpha=alpha) |
|
|
out_img = Image.fromarray(blended) |
|
|
else: |
|
|
max_id = len(COLOR_MAP) - 1 |
|
|
norm = (mask_np / max_id * 255).astype(np.uint8) |
|
|
gray_rgb = np.stack([norm, norm, norm], axis=-1) |
|
|
out_img = Image.fromarray(gray_rgb) |
|
|
|
|
|
if show_colored: |
|
|
colored_only = make_colored_mask_rgba(mask_np, bg_opacity) |
|
|
else: |
|
|
colored_only = None |
|
|
|
|
|
stats = compute_class_stats(mask_np) |
|
|
legend_html = build_legend_html(stats) |
|
|
download_link = raw_mask_download(mask_np) |
|
|
download_html = f"<a class='download-anchor' href='{download_link}' download='raw_mask.png'>Download Raw Mask (PNG)</a>" |
|
|
|
|
|
return out_img, colored_only, legend_html, download_html |
|
|
|
|
|
def clear_outputs(): |
|
|
return None, None, "<p class='legend-empty'>Cleared.</p>", "<div id='download-link'>Cleared.</div>" |
|
|
|
|
|
|
|
|
css_path = Path(__file__).parent / "style.css" |
|
|
css_text = css_path.read_text(encoding="utf-8") |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css_text, title="Hey Inspector • Drone Bridge Image Segmentation") as demo: |
|
|
gr.HTML(""" |
|
|
<div class="hero-banner floating"> |
|
|
<h1 class="hero-title">Hey Inspector • Drone Bridge Image Segmentation</h1> |
|
|
</div> |
|
|
""") |
|
|
if model_load_error: |
|
|
gr.HTML(f"<div style='color:#ff4d4d; font-weight:600; margin-bottom:10px;'>{model_load_error}</div>") |
|
|
|
|
|
gr.HTML("<p class='intro-tagline'>Upload an image and choose how you want to visualize the segmentation.</p>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=5, elem_classes="panel glass left-panel"): |
|
|
input_image = gr.Image( |
|
|
label="Input Image", |
|
|
type="numpy", |
|
|
image_mode="RGB", |
|
|
sources=["upload", "clipboard", "webcam"] |
|
|
) |
|
|
view_mode = gr.Radio( |
|
|
["Colored Mask", "Overlay", "Raw Class Indices"], |
|
|
value="Colored Mask", |
|
|
label="View Mode", |
|
|
elem_id="view-mode-radio" |
|
|
) |
|
|
alpha = gr.Slider( |
|
|
0.0, 1.0, value=0.5, step=0.05, |
|
|
label="Overlay Opacity", |
|
|
elem_id="alpha-slider" |
|
|
) |
|
|
bg_opacity = gr.Slider( |
|
|
0.0, 1.0, value=1.0, step=0.05, |
|
|
label="Background Opacity (Colored Mask)", |
|
|
elem_id="bg-opacity-slider" |
|
|
) |
|
|
show_colored = gr.Checkbox(value=True, label="Show 'Colored Mask (Always)' panel") |
|
|
return_small = gr.Checkbox(value=False, label="Return downsized (256x256) mask instead of original size") |
|
|
with gr.Row(): |
|
|
run_btn = gr.Button("Run Segmentation", elem_id="run-btn", variant="primary") |
|
|
clear_btn = gr.Button("Clear", elem_id="clear-btn") |
|
|
|
|
|
with gr.Column(scale=7, elem_classes="panel glass right-panel"): |
|
|
gr.Markdown("#### Results") |
|
|
output_image = gr.Image(label="Result View", type="pil") |
|
|
color_mask_output = gr.Image(label="Colored Mask (Always)", type="pil") |
|
|
legend_html = gr.HTML("<p class='legend-empty'>Legend will appear here after segmentation.</p>") |
|
|
download_html = gr.HTML("<div id='download-link'>No mask yet.</div>") |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Tips** |
|
|
- Background Opacity affects only Colored Mask outputs (main and the 'always' panel). |
|
|
- Set it to 0 to hide background and emphasize target classes. |
|
|
- Overlay mode ignores the background opacity slider (uses original image + colored mask). |
|
|
- Raw Class Indices is a grayscale class map. |
|
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
|
<script> |
|
|
function toggleLegend(){ |
|
|
const b = document.getElementById('legend-body'); |
|
|
if(b){ b.classList.toggle('collapsed'); } |
|
|
} |
|
|
function syncAlphaVisibility(){ |
|
|
const radios = document.querySelectorAll("#view-mode-radio input"); |
|
|
let mode = "Colored Mask"; |
|
|
radios.forEach(r => { if(r.checked) mode = r.value; }); |
|
|
const overlayWrap = document.querySelector("#alpha-slider")?.closest(".gr-form"); |
|
|
const overlayRange = document.querySelector("#alpha-slider input[type=range]"); |
|
|
const bgWrap = document.querySelector("#bg-opacity-slider")?.closest(".gr-form"); |
|
|
if(overlayRange){ |
|
|
if(mode === "Overlay"){ |
|
|
overlayRange.disabled = false; |
|
|
if(overlayWrap) overlayWrap.style.opacity = "1"; |
|
|
} else { |
|
|
overlayRange.disabled = true; |
|
|
if(overlayWrap) overlayWrap.style.opacity = "0.35"; |
|
|
} |
|
|
} |
|
|
const bgRange = document.querySelector("#bg-opacity-slider input[type=range]"); |
|
|
if(bgRange){ |
|
|
if(mode === "Colored Mask"){ |
|
|
bgRange.disabled = false; |
|
|
if(bgWrap) bgWrap.style.opacity = "1"; |
|
|
} else { |
|
|
bgRange.disabled = true; |
|
|
if(bgWrap) bgWrap.style.opacity = "0.35"; |
|
|
} |
|
|
} |
|
|
} |
|
|
document.addEventListener("change", e => { |
|
|
if(e.target && e.target.closest("#view-mode-radio")) syncAlphaVisibility(); |
|
|
}); |
|
|
window.addEventListener("load", syncAlphaVisibility); |
|
|
</script> |
|
|
""") |
|
|
|
|
|
run_btn.click( |
|
|
fn=run_segmentation, |
|
|
inputs=[input_image, view_mode, alpha, show_colored, return_small, bg_opacity], |
|
|
outputs=[output_image, color_mask_output, legend_html, download_html] |
|
|
) |
|
|
clear_btn.click( |
|
|
fn=clear_outputs, |
|
|
inputs=None, |
|
|
outputs=[output_image, color_mask_output, legend_html, download_html] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |