Suzyloubna's picture
Update app.py
a208ded verified
import gradio as gr
import numpy as np
import torch
from pathlib import Path
from PIL import Image
from inference import load_model, predict
# -------- Hub / Weights Configuration --------
HUB_REPO_ID = "Suzyloubna/bridge-unetpp" # Hugging Face model repo (change if you renamed it)
WEIGHTS_FILENAME = "MILESTONE_090_ACHIEVED_iou_0.9077.pth"
WEIGHTS_PATH = Path(WEIGHTS_FILENAME)
# Try to fetch weights from Hub if not present locally
# (Requires 'huggingface-hub' in requirements.txt)
try:
if not WEIGHTS_PATH.exists():
print(f"Weights file {WEIGHTS_FILENAME} not found locally. Downloading from {HUB_REPO_ID} ...")
from huggingface_hub import hf_hub_download
hf_hub_download(
repo_id=HUB_REPO_ID,
filename=WEIGHTS_FILENAME,
local_dir=".", # place file in current working directory
local_dir_use_symlinks=False # make a real copy (Spaces friendly)
)
if WEIGHTS_PATH.exists():
print("Download complete.")
else:
print("Download attempted but file still not found.")
except Exception as dl_err:
print(f"WARNING: Could not download weights automatically: {dl_err}")
# ---------------- Runtime / Device ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASS_INFO = [
{"id": 0, "name": "background", "color": (0, 0, 0)},
{"id": 1, "name": "beton", "color": (0, 114, 189)},
{"id": 2, "name": "steel", "color": (200, 30, 30)},
]
COLOR_MAP = np.array([c["color"] for c in CLASS_INFO], dtype=np.uint8)
# ---------------- Load Model (defensive) ----------------
model_load_error = None
model = None
try:
if WEIGHTS_PATH.exists():
model = load_model(str(WEIGHTS_PATH), DEVICE)
else:
model_load_error = f"Weight file {WEIGHTS_FILENAME} not found after download attempt."
except Exception as e:
model_load_error = f"Model failed to load: {e}"
# ---------------- Utility Functions ----------------
def resize_mask_to_original(mask_np_small: np.ndarray, original_shape):
H, W = original_shape[:2]
if mask_np_small.shape[:2] == (H, W):
return mask_np_small
pil_small = Image.fromarray(mask_np_small.astype(np.uint8))
pil_big = pil_small.resize((W, H), resample=Image.NEAREST)
return np.array(pil_big)
def overlay_mask(original_np: np.ndarray, mask_np: np.ndarray, alpha: float = 0.5):
color_mask = COLOR_MAP[mask_np]
blended = (1 - alpha) * original_np.astype(np.float32) + alpha * color_mask
return blended.clip(0, 255).astype(np.uint8)
def compute_class_stats(mask_np: np.ndarray):
total = mask_np.size
counts = np.bincount(mask_np.flatten(), minlength=len(COLOR_MAP))
stats = []
for info in CLASS_INFO:
cid = info["id"]
count = int(counts[cid]) if cid < len(counts) else 0
pct = (count / total * 100.0) if total else 0.0
stats.append({**info, "count": count, "pct": pct})
return stats
def build_legend_html(stats):
rows = []
for s in stats:
r, g, b = s["color"]
rows.append(f"""
<div class="legend-item" aria-label="Class {s['name']}">
<div class="legend-color" style="--c: rgb({r},{g},{b});"></div>
<div class="legend-meta">
<div class="legend-name">{s['id']}: {s['name']}</div>
<div class="legend-stats">
<span class="legend-count">{s['count']} px</span>
<span class="legend-pct">{s['pct']:.2f}%</span>
</div>
</div>
</div>
""")
return f"""
<div class="legend-wrapper" id="legend-wrapper">
<div class="legend-header">
<span>Segmentation Legend</span>
<button onclick="toggleLegend()" class="legend-toggle-btn" aria-label="Collapse legend">⤢</button>
</div>
<div id="legend-body" class="legend-body expanded">
{''.join(rows)}
</div>
</div>
"""
def raw_mask_download(mask_np: np.ndarray):
from io import BytesIO
import base64
img = Image.fromarray(mask_np.astype(np.uint8))
bio = BytesIO()
img.save(bio, format="PNG")
bio.seek(0)
return "data:image/png;base64," + base64.b64encode(bio.read()).decode()
def make_colored_mask_rgba(mask_np: np.ndarray, bg_opacity: float):
"""
Return an RGBA image where background class (0) has adjustable opacity.
bg_opacity in [0,1].
"""
rgb = COLOR_MAP[mask_np] # (H,W,3)
H, W = mask_np.shape
alpha_channel = np.full((H, W), 255, dtype=np.uint8)
alpha_channel[mask_np == 0] = int(bg_opacity * 255)
rgba = np.dstack([rgb, alpha_channel]).astype(np.uint8)
return Image.fromarray(rgba, mode="RGBA")
def run_segmentation(image, view_mode, alpha, show_colored, return_small, bg_opacity):
if model is None:
return (None, None, "<p class='legend-empty'>Model not loaded.</p>",
f"<span style='color:#ff8080'>{model_load_error or 'Model error.'}</span>")
if image is None:
return (None, None, "<p class='legend-empty'>No image yet.</p>",
"<span style='opacity:0.6'>No mask.</span>")
pred_mask = predict(image, model, DEVICE)
mask_small = pred_mask.numpy()
H, W = image.shape[:2]
if return_small:
mask_np = mask_small
if view_mode == "Overlay":
pil_orig = Image.fromarray(image.astype(np.uint8))
base_img = np.array(pil_orig.resize(mask_small.shape[::-1], resample=Image.BILINEAR))
else:
base_img = image
else:
mask_np = resize_mask_to_original(mask_small, (H, W))
base_img = image
if view_mode == "Colored Mask":
out_img = make_colored_mask_rgba(mask_np, bg_opacity)
elif view_mode == "Overlay":
blended = overlay_mask(base_img, mask_np, alpha=alpha)
out_img = Image.fromarray(blended)
else: # Raw Class Indices
max_id = len(COLOR_MAP) - 1
norm = (mask_np / max_id * 255).astype(np.uint8)
gray_rgb = np.stack([norm, norm, norm], axis=-1)
out_img = Image.fromarray(gray_rgb)
if show_colored:
colored_only = make_colored_mask_rgba(mask_np, bg_opacity)
else:
colored_only = None
stats = compute_class_stats(mask_np)
legend_html = build_legend_html(stats)
download_link = raw_mask_download(mask_np)
download_html = f"<a class='download-anchor' href='{download_link}' download='raw_mask.png'>Download Raw Mask (PNG)</a>"
return out_img, colored_only, legend_html, download_html
def clear_outputs():
return None, None, "<p class='legend-empty'>Cleared.</p>", "<div id='download-link'>Cleared.</div>"
# ---------------- Load CSS ----------------
css_path = Path(__file__).parent / "style.css"
css_text = css_path.read_text(encoding="utf-8")
# ---------------- Interface Layout ----------------
with gr.Blocks(css=css_text, title="Hey Inspector • Drone Bridge Image Segmentation") as demo:
gr.HTML("""
<div class="hero-banner floating">
<h1 class="hero-title">Hey Inspector • Drone Bridge Image Segmentation</h1>
</div>
""")
if model_load_error:
gr.HTML(f"<div style='color:#ff4d4d; font-weight:600; margin-bottom:10px;'>{model_load_error}</div>")
gr.HTML("<p class='intro-tagline'>Upload an image and choose how you want to visualize the segmentation.</p>")
with gr.Row():
with gr.Column(scale=5, elem_classes="panel glass left-panel"):
input_image = gr.Image(
label="Input Image",
type="numpy",
image_mode="RGB",
sources=["upload", "clipboard", "webcam"]
)
view_mode = gr.Radio(
["Colored Mask", "Overlay", "Raw Class Indices"],
value="Colored Mask",
label="View Mode",
elem_id="view-mode-radio"
)
alpha = gr.Slider(
0.0, 1.0, value=0.5, step=0.05,
label="Overlay Opacity",
elem_id="alpha-slider"
)
bg_opacity = gr.Slider(
0.0, 1.0, value=1.0, step=0.05,
label="Background Opacity (Colored Mask)",
elem_id="bg-opacity-slider"
)
show_colored = gr.Checkbox(value=True, label="Show 'Colored Mask (Always)' panel")
return_small = gr.Checkbox(value=False, label="Return downsized (256x256) mask instead of original size")
with gr.Row():
run_btn = gr.Button("Run Segmentation", elem_id="run-btn", variant="primary")
clear_btn = gr.Button("Clear", elem_id="clear-btn")
with gr.Column(scale=7, elem_classes="panel glass right-panel"):
gr.Markdown("#### Results")
output_image = gr.Image(label="Result View", type="pil")
color_mask_output = gr.Image(label="Colored Mask (Always)", type="pil")
legend_html = gr.HTML("<p class='legend-empty'>Legend will appear here after segmentation.</p>")
download_html = gr.HTML("<div id='download-link'>No mask yet.</div>")
gr.Markdown("""
**Tips**
- Background Opacity affects only Colored Mask outputs (main and the 'always' panel).
- Set it to 0 to hide background and emphasize target classes.
- Overlay mode ignores the background opacity slider (uses original image + colored mask).
- Raw Class Indices is a grayscale class map.
""")
gr.HTML("""
<script>
function toggleLegend(){
const b = document.getElementById('legend-body');
if(b){ b.classList.toggle('collapsed'); }
}
function syncAlphaVisibility(){
const radios = document.querySelectorAll("#view-mode-radio input");
let mode = "Colored Mask";
radios.forEach(r => { if(r.checked) mode = r.value; });
const overlayWrap = document.querySelector("#alpha-slider")?.closest(".gr-form");
const overlayRange = document.querySelector("#alpha-slider input[type=range]");
const bgWrap = document.querySelector("#bg-opacity-slider")?.closest(".gr-form");
if(overlayRange){
if(mode === "Overlay"){
overlayRange.disabled = false;
if(overlayWrap) overlayWrap.style.opacity = "1";
} else {
overlayRange.disabled = true;
if(overlayWrap) overlayWrap.style.opacity = "0.35";
}
}
const bgRange = document.querySelector("#bg-opacity-slider input[type=range]");
if(bgRange){
if(mode === "Colored Mask"){
bgRange.disabled = false;
if(bgWrap) bgWrap.style.opacity = "1";
} else {
bgRange.disabled = true;
if(bgWrap) bgWrap.style.opacity = "0.35";
}
}
}
document.addEventListener("change", e => {
if(e.target && e.target.closest("#view-mode-radio")) syncAlphaVisibility();
});
window.addEventListener("load", syncAlphaVisibility);
</script>
""")
run_btn.click(
fn=run_segmentation,
inputs=[input_image, view_mode, alpha, show_colored, return_small, bg_opacity],
outputs=[output_image, color_mask_output, legend_html, download_html]
)
clear_btn.click(
fn=clear_outputs,
inputs=None,
outputs=[output_image, color_mask_output, legend_html, download_html]
)
if __name__ == "__main__":
demo.launch()