Box-Detection / app.py
saim1309's picture
Update app.py
e9dce51 verified
import gradio as gr
import os
import json
import cv2
import numpy as np
from sklearn.cluster import KMeans
# ----------------- Config -----------------
RESIZE_MAX = 1600
MIN_AREA = 300
MAX_AREA = 120000
APPROX_EPS = 0.06
IOU_NMS = 0.25
COLOR_CLUSTER_N = 6
SAT_MIN = 20
VAL_MIN = 20
ROW_TOL = 0.75
AREA_FILTER_THRESH = 0.35
# ----------------- Utility Functions -----------------
def load_and_resize(img_or_path, max_dim=RESIZE_MAX):
if isinstance(img_or_path, str): # file path
img = cv2.imread(img_or_path)
if img is None:
raise FileNotFoundError(f"Image not found: {img_or_path}")
elif isinstance(img_or_path, np.ndarray): # already loaded
img = img_or_path.copy()
else:
raise ValueError("Input must be a file path or a numpy array")
h, w = img.shape[:2]
if max(h, w) > max_dim:
scale = max_dim / float(max(h, w))
img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
return img
def non_max_suppression(boxes, iou_thresh=IOU_NMS):
if not boxes:
return []
arr = np.array(boxes, dtype=float)
x1 = arr[:, 0]; y1 = arr[:, 1]; x2 = arr[:, 0] + arr[:, 2]; y2 = arr[:, 1] + arr[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = areas.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(tuple(arr[i].astype(int)))
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
inter = w * h
union = areas[i] + areas[order[1:]] - inter
iou = inter / (union + 1e-8)
inds = np.where(iou <= iou_thresh)[0]
order = order[inds + 1]
return keep
def color_cluster_masks(img, n_clusters=COLOR_CLUSTER_N):
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
h, w = lab.shape[:2]
pixels = lab.reshape(-1, 3).astype(np.float32)
max_samples = 30000
if pixels.shape[0] > max_samples:
rng = np.random.default_rng(0)
sample_idx = rng.choice(pixels.shape[0], max_samples, replace=False)
sample = pixels[sample_idx]
else:
sample = pixels
n_clusters = min(n_clusters, max(1, sample.shape[0]))
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=8).fit(sample)
centers = kmeans.cluster_centers_
centers_f = centers.astype(np.float32).reshape(1, 1, n_clusters, 3)
lab_f = lab.astype(np.float32).reshape(h, w, 1, 3)
diff = lab_f - centers_f
dist = np.linalg.norm(diff, axis=3)
labels = np.argmin(dist, axis=2).astype(np.int32)
masks = [(labels == k).astype(np.uint8) * 255 for k in range(n_clusters)]
return masks
def refine_mask_by_hsv(mask, img, sat_min=SAT_MIN, val_min=VAL_MIN):
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
s = hsv[:, :, 1]; v = hsv[:, :, 2]
sv_mask = (s >= sat_min) & (v >= val_min)
refined = mask.copy()
refined[~sv_mask] = 0
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel, iterations=2)
refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel, iterations=1)
return refined
def contours_from_mask(mask):
cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects = []
for c in cnts:
area = cv2.contourArea(c)
if area < MIN_AREA or area > MAX_AREA:
continue
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, APPROX_EPS * peri, True)
x, y, w, h = cv2.boundingRect(approx)
if h == 0 or w == 0:
continue
ar = w / float(h)
if 0.12 < ar < 8:
rects.append((x, y, w, h))
return rects
def mser_candidates(img):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
mser = cv2.MSER_create()
mser.setMinArea(60)
mser.setMaxArea(MAX_AREA)
regions, _ = mser.detectRegions(gray)
rects = []
for r in regions:
x, y, w, h = cv2.boundingRect(r.reshape(-1, 1, 2))
area = w * h
if area < MIN_AREA or area > MAX_AREA:
continue
ar = w / float(h) if h > 0 else 0
if 0.25 < ar < 4.0:
rects.append((x, y, w, h))
return rects
def collect_candidates(img):
masks = color_cluster_masks(img, n_clusters=COLOR_CLUSTER_N)
cluster_rects = []
for m in masks:
refined = refine_mask_by_hsv(m, img)
rects = contours_from_mask(refined)
cluster_rects.extend(rects)
mser_rects = mser_candidates(img)
all_rects = cluster_rects + mser_rects
nms = non_max_suppression(all_rects, IOU_NMS)
return nms
def filter_by_area(rects):
if not rects:
return rects
areas = np.array([w * h for (_, _, w, h) in rects], dtype=float)
avg_area = np.mean(areas)
lower = avg_area * (1.0 - AREA_FILTER_THRESH)
upper = avg_area * (1.0 + AREA_FILTER_THRESH)
return [r for r, a in zip(rects, areas) if lower <= a <= upper]
def group_rows(rects, tol=ROW_TOL):
if not rects:
return []
rects = sorted(rects, key=lambda b: b[1])
rows = [[rects[0]]]
for r in rects[1:]:
prev = rows[-1][-1]
y1 = prev[1] + prev[3] / 2.0
y2 = r[1] + r[3] / 2.0
avg_h = (prev[3] + r[3]) / 2.0
if abs(y1 - y2) <= tol * avg_h:
rows[-1].append(r)
else:
rows.append([r])
return rows
def group_columns(rects, tol=ROW_TOL):
if not rects:
return []
rects = sorted(rects, key=lambda b: b[0])
cols = [[rects[0]]]
for r in rects[1:]:
prev = cols[-1][-1]
x1 = prev[0] + prev[2] / 2.0
x2 = r[0] + r[2] / 2.0
avg_w = (prev[2] + r[2]) / 2.0
if abs(x1 - x2) <= tol * avg_w:
cols[-1].append(r)
else:
cols.append([r])
return cols
def fill_missing_boxes(img, reference_rects, row_tol=ROW_TOL, col_tol=ROW_TOL):
if not reference_rects:
return []
areas = [w * h for (_, _, w, h) in reference_rects]
rounded_areas = [int(a // 100) * 100 for a in areas]
unique, counts = np.unique(rounded_areas, return_counts=True)
most_common_area = unique[np.argmax(counts)]
closest_box = min(reference_rects, key=lambda r: abs((r[2]*r[3]) - most_common_area))
avg_w, avg_h = int(closest_box[2]), int(closest_box[3])
rows = group_rows(reference_rects, tol=row_tol)
cols = group_columns(reference_rects, tol=col_tol)
if not rows or not cols:
return []
row_ys = [int(np.mean([y+h/2.0 for (x,y,w,h) in r])) for r in rows]
col_xs = [int(np.mean([x+w/2.0 for (x,y,w,h) in c])) for c in cols]
centers_existing = [(int(x+w/2), int(y+h/2)) for (x,y,w,h) in reference_rects]
synth_boxes = []
tol_x = avg_w * 0.45
tol_y = avg_h * 0.45
for ry in row_ys:
for cx in col_xs:
exists = any(abs(ex[0]-cx)<tol_x and abs(ex[1]-ry)<tol_y for ex in centers_existing)
if not exists:
x = int(cx - avg_w/2)
y = int(ry - avg_h/2)
synth_boxes.append({'x': x, 'y': y, 'w': avg_w, 'h': avg_h, 'synthetic': True})
return synth_boxes
def split_left_right(rects, img, left_frac):
if not rects:
return [], []
h, w = img.shape[:2]
left = [r for r in rects if (r[0] + r[2]/2.0) < left_frac * w]
right = [r for r in rects if r not in left]
return sorted(left, key=lambda b: b[1]), sorted(right, key=lambda b: (b[1], b[0]))
def match_left_to_right(left_boxes, right_boxes):
mapping = {}
for i, tb in enumerate(left_boxes):
key = f"test_box_{i+1}"
tx, ty, tw, th = tb
tcy = ty + th/2.0
mapping[key] = {"test_box": [int(tx), int(ty), int(tw), int(th)], "matched_refs": []}
for rb in right_boxes:
x, y, w, h = rb['x'], rb['y'], rb['w'], rb['h']
cy = y + h/2.0
avg_h = (th + h)/2.0
if abs(tcy - cy) <= ROW_TOL * avg_h:
mapping[key]["matched_refs"].append({k:int(v) for k,v in rb.items() if k!="synthetic"})
return mapping
def visualize(img, left, right, mapping):
vis = img.copy()
# --- Draw left (test) boxes ---
for i, tb in enumerate(left):
x, y, w, h = tb
cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2)
cv2.putText(
vis, f"T{i+1}", (x, y - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 0, 0), 1
)
# --- Draw right (reference) boxes (skip synthetic ones) ---
for j, rb in enumerate(right):
if rb.get('synthetic', False):
continue # skip synthetic boxes to avoid double-drawing
color = (255, 0, 255) # magenta for right boxes
cv2.rectangle(
vis,
(rb['x'], rb['y']),
(rb['x'] + rb['w'], rb['y'] + rb['h']),
color,
1
)
# --- Draw red matching lines ---
for i, tb in enumerate(left):
key = f"test_box_{i+1}"
tx, ty, tw, th = tb
tcx, tcy = int(tx + tw / 2), int(ty + th / 2)
for rb in mapping.get(key, {}).get("matched_refs", []):
rcx = int(rb["x"] + rb["w"] / 2)
rcy = int(rb["y"] + rb["h"] / 2)
cv2.line(vis, (tcx, tcy), (rcx, rcy), (0, 0, 255), 1)
# Convert BGR → RGB for Gradio display
vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
return vis_rgb
def keep_one_box_per_row(rects, reference_rects=None, row_tol=ROW_TOL):
"""
Keep only one representative box per row.
Selection score for each box in a row is based on:
- closeness of box area to expected (dominant) area
- closeness of aspect ratio (w/h) to expected aspect ratio
- small penalty for very skinny or very flat boxes
reference_rects: list of all detected rects (used to compute expected area/aspect ratio).
"""
if not rects:
return rects
# Compute expected statistics from reference_rects (fallback to rects if None)
ref = reference_rects if (reference_rects and len(reference_rects) > 0) else rects
areas_ref = np.array([w * h for (_, _, w, h) in ref], dtype=float)
ars_ref = np.array([w / float(h) for (_, _, w, h) in ref], dtype=float)
# Robust central estimates (median)
expected_area = float(np.median(areas_ref))
expected_ar = float(np.median(ars_ref))
# Safety floor
if expected_area <= 0:
expected_area = np.mean(areas_ref) if len(areas_ref) else 1.0
if expected_ar <= 0:
expected_ar = 1.0
# Group rectangles into rows by vertical center proximity
rects_sorted = sorted(rects, key=lambda b: b[1])
rows = [[rects_sorted[0]]]
for r in rects_sorted[1:]:
y_center = r[1] + r[3] / 2.0
last = rows[-1][-1]
last_center = last[1] + last[3] / 2.0
avg_h = (r[3] + last[3]) / 2.0
if abs(y_center - last_center) <= row_tol * avg_h:
rows[-1].append(r)
else:
rows.append([r])
kept = []
for group in rows:
if len(group) == 1:
kept.append(group[0])
continue
# Compute a score for each candidate; lower is better
scores = []
for (x, y, w, h) in group:
area = w * h
ar = w / float(h) if h > 0 else 0.0
# area closeness: use log-ratio so relative differences are symmetric
area_score = abs(np.log((area + 1e-6) / (expected_area + 1e-6)))
# aspect ratio closeness (normalized)
ar_score = abs(ar - expected_ar) / (expected_ar + 1e-6)
# penalty for extremely skinny or extremely tall flat boxes
penalty = 0.0
if ar < 0.25: # very skinny tall
penalty += 1.0
if ar > 4.0: # very wide flat (unlikely in left column but defensive)
penalty += 0.6
# small preference toward boxes centered horizontally in the row (optional)
# compute row median x center
group_centers_x = [g[0] + g[2]/2.0 for g in group]
median_cx = float(np.median(group_centers_x))
cx = x + w/2.0
center_score = abs(cx - median_cx) / (expected_ar * np.sqrt(expected_area) + 1.0)
# combine scores with weights (tune if needed)
score = (2.0 * area_score) + (1.2 * ar_score) + (0.5 * center_score) + penalty
scores.append(score)
best_idx = int(np.argmin(scores))
best_box = group[best_idx]
kept.append(best_box)
# optional: sort kept boxes by y
kept = sorted(kept, key=lambda b: b[1])
print(f"Kept {len(kept)} boxes (one per row) out of {len(rects)} candidates.")
return kept
def clean_mapping(mapping, left_boxes):
"""
Clean the mapping dictionary by:
1. Removing any matched_refs that duplicate a test_box in left_boxes.
2. Removing duplicate test_boxes in the mapping.
Args:
mapping (dict): Original mapping from match_left_to_right.
left_boxes (list of tuples): List of test boxes [(x, y, w, h), ...].
Returns:
dict: Cleaned mapping.
"""
# Step 1: Remove matched_refs that duplicate any test_box
all_test_boxes = set(tuple(tb) for tb in left_boxes)
for key, val in mapping.items():
cleaned_refs = []
for ref in val.get("matched_refs", []):
ref_box = (ref["x"], ref["y"], ref["w"], ref["h"])
if ref_box not in all_test_boxes and ref_box not in cleaned_refs:
cleaned_refs.append(ref_box)
val["matched_refs"] = [{"x": x, "y": y, "w": w, "h": h} for x, y, w, h in cleaned_refs]
# Step 2: Remove duplicate test_boxes
seen_test_boxes = set()
cleaned_mapping = {}
for key, val in mapping.items():
tb = tuple(val["test_box"])
if tb not in seen_test_boxes:
seen_test_boxes.add(tb)
cleaned_mapping[key] = val
return cleaned_mapping
# ----------------- Pipeline -----------------
def process_image(image, left_frac):
img_bgr = load_and_resize(image)
rects = collect_candidates(img_bgr)
rects = filter_by_area(rects)
synth = fill_missing_boxes(img_bgr, rects)
all_boxes = rects + [(b['x'], b['y'], b['w'], b['h']) for b in synth]
left, right = split_left_right(all_boxes, img_bgr, left_frac)
left = keep_one_box_per_row(left)
right_with_synth = [{'x':x,'y':y,'w':w,'h':h,'synthetic':False} for (x,y,w,h) in right] + synth
mapping = match_left_to_right(left, right_with_synth)
mapping = clean_mapping(mapping, left)
result = visualize(img_bgr, left, right_with_synth, mapping)
return result
# ----------------- Gradio App -----------------
# ----------------- Gradio App -----------------
title = "Grid Detection & Matching Viewer"
description = "Upload an image, adjust the Left/Right threshold, and view final matching visualization."
# Add example images (place them in the same folder or give full paths)
examples = [
["2.png", 0.28],
["4.jpg", 0.35],
["5.jpg", 0.35],
]
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Slider(0.1, 0.9, value=0.35, step=0.01, label="Left Fraction Threshold (LEFT_FRAC_FALLBACK)")
],
outputs=gr.Image(label="Matched Output", type="numpy"),
title=title,
description=description,
examples=examples # ✅ Add examples here
)
if __name__ == "__main__":
iface.launch()