# app.py
# Gradio app that LOADS a saved scikit-learn model bundle (joblib)
# and uses Roboflow segmentation at runtime (no regressor training here).
# --- Standard Library ---
import tempfile
from io import BytesIO
from pathlib import Path
import base64
# --- Third-party Libraries ---
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import font_manager
from PIL import Image
import gradio as gr
import seaborn as sns
import joblib
from roboflow import Roboflow
# ============================================================
# 0) Paths (joblib + assets are in the SAME folder as app.py)
# ============================================================
APP_DIR = Path(__file__).resolve().parent
MODEL_BUNDLE_PATH = APP_DIR / "progress_regressor.joblib"
BANNER_PATH = APP_DIR / "strive_banner.png" # put this file beside app.py
# ============================================================
# 1) Global Styling Setup (Ruda + seaborn white)
# ============================================================
ruda_font = None
try:
font_path = APP_DIR / "Ruda-Regular.ttf" # optional: place beside app.py
if font_path.exists():
font_manager.fontManager.addfont(str(font_path))
ruda_font = font_manager.FontProperties(fname=str(font_path))
plt.rcParams["font.family"] = ruda_font.get_name()
print(f"Successfully loaded font: {ruda_font.get_name()}")
else:
raise FileNotFoundError("Ruda-Regular.ttf not found")
except Exception:
print("--- FONT WARNING ---")
print("Ruda font not found. Plots will use Matplotlib default font.")
plt.rcParams["font.family"] = "sans-serif"
if ruda_font is not None:
sns.set_theme(style="white", font=ruda_font.get_name())
else:
sns.set_theme(style="white")
ACCENT_COLOR = "#111827"
plt.rcParams.update({
"axes.spines.top": False,
"axes.spines.right": False,
"axes.titlesize": 10,
"axes.labelsize": 9,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"legend.fontsize": 8,
})
def _style_axes(ax):
ax.set_facecolor("white")
for s in ["top", "right", "left"]:
if s in ax.spines:
ax.spines[s].set_visible(False)
if "bottom" in ax.spines:
ax.spines["bottom"].set_visible(True)
ax.spines["bottom"].set_linewidth(2)
ax.spines["bottom"].set_color(ACCENT_COLOR)
# ============================================================
# 2) Config: colors, indices, Roboflow model
# ============================================================
colors = np.array([
[0, 0, 0, 80], # 0 background (semi-transparent)
[255, 0, 0, 128], # 1 beam-concrete
[255, 128, 0, 128], # 2 beam-formwork
[255, 255, 0, 128], # 3 beam-rebar
[0, 255, 0, 128], # 4 columns-concrete
[0, 255, 255, 128], # 5 columns-formwork
[0, 128, 255, 128], # 6 columns-rebar
[0, 0, 255, 128], # 7 wall-concrete
[128, 0, 255, 128], # 8 wall-formwork
[255, 0, 255, 128], # 9 wall-rebar
], dtype=np.uint8)
NUM_CLASSES = len(colors) # 10
# Indices by stage type
CONCRETE_IDX = [1, 4, 7]
FORMWORK_IDX = [2, 5, 8]
REBAR_IDX = [3, 6, 9]
# Indices by structural group
BEAM_IDX = [1, 2, 3]
COLUMNS_IDX = [4, 5, 6]
WALL_IDX = [7, 8, 9]
# --- Roboflow model ---
# NOTE: You are still doing segmentation online via Roboflow each time.
rf = Roboflow(api_key="9voC8YnnNJ4DQRry6gfd") # <-- your key
project = rf.workspace().project("eagle.ai-str-components-v2-vhblf")
seg_model = project.version(8).model
# ============================================================
# 3) Load saved regressor (joblib)
# ============================================================
if not MODEL_BUNDLE_PATH.exists():
raise FileNotFoundError(
f"Could not find saved model bundle:\n {MODEL_BUNDLE_PATH}\n"
f"Make sure 'progress_regressor.joblib' is in the same folder as app.py."
)
bundle = joblib.load(MODEL_BUNDLE_PATH)
best_model = bundle["model"]
feat_cols = bundle["feature_cols"]
print(f"[OK] Loaded regressor from: {MODEL_BUNDLE_PATH}")
# ============================================================
# 4) Utility functions: image prep, mask decoding, legend
# ============================================================
def _prepare_image_for_roboflow(path: str) -> str:
"""
If image has transparency, flatten to white and save as a temp JPEG.
Return a path suitable for Roboflow.
"""
p = Path(path)
im = Image.open(p)
if im.mode in ("RGBA", "LA") or (im.mode == "P" and "transparency" in im.info):
if im.mode != "RGBA":
im = im.convert("RGBA")
bg = Image.new("RGB", im.size, (255, 255, 255))
bg.paste(im, mask=im.split()[-1])
im = bg
else:
im = im.convert("RGB")
tmp_jpg = Path(tempfile.gettempdir()) / f"{p.stem}_rf.jpg"
im.save(tmp_jpg, format="JPEG", quality=90)
return str(tmp_jpg)
def _roboflow_ready_path(original_path: str) -> str:
p = Path(original_path)
ext = p.suffix.lower()
if ext in (".jpg", ".jpeg"):
return str(p)
return _prepare_image_for_roboflow(str(p))
def _decode_mask_to_array(result_json) -> np.ndarray:
preds = result_json.get("predictions", [])
if not preds:
raise ValueError("No predictions returned by the segmentation model.")
mask_base64 = preds[0]["segmentation_mask"]
mask_bytes = base64.b64decode(mask_base64)
mask_img = Image.open(BytesIO(mask_bytes))
return np.array(mask_img)
def _make_legend(class_map, colors_lut: np.ndarray):
"""
Build grouped legend handles with spacing: Beams, Columns, Walls.
"""
def pretty_material(label: str) -> str:
return label.split("-", 1)[1].capitalize()
def make_patch(idx: int, label: str) -> mpatches.Patch:
col = colors_lut[idx][:3]
return mpatches.Patch(color=np.array(col) / 255.0, label=label)
beams, columns, walls = [], [], []
for k, lbl in class_map.items():
idx = int(k)
low = lbl.lower()
if "beam" in low:
beams.append((idx, lbl))
elif "column" in low:
columns.append((idx, lbl))
elif "wall" in low:
walls.append((idx, lbl))
handles = []
def add_spacing():
handles.append(mpatches.Patch(color=(0, 0, 0, 0), label=" "))
add_spacing()
if beams:
handles.append(mpatches.Patch(color="none", label="Beams"))
for idx, lbl in sorted(beams, key=lambda x: x[0]):
handles.append(make_patch(idx, " " + pretty_material(lbl)))
add_spacing()
if columns:
handles.append(mpatches.Patch(color="none", label="Columns"))
for idx, lbl in sorted(columns, key=lambda x: x[0]):
handles.append(make_patch(idx, " " + pretty_material(lbl)))
add_spacing()
if walls:
handles.append(mpatches.Patch(color="none", label="Walls"))
for idx, lbl in sorted(walls, key=lambda x: x[0]):
handles.append(make_patch(idx, " " + pretty_material(lbl)))
return handles
# ============================================================
# 5) Segmentation & overlay helpers
# ============================================================
def get_mask_from_image(img_path: str):
rf_path = _roboflow_ready_path(img_path)
result = seg_model.predict(rf_path).json()
mask_array = _decode_mask_to_array(result)
return mask_array, result
def make_overlay_image(img_path: str, mask_array: np.ndarray, result_json, alpha_blend: bool = True) -> Image.Image:
"""
Create an RGBA overlay image with legend from original image + mask.
Returns a PIL.Image that Gradio can display.
"""
original_img = Image.open(img_path).convert("RGBA")
if mask_array.max() >= len(colors):
raise IndexError(f"Mask contains class index {mask_array.max()} but colors size is {len(colors)}.")
color_mask = colors[mask_array]
# Ensure alpha
if color_mask.shape[-1] == 3:
a = np.full(color_mask.shape[:2] + (1,), 128 if alpha_blend else 255, dtype=np.uint8)
color_mask = np.concatenate([color_mask, a], axis=-1)
else:
if alpha_blend and np.all(color_mask[..., 3] == 255):
color_mask[..., 3] = 128
mask_colored = Image.fromarray(color_mask, mode="RGBA").resize(original_img.size, Image.NEAREST)
overlay = Image.alpha_composite(original_img, mask_colored)
class_map = result_json["predictions"][0]["class_map"]
handles = _make_legend(class_map, colors)
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(overlay)
ax.axis("off")
ax.legend(
handles=handles,
loc="center left",
bbox_to_anchor=(1.01, 0.5),
borderaxespad=0.2,
frameon=False,
title="Classes",
title_fontsize=7,
prop={"size": 7},
labelspacing=0.2,
handlelength=0.8,
handleheight=0.8,
handletextpad=0.4,
)
plt.tight_layout()
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
plt.close(fig)
buf.seek(0)
return Image.open(buf).convert("RGB")
# ============================================================
# 6) Feature extraction from mask
# ============================================================
def extract_class_features(mask_array: np.ndarray, num_classes: int = NUM_CLASSES):
flat = mask_array.flatten()
counts = np.bincount(flat, minlength=num_classes)
total = mask_array.size
ratios = counts / total if total > 0 else np.zeros_like(counts, dtype=float)
return counts, ratios
def aggregate_stage_features(ratios: np.ndarray):
f_conc = ratios[CONCRETE_IDX].sum()
f_form = ratios[FORMWORK_IDX].sum()
f_rebar = ratios[REBAR_IDX].sum()
f_beams = ratios[BEAM_IDX].sum()
f_columns = ratios[COLUMNS_IDX].sum()
f_walls = ratios[WALL_IDX].sum()
f_finished = f_conc
f_in_progress = f_form + f_rebar
eps = 1e-6
ratio_cf = f_conc / (f_form + eps)
ratio_fr = f_form / (f_rebar + eps)
ratio_rc = f_rebar / (f_conc + eps)
return {
"ratio_concrete": float(f_conc),
"ratio_formwork": float(f_form),
"ratio_rebar": float(f_rebar),
"ratio_beams": float(f_beams),
"ratio_columns": float(f_columns),
"ratio_walls": float(f_walls),
"ratio_finished": float(f_finished),
"ratio_in_progress": float(f_in_progress),
"ratio_cf": float(ratio_cf),
"ratio_fr": float(ratio_fr),
"ratio_rc": float(ratio_rc),
}
# ============================================================
# 7) Aggregate features over any number of images
# ============================================================
def aggregate_features_over_images(image_paths, feature_cols):
n_used = len(image_paths)
if n_used == 0:
raise ValueError("No image paths provided for aggregation.")
agg_sums = None
per_class_sums = np.zeros(NUM_CLASSES, dtype=float)
class_counts_sum = np.zeros(NUM_CLASSES, dtype=float)
overlays = []
class_map_first = None
for img_path in image_paths:
mask, result_json = get_mask_from_image(img_path)
counts, ratios = extract_class_features(mask, num_classes=NUM_CLASSES)
overlay_img = make_overlay_image(img_path, mask, result_json)
overlays.append(overlay_img)
if class_map_first is None:
class_map_first = result_json["predictions"][0]["class_map"]
agg = aggregate_stage_features(ratios)
if agg_sums is None:
agg_sums = {k: float(v) for k, v in agg.items()}
else:
for k, v in agg.items():
agg_sums[k] += float(v)
per_class_sums += ratios
class_counts_sum += counts
agg_avg = {k: v / n_used for k, v in agg_sums.items()}
per_class_avg = {f"ratio_class_{i}": float(per_class_sums[i] / n_used) for i in range(1, NUM_CLASSES)}
feat_dict = {**agg_avg, **per_class_avg}
feat_vector = np.array([[feat_dict[c] for c in feature_cols]])
return (
feat_vector,
agg_avg,
per_class_avg,
class_counts_sum,
overlays,
class_map_first,
n_used,
)
# ============================================================
# 8) Single-image prediction (Tab 1)
# ============================================================
def analyze_image(image_path):
if image_path is None:
return (None, "
Please upload an image.
", None, None, None)
mask, result_json = get_mask_from_image(image_path)
overlay_img = make_overlay_image(image_path, mask, result_json)
counts, ratios = extract_class_features(mask, num_classes=NUM_CLASSES)
agg = aggregate_stage_features(ratios)
per_class_feats = {f"ratio_class_{i}": float(ratios[i]) for i in range(1, NUM_CLASSES)}
feat_dict = {**agg, **per_class_feats}
x = np.array([[feat_dict[c] for c in feat_cols]])
pred = float(best_model.predict(x)[0])
summary_html = f"""
Predicted progress
{pred:.2f}%
Stage coverage – share of detected pixels in Concrete/Formwork/Rebar.
Stage ratios (C/F, F/R, R/C) – ratios describing stage advancement.
Objects heatmap – where detections concentrate (Beams/Columns/Walls × stages).
"""
conc = agg["ratio_concrete"]
form = agg["ratio_formwork"]
reb = agg["ratio_rebar"]
det_sum = conc + form + reb
if det_sum > 0:
conc_obj_pct = conc / det_sum * 100.0
form_obj_pct = form / det_sum * 100.0
reb_obj_pct = reb / det_sum * 100.0
else:
conc_obj_pct = form_obj_pct = reb_obj_pct = 0.0
# Stage coverage pie chart
stage_palette = {"Concrete": "#9e9e9e", "Formwork": "#d97706", "Rebar": "#b7410e"}
fig_stage_cov, ax1 = plt.subplots(figsize=(3.0, 3.0))
values = [conc_obj_pct, form_obj_pct, reb_obj_pct]
labels = ["Concrete", "Formwork", "Rebar"]
pie_colors = [stage_palette[l] for l in labels]
if sum(values) > 0:
wedges, texts, autotexts = ax1.pie(
values,
labels=labels,
colors=pie_colors,
autopct="%1.1f%%",
pctdistance=0.78,
labeldistance=1.1,
startangle=90,
textprops={"fontsize": 8},
)
for autotext in autotexts:
autotext.set_fontsize(7)
ax1.axis("equal")
else:
ax1.text(0.5, 0.5, "No detected objects", ha="center", va="center", fontsize=8)
ax1.axis("off")
_style_axes(ax1)
fig_stage_cov.tight_layout()
# Stage ratios bar (C/F, F/R, R/C)
fig_stage_ratios, ax3 = plt.subplots(figsize=(3.0, 3.0))
df_ratios = pd.DataFrame({
"Ratio": ["C/F", "F/R", "R/C"],
"Value": [agg["ratio_cf"], agg["ratio_fr"], agg["ratio_rc"]],
})
ratio_palette = {"C/F": "#9e9e9e", "F/R": "#d97706", "R/C": "#b7410e"}
sns.barplot(
data=df_ratios,
x="Ratio",
y="Value",
ax=ax3,
palette=[ratio_palette[r] for r in df_ratios["Ratio"]],
)
ax3.set_ylabel("Ratio", fontsize=8)
ax3.set_xlabel("", fontsize=8)
ax3.tick_params(axis="both", labelsize=8)
legend_patches = [
mpatches.Patch(color="none", label="C = Concrete"),
mpatches.Patch(color="none", label="F = Formwork"),
mpatches.Patch(color="none", label="R = Rebar"),
]
ax3.legend(handles=legend_patches, loc="upper right", frameon=False, fontsize=7)
_style_axes(ax3)
fig_stage_ratios.tight_layout()
# Objects 3×3 heatmap with class colors
object_total = int(sum(counts[1:]))
groups = ["Beams", "Columns", "Walls"]
stages = ["Concrete", "Formwork", "Rebar"]
heat_counts = np.zeros((3, 3), dtype=float)
if object_total > 0:
for idx in range(1, NUM_CLASSES):
c_val = counts[idx]
if c_val <= 0:
continue
if idx in BEAM_IDX:
r = 0
elif idx in COLUMNS_IDX:
r = 1
elif idx in WALL_IDX:
r = 2
else:
continue
if idx in CONCRETE_IDX:
c_idx = 0
elif idx in FORMWORK_IDX:
c_idx = 1
elif idx in REBAR_IDX:
c_idx = 2
else:
continue
heat_counts[r, c_idx] += c_val
heat_pct = (heat_counts / object_total) * 100.0
else:
heat_pct = np.zeros((3, 3), dtype=float)
idx_grid = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
rgb_img = np.zeros((3, 3, 3), dtype=float)
for r in range(3):
for c in range(3):
idx = idx_grid[r, c]
base_rgb = colors[idx][:3] / 255.0
alpha = np.clip(heat_pct[r, c] / 100.0, 0.0, 1.0)
rgb_img[r, c, :] = (1 - alpha) * np.array([1.0, 1.0, 1.0]) + (alpha * base_rgb)
fig_objects, ax4 = plt.subplots(figsize=(3.0, 3.0))
ax4.imshow(rgb_img, aspect="equal", extent=(-0.5, 2.5, 2.5, -0.5))
ax4.set_xlim(-0.5, 2.5)
ax4.set_ylim(2.5, -0.5)
for x in np.arange(-0.5, 3.0, 1.0):
ax4.axvline(x, color="#d1d5db", linewidth=0.8, zorder=3, clip_on=False)
for y in np.arange(-0.5, 3.0, 1.0):
ax4.axhline(y, color="#d1d5db", linewidth=0.8, zorder=3, clip_on=False)
ax4.set_xticks(np.arange(3))
ax4.set_yticks(np.arange(3))
ax4.set_xticklabels(stages, fontsize=8)
ax4.set_yticklabels(groups, fontsize=8)
ax4.tick_params(which="both", length=0)
for r in range(3):
for c in range(3):
ax4.text(c, r, f"{heat_pct[r, c]:.1f}%", ha="center", va="center", fontsize=7, color="black", zorder=4)
ax4.set_xlabel("Stage", fontsize=8)
ax4.set_ylabel("Structural group", fontsize=8)
_style_axes(ax4)
fig_objects.tight_layout()
return (overlay_img, summary_html, fig_stage_cov, fig_stage_ratios, fig_objects)
# ============================================================
# 9) Multi-image aggregated prediction (Tab 2)
# ============================================================
def analyze_images(image_paths):
if not image_paths:
return (
[],
"Please upload at least one image.
",
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
)
# gr.Files sometimes returns list of dicts with "name"
if isinstance(image_paths[0], dict) and "name" in image_paths[0]:
img_paths = [f["name"] for f in image_paths]
else:
img_paths = image_paths
(
feat_vector,
agg_avg,
_,
class_counts_sum,
overlays,
class_map_first,
n_used,
) = aggregate_features_over_images(img_paths, feat_cols)
pred = float(best_model.predict(feat_vector)[0])
summary_html = f"""
Predicted progress averaged over {n_used} photo(s)
{pred:.2f}%
"""
conc = agg_avg["ratio_concrete"]
form = agg_avg["ratio_formwork"]
reb = agg_avg["ratio_rebar"]
det_sum = conc + form + reb
if det_sum > 0:
conc_obj_pct = conc / det_sum * 100.0
form_obj_pct = form / det_sum * 100.0
reb_obj_pct = reb / det_sum * 100.0
else:
conc_obj_pct = form_obj_pct = reb_obj_pct = 0.0
# Stage coverage pie (avg)
stage_palette = {"Concrete": "#9e9e9e", "Formwork": "#d97706", "Rebar": "#b7410e"}
fig_stage_cov, ax1 = plt.subplots(figsize=(3.0, 3.0))
values = [conc_obj_pct, form_obj_pct, reb_obj_pct]
labels = ["Concrete", "Formwork", "Rebar"]
pie_colors = [stage_palette[l] for l in labels]
if sum(values) > 0:
wedges, texts, autotexts = ax1.pie(
values,
labels=labels,
colors=pie_colors,
autopct="%1.1f%%",
pctdistance=0.78,
labeldistance=1.1,
startangle=90,
textprops={"fontsize": 8},
)
for autotext in autotexts:
autotext.set_fontsize(7)
ax1.axis("equal")
else:
ax1.text(0.5, 0.5, "No detected objects", ha="center", va="center", fontsize=8)
ax1.axis("off")
_style_axes(ax1)
fig_stage_cov.tight_layout()
# Stage ratios bar (avg)
fig_stage_ratios, ax3 = plt.subplots(figsize=(3.0, 3.0))
df_ratios = pd.DataFrame({
"Ratio": ["C/F", "F/R", "R/C"],
"Value": [agg_avg["ratio_cf"], agg_avg["ratio_fr"], agg_avg["ratio_rc"]],
})
ratio_palette = {"C/F": "#9e9e9e", "F/R": "#d97706", "R/C": "#b7410e"}
sns.barplot(
data=df_ratios,
x="Ratio",
y="Value",
ax=ax3,
palette=[ratio_palette[r] for r in df_ratios["Ratio"]],
)
ax3.set_ylabel("Ratio", fontsize=8)
ax3.set_xlabel("", fontsize=8)
ax3.tick_params(axis="both", labelsize=8)
legend_patches = [
mpatches.Patch(color="none", label="C = Concrete"),
mpatches.Patch(color="none", label="F = Formwork"),
mpatches.Patch(color="none", label="R = Rebar"),
]
ax3.legend(handles=legend_patches, loc="upper right", frameon=False, fontsize=7)
_style_axes(ax3)
fig_stage_ratios.tight_layout()
# Aggregated objects heatmap
object_total = int(sum(class_counts_sum[1:]))
groups = ["Beams", "Columns", "Walls"]
stages = ["Concrete", "Formwork", "Rebar"]
heat_counts = np.zeros((3, 3), dtype=float)
if object_total > 0 and class_map_first is not None:
for idx in range(1, NUM_CLASSES):
c_val = class_counts_sum[idx]
if c_val <= 0:
continue
if idx in BEAM_IDX:
r = 0
elif idx in COLUMNS_IDX:
r = 1
elif idx in WALL_IDX:
r = 2
else:
continue
if idx in CONCRETE_IDX:
c_idx = 0
elif idx in FORMWORK_IDX:
c_idx = 1
elif idx in REBAR_IDX:
c_idx = 2
else:
continue
heat_counts[r, c_idx] += c_val
heat_pct = (heat_counts / object_total) * 100.0
else:
heat_pct = np.zeros((3, 3), dtype=float)
idx_grid = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
rgb_img = np.zeros((3, 3, 3), dtype=float)
for r in range(3):
for c in range(3):
idx = idx_grid[r, c]
base_rgb = colors[idx][:3] / 255.0
alpha = np.clip(heat_pct[r, c] / 100.0, 0.0, 1.0)
rgb_img[r, c, :] = (1 - alpha) * np.array([1.0, 1.0, 1.0]) + (alpha * base_rgb)
fig_objects_agg, ax4 = plt.subplots(figsize=(3.0, 3.0))
ax4.imshow(rgb_img, aspect="equal", extent=(-0.5, 2.5, 2.5, -0.5))
ax4.set_xlim(-0.5, 2.5)
ax4.set_ylim(2.5, -0.5)
for x in np.arange(-0.5, 3.0, 1.0):
ax4.axvline(x, color="#d1d5db", linewidth=0.8, zorder=3, clip_on=False)
for y in np.arange(-0.5, 3.0, 1.0):
ax4.axhline(y, color="#d1d5db", linewidth=0.8, zorder=3, clip_on=False)
ax4.set_xticks(np.arange(3))
ax4.set_yticks(np.arange(3))
ax4.set_xticklabels(stages, fontsize=8)
ax4.set_yticklabels(groups, fontsize=8)
ax4.tick_params(which="both", length=0)
for r in range(3):
for c in range(3):
ax4.text(c, r, f"{heat_pct[r, c]:.1f}%", ha="center", va="center", fontsize=7, color="black", zorder=4)
ax4.set_xlabel("Stage", fontsize=8)
ax4.set_ylabel("Structural group", fontsize=8)
_style_axes(ax4)
fig_objects_agg.tight_layout()
return (
overlays,
summary_html,
gr.update(value=fig_stage_cov, visible=True),
gr.update(value=fig_stage_ratios, visible=True),
gr.update(value=fig_objects_agg, visible=True),
)
# ============================================================
# 10) Gradio UI with two tabs
# ============================================================
with gr.Blocks(
css="""
button.primary {
background: linear-gradient(90deg,#9333ea 0%,#dc2626 100%) !important;
border: none !important;
color: white !important;
font-weight: 600;
transition: all 0.2s ease;
}
button.primary:hover { filter: brightness(1.05); }
button.primary:active { filter: brightness(0.95); }
"""
) as demo:
# banner (optional)
if BANNER_PATH.exists():
gr.Image(value=str(BANNER_PATH), show_label=False, type="filepath")
else:
gr.Markdown("### STRIVE Progress Estimator")
# ---------------- Tab 1: Single image -----------------
with gr.Tab("Single image"):
with gr.Row():
with gr.Column(scale=1):
img_in_single = gr.Image(type="filepath", label="Upload construction photo")
run_btn_single = gr.Button("Analyze", variant="primary")
summary_box_single = gr.HTML(label="Predicted progress")
with gr.Column(scale=2):
img_out_single = gr.Image(label="Overlayed segmentation + legend")
with gr.Row():
stage_cov_plot_single = gr.Plot(label="Stage coverage")
stage_ratio_plot_single = gr.Plot(label="Stage ratios")
objects_plot_single = gr.Plot(label="Objects heatmap")
run_btn_single.click(
fn=analyze_image,
inputs=[img_in_single],
outputs=[
img_out_single,
summary_box_single,
stage_cov_plot_single,
stage_ratio_plot_single,
objects_plot_single,
],
)
# ---------------- Tab 2: Multiple images -----------------
with gr.Tab("Multiple images"):
with gr.Row():
with gr.Column(scale=1):
img_in_multi = gr.Files(label="Upload multiple construction photos", file_types=["image"])
run_btn_multi = gr.Button("Analyze all", variant="primary")
summary_box_multi = gr.HTML(label="Predicted progress (averaged)")
with gr.Column(scale=2):
overlays_gallery = gr.Gallery(label="Overlays", show_label=True, columns=3, height="auto")
with gr.Row():
stage_cov_plot_multi = gr.Plot(label="Stage coverage (avg)")
stage_ratio_plot_multi = gr.Plot(label="Stage ratios (avg)")
objects_plot_multi = gr.Plot(label="Objects heatmap (avg)")
run_btn_multi.click(
fn=analyze_images,
inputs=[img_in_multi],
outputs=[
overlays_gallery,
summary_box_multi,
stage_cov_plot_multi,
stage_ratio_plot_multi,
objects_plot_multi,
],
)
if __name__ == "__main__":
demo.launch(inbrowser=True)