Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| import cv2 | |
| import numpy as np | |
| from sklearn.cluster import KMeans | |
| # ----------------- Config ----------------- | |
| RESIZE_MAX = 1600 | |
| MIN_AREA = 300 | |
| MAX_AREA = 120000 | |
| APPROX_EPS = 0.06 | |
| IOU_NMS = 0.25 | |
| COLOR_CLUSTER_N = 6 | |
| SAT_MIN = 20 | |
| VAL_MIN = 20 | |
| ROW_TOL = 0.75 | |
| AREA_FILTER_THRESH = 0.35 | |
| # ----------------- Utility Functions ----------------- | |
| def load_and_resize(img_or_path, max_dim=RESIZE_MAX): | |
| if isinstance(img_or_path, str): # file path | |
| img = cv2.imread(img_or_path) | |
| if img is None: | |
| raise FileNotFoundError(f"Image not found: {img_or_path}") | |
| elif isinstance(img_or_path, np.ndarray): # already loaded | |
| img = img_or_path.copy() | |
| else: | |
| raise ValueError("Input must be a file path or a numpy array") | |
| h, w = img.shape[:2] | |
| if max(h, w) > max_dim: | |
| scale = max_dim / float(max(h, w)) | |
| img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA) | |
| return img | |
| def non_max_suppression(boxes, iou_thresh=IOU_NMS): | |
| if not boxes: | |
| return [] | |
| arr = np.array(boxes, dtype=float) | |
| x1 = arr[:, 0]; y1 = arr[:, 1]; x2 = arr[:, 0] + arr[:, 2]; y2 = arr[:, 1] + arr[:, 3] | |
| areas = (x2 - x1) * (y2 - y1) | |
| order = areas.argsort()[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(tuple(arr[i].astype(int))) | |
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |
| w = np.maximum(0.0, xx2 - xx1) | |
| h = np.maximum(0.0, yy2 - yy1) | |
| inter = w * h | |
| union = areas[i] + areas[order[1:]] - inter | |
| iou = inter / (union + 1e-8) | |
| inds = np.where(iou <= iou_thresh)[0] | |
| order = order[inds + 1] | |
| return keep | |
| def color_cluster_masks(img, n_clusters=COLOR_CLUSTER_N): | |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) | |
| h, w = lab.shape[:2] | |
| pixels = lab.reshape(-1, 3).astype(np.float32) | |
| max_samples = 30000 | |
| if pixels.shape[0] > max_samples: | |
| rng = np.random.default_rng(0) | |
| sample_idx = rng.choice(pixels.shape[0], max_samples, replace=False) | |
| sample = pixels[sample_idx] | |
| else: | |
| sample = pixels | |
| n_clusters = min(n_clusters, max(1, sample.shape[0])) | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=8).fit(sample) | |
| centers = kmeans.cluster_centers_ | |
| centers_f = centers.astype(np.float32).reshape(1, 1, n_clusters, 3) | |
| lab_f = lab.astype(np.float32).reshape(h, w, 1, 3) | |
| diff = lab_f - centers_f | |
| dist = np.linalg.norm(diff, axis=3) | |
| labels = np.argmin(dist, axis=2).astype(np.int32) | |
| masks = [(labels == k).astype(np.uint8) * 255 for k in range(n_clusters)] | |
| return masks | |
| def refine_mask_by_hsv(mask, img, sat_min=SAT_MIN, val_min=VAL_MIN): | |
| hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
| s = hsv[:, :, 1]; v = hsv[:, :, 2] | |
| sv_mask = (s >= sat_min) & (v >= val_min) | |
| refined = mask.copy() | |
| refined[~sv_mask] = 0 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) | |
| refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel, iterations=1) | |
| return refined | |
| def contours_from_mask(mask): | |
| cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| rects = [] | |
| for c in cnts: | |
| area = cv2.contourArea(c) | |
| if area < MIN_AREA or area > MAX_AREA: | |
| continue | |
| peri = cv2.arcLength(c, True) | |
| approx = cv2.approxPolyDP(c, APPROX_EPS * peri, True) | |
| x, y, w, h = cv2.boundingRect(approx) | |
| if h == 0 or w == 0: | |
| continue | |
| ar = w / float(h) | |
| if 0.12 < ar < 8: | |
| rects.append((x, y, w, h)) | |
| return rects | |
| def mser_candidates(img): | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| mser = cv2.MSER_create() | |
| mser.setMinArea(60) | |
| mser.setMaxArea(MAX_AREA) | |
| regions, _ = mser.detectRegions(gray) | |
| rects = [] | |
| for r in regions: | |
| x, y, w, h = cv2.boundingRect(r.reshape(-1, 1, 2)) | |
| area = w * h | |
| if area < MIN_AREA or area > MAX_AREA: | |
| continue | |
| ar = w / float(h) if h > 0 else 0 | |
| if 0.25 < ar < 4.0: | |
| rects.append((x, y, w, h)) | |
| return rects | |
| def collect_candidates(img): | |
| masks = color_cluster_masks(img, n_clusters=COLOR_CLUSTER_N) | |
| cluster_rects = [] | |
| for m in masks: | |
| refined = refine_mask_by_hsv(m, img) | |
| rects = contours_from_mask(refined) | |
| cluster_rects.extend(rects) | |
| mser_rects = mser_candidates(img) | |
| all_rects = cluster_rects + mser_rects | |
| nms = non_max_suppression(all_rects, IOU_NMS) | |
| return nms | |
| def filter_by_area(rects): | |
| if not rects: | |
| return rects | |
| areas = np.array([w * h for (_, _, w, h) in rects], dtype=float) | |
| avg_area = np.mean(areas) | |
| lower = avg_area * (1.0 - AREA_FILTER_THRESH) | |
| upper = avg_area * (1.0 + AREA_FILTER_THRESH) | |
| return [r for r, a in zip(rects, areas) if lower <= a <= upper] | |
| def group_rows(rects, tol=ROW_TOL): | |
| if not rects: | |
| return [] | |
| rects = sorted(rects, key=lambda b: b[1]) | |
| rows = [[rects[0]]] | |
| for r in rects[1:]: | |
| prev = rows[-1][-1] | |
| y1 = prev[1] + prev[3] / 2.0 | |
| y2 = r[1] + r[3] / 2.0 | |
| avg_h = (prev[3] + r[3]) / 2.0 | |
| if abs(y1 - y2) <= tol * avg_h: | |
| rows[-1].append(r) | |
| else: | |
| rows.append([r]) | |
| return rows | |
| def group_columns(rects, tol=ROW_TOL): | |
| if not rects: | |
| return [] | |
| rects = sorted(rects, key=lambda b: b[0]) | |
| cols = [[rects[0]]] | |
| for r in rects[1:]: | |
| prev = cols[-1][-1] | |
| x1 = prev[0] + prev[2] / 2.0 | |
| x2 = r[0] + r[2] / 2.0 | |
| avg_w = (prev[2] + r[2]) / 2.0 | |
| if abs(x1 - x2) <= tol * avg_w: | |
| cols[-1].append(r) | |
| else: | |
| cols.append([r]) | |
| return cols | |
| def fill_missing_boxes(img, reference_rects, row_tol=ROW_TOL, col_tol=ROW_TOL): | |
| if not reference_rects: | |
| return [] | |
| areas = [w * h for (_, _, w, h) in reference_rects] | |
| rounded_areas = [int(a // 100) * 100 for a in areas] | |
| unique, counts = np.unique(rounded_areas, return_counts=True) | |
| most_common_area = unique[np.argmax(counts)] | |
| closest_box = min(reference_rects, key=lambda r: abs((r[2]*r[3]) - most_common_area)) | |
| avg_w, avg_h = int(closest_box[2]), int(closest_box[3]) | |
| rows = group_rows(reference_rects, tol=row_tol) | |
| cols = group_columns(reference_rects, tol=col_tol) | |
| if not rows or not cols: | |
| return [] | |
| row_ys = [int(np.mean([y+h/2.0 for (x,y,w,h) in r])) for r in rows] | |
| col_xs = [int(np.mean([x+w/2.0 for (x,y,w,h) in c])) for c in cols] | |
| centers_existing = [(int(x+w/2), int(y+h/2)) for (x,y,w,h) in reference_rects] | |
| synth_boxes = [] | |
| tol_x = avg_w * 0.45 | |
| tol_y = avg_h * 0.45 | |
| for ry in row_ys: | |
| for cx in col_xs: | |
| exists = any(abs(ex[0]-cx)<tol_x and abs(ex[1]-ry)<tol_y for ex in centers_existing) | |
| if not exists: | |
| x = int(cx - avg_w/2) | |
| y = int(ry - avg_h/2) | |
| synth_boxes.append({'x': x, 'y': y, 'w': avg_w, 'h': avg_h, 'synthetic': True}) | |
| return synth_boxes | |
| def split_left_right(rects, img, left_frac): | |
| if not rects: | |
| return [], [] | |
| h, w = img.shape[:2] | |
| left = [r for r in rects if (r[0] + r[2]/2.0) < left_frac * w] | |
| right = [r for r in rects if r not in left] | |
| return sorted(left, key=lambda b: b[1]), sorted(right, key=lambda b: (b[1], b[0])) | |
| def match_left_to_right(left_boxes, right_boxes): | |
| mapping = {} | |
| for i, tb in enumerate(left_boxes): | |
| key = f"test_box_{i+1}" | |
| tx, ty, tw, th = tb | |
| tcy = ty + th/2.0 | |
| mapping[key] = {"test_box": [int(tx), int(ty), int(tw), int(th)], "matched_refs": []} | |
| for rb in right_boxes: | |
| x, y, w, h = rb['x'], rb['y'], rb['w'], rb['h'] | |
| cy = y + h/2.0 | |
| avg_h = (th + h)/2.0 | |
| if abs(tcy - cy) <= ROW_TOL * avg_h: | |
| mapping[key]["matched_refs"].append({k:int(v) for k,v in rb.items() if k!="synthetic"}) | |
| return mapping | |
| def visualize(img, left, right, mapping): | |
| vis = img.copy() | |
| # --- Draw left (test) boxes --- | |
| for i, tb in enumerate(left): | |
| x, y, w, h = tb | |
| cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2) | |
| cv2.putText( | |
| vis, f"T{i+1}", (x, y - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 0, 0), 1 | |
| ) | |
| # --- Draw right (reference) boxes (skip synthetic ones) --- | |
| for j, rb in enumerate(right): | |
| if rb.get('synthetic', False): | |
| continue # skip synthetic boxes to avoid double-drawing | |
| color = (255, 0, 255) # magenta for right boxes | |
| cv2.rectangle( | |
| vis, | |
| (rb['x'], rb['y']), | |
| (rb['x'] + rb['w'], rb['y'] + rb['h']), | |
| color, | |
| 1 | |
| ) | |
| # --- Draw red matching lines --- | |
| for i, tb in enumerate(left): | |
| key = f"test_box_{i+1}" | |
| tx, ty, tw, th = tb | |
| tcx, tcy = int(tx + tw / 2), int(ty + th / 2) | |
| for rb in mapping.get(key, {}).get("matched_refs", []): | |
| rcx = int(rb["x"] + rb["w"] / 2) | |
| rcy = int(rb["y"] + rb["h"] / 2) | |
| cv2.line(vis, (tcx, tcy), (rcx, rcy), (0, 0, 255), 1) | |
| # Convert BGR → RGB for Gradio display | |
| vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB) | |
| return vis_rgb | |
| def keep_one_box_per_row(rects, reference_rects=None, row_tol=ROW_TOL): | |
| """ | |
| Keep only one representative box per row. | |
| Selection score for each box in a row is based on: | |
| - closeness of box area to expected (dominant) area | |
| - closeness of aspect ratio (w/h) to expected aspect ratio | |
| - small penalty for very skinny or very flat boxes | |
| reference_rects: list of all detected rects (used to compute expected area/aspect ratio). | |
| """ | |
| if not rects: | |
| return rects | |
| # Compute expected statistics from reference_rects (fallback to rects if None) | |
| ref = reference_rects if (reference_rects and len(reference_rects) > 0) else rects | |
| areas_ref = np.array([w * h for (_, _, w, h) in ref], dtype=float) | |
| ars_ref = np.array([w / float(h) for (_, _, w, h) in ref], dtype=float) | |
| # Robust central estimates (median) | |
| expected_area = float(np.median(areas_ref)) | |
| expected_ar = float(np.median(ars_ref)) | |
| # Safety floor | |
| if expected_area <= 0: | |
| expected_area = np.mean(areas_ref) if len(areas_ref) else 1.0 | |
| if expected_ar <= 0: | |
| expected_ar = 1.0 | |
| # Group rectangles into rows by vertical center proximity | |
| rects_sorted = sorted(rects, key=lambda b: b[1]) | |
| rows = [[rects_sorted[0]]] | |
| for r in rects_sorted[1:]: | |
| y_center = r[1] + r[3] / 2.0 | |
| last = rows[-1][-1] | |
| last_center = last[1] + last[3] / 2.0 | |
| avg_h = (r[3] + last[3]) / 2.0 | |
| if abs(y_center - last_center) <= row_tol * avg_h: | |
| rows[-1].append(r) | |
| else: | |
| rows.append([r]) | |
| kept = [] | |
| for group in rows: | |
| if len(group) == 1: | |
| kept.append(group[0]) | |
| continue | |
| # Compute a score for each candidate; lower is better | |
| scores = [] | |
| for (x, y, w, h) in group: | |
| area = w * h | |
| ar = w / float(h) if h > 0 else 0.0 | |
| # area closeness: use log-ratio so relative differences are symmetric | |
| area_score = abs(np.log((area + 1e-6) / (expected_area + 1e-6))) | |
| # aspect ratio closeness (normalized) | |
| ar_score = abs(ar - expected_ar) / (expected_ar + 1e-6) | |
| # penalty for extremely skinny or extremely tall flat boxes | |
| penalty = 0.0 | |
| if ar < 0.25: # very skinny tall | |
| penalty += 1.0 | |
| if ar > 4.0: # very wide flat (unlikely in left column but defensive) | |
| penalty += 0.6 | |
| # small preference toward boxes centered horizontally in the row (optional) | |
| # compute row median x center | |
| group_centers_x = [g[0] + g[2]/2.0 for g in group] | |
| median_cx = float(np.median(group_centers_x)) | |
| cx = x + w/2.0 | |
| center_score = abs(cx - median_cx) / (expected_ar * np.sqrt(expected_area) + 1.0) | |
| # combine scores with weights (tune if needed) | |
| score = (2.0 * area_score) + (1.2 * ar_score) + (0.5 * center_score) + penalty | |
| scores.append(score) | |
| best_idx = int(np.argmin(scores)) | |
| best_box = group[best_idx] | |
| kept.append(best_box) | |
| # optional: sort kept boxes by y | |
| kept = sorted(kept, key=lambda b: b[1]) | |
| print(f"Kept {len(kept)} boxes (one per row) out of {len(rects)} candidates.") | |
| return kept | |
| def clean_mapping(mapping, left_boxes): | |
| """ | |
| Clean the mapping dictionary by: | |
| 1. Removing any matched_refs that duplicate a test_box in left_boxes. | |
| 2. Removing duplicate test_boxes in the mapping. | |
| Args: | |
| mapping (dict): Original mapping from match_left_to_right. | |
| left_boxes (list of tuples): List of test boxes [(x, y, w, h), ...]. | |
| Returns: | |
| dict: Cleaned mapping. | |
| """ | |
| # Step 1: Remove matched_refs that duplicate any test_box | |
| all_test_boxes = set(tuple(tb) for tb in left_boxes) | |
| for key, val in mapping.items(): | |
| cleaned_refs = [] | |
| for ref in val.get("matched_refs", []): | |
| ref_box = (ref["x"], ref["y"], ref["w"], ref["h"]) | |
| if ref_box not in all_test_boxes and ref_box not in cleaned_refs: | |
| cleaned_refs.append(ref_box) | |
| val["matched_refs"] = [{"x": x, "y": y, "w": w, "h": h} for x, y, w, h in cleaned_refs] | |
| # Step 2: Remove duplicate test_boxes | |
| seen_test_boxes = set() | |
| cleaned_mapping = {} | |
| for key, val in mapping.items(): | |
| tb = tuple(val["test_box"]) | |
| if tb not in seen_test_boxes: | |
| seen_test_boxes.add(tb) | |
| cleaned_mapping[key] = val | |
| return cleaned_mapping | |
| # ----------------- Pipeline ----------------- | |
| def process_image(image, left_frac): | |
| img_bgr = load_and_resize(image) | |
| rects = collect_candidates(img_bgr) | |
| rects = filter_by_area(rects) | |
| synth = fill_missing_boxes(img_bgr, rects) | |
| all_boxes = rects + [(b['x'], b['y'], b['w'], b['h']) for b in synth] | |
| left, right = split_left_right(all_boxes, img_bgr, left_frac) | |
| left = keep_one_box_per_row(left) | |
| right_with_synth = [{'x':x,'y':y,'w':w,'h':h,'synthetic':False} for (x,y,w,h) in right] + synth | |
| mapping = match_left_to_right(left, right_with_synth) | |
| mapping = clean_mapping(mapping, left) | |
| result = visualize(img_bgr, left, right_with_synth, mapping) | |
| return result | |
| # ----------------- Gradio App ----------------- | |
| # ----------------- Gradio App ----------------- | |
| title = "Grid Detection & Matching Viewer" | |
| description = "Upload an image, adjust the Left/Right threshold, and view final matching visualization." | |
| # Add example images (place them in the same folder or give full paths) | |
| examples = [ | |
| ["2.png", 0.28], | |
| ["4.jpg", 0.35], | |
| ["5.jpg", 0.35], | |
| ] | |
| iface = gr.Interface( | |
| fn=process_image, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Upload Image"), | |
| gr.Slider(0.1, 0.9, value=0.35, step=0.01, label="Left Fraction Threshold (LEFT_FRAC_FALLBACK)") | |
| ], | |
| outputs=gr.Image(label="Matched Output", type="numpy"), | |
| title=title, | |
| description=description, | |
| examples=examples # ✅ Add examples here | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |