File size: 15,649 Bytes
4683e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a6febf
 
4683e68
9a6febf
 
 
 
 
 
 
 
4683e68
9a6febf
 
 
 
 
 
 
 
 
 
 
4683e68
 
 
9a6febf
 
 
 
 
 
 
4683e68
 
9a6febf
4683e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e90325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4683e68
 
 
 
 
 
 
 
 
 
 
 
8e90325
4683e68
 
 
8facb11
4683e68
 
 
 
8facb11
 
 
e9dce51
8facb11
 
 
4683e68
 
 
 
 
 
 
 
 
8facb11
4683e68
 
 
 
8facb11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
import gradio as gr
import os
import json
import cv2
import numpy as np
from sklearn.cluster import KMeans

# ----------------- Config -----------------
RESIZE_MAX = 1600
MIN_AREA = 300
MAX_AREA = 120000
APPROX_EPS = 0.06
IOU_NMS = 0.25
COLOR_CLUSTER_N = 6
SAT_MIN = 20
VAL_MIN = 20
ROW_TOL = 0.75
AREA_FILTER_THRESH = 0.35

# ----------------- Utility Functions -----------------
def load_and_resize(img_or_path, max_dim=RESIZE_MAX):
    if isinstance(img_or_path, str):  # file path
        img = cv2.imread(img_or_path)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_or_path}")
    elif isinstance(img_or_path, np.ndarray):  # already loaded
        img = img_or_path.copy()
    else:
        raise ValueError("Input must be a file path or a numpy array")

    h, w = img.shape[:2]
    if max(h, w) > max_dim:
        scale = max_dim / float(max(h, w))
        img = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
    return img


def non_max_suppression(boxes, iou_thresh=IOU_NMS):
    if not boxes:
        return []
    arr = np.array(boxes, dtype=float)
    x1 = arr[:, 0]; y1 = arr[:, 1]; x2 = arr[:, 0] + arr[:, 2]; y2 = arr[:, 1] + arr[:, 3]
    areas = (x2 - x1) * (y2 - y1)
    order = areas.argsort()[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(tuple(arr[i].astype(int)))
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])
        w = np.maximum(0.0, xx2 - xx1)
        h = np.maximum(0.0, yy2 - yy1)
        inter = w * h
        union = areas[i] + areas[order[1:]] - inter
        iou = inter / (union + 1e-8)
        inds = np.where(iou <= iou_thresh)[0]
        order = order[inds + 1]
    return keep

def color_cluster_masks(img, n_clusters=COLOR_CLUSTER_N):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    h, w = lab.shape[:2]
    pixels = lab.reshape(-1, 3).astype(np.float32)
    max_samples = 30000
    if pixels.shape[0] > max_samples:
        rng = np.random.default_rng(0)
        sample_idx = rng.choice(pixels.shape[0], max_samples, replace=False)
        sample = pixels[sample_idx]
    else:
        sample = pixels
    n_clusters = min(n_clusters, max(1, sample.shape[0]))
    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=8).fit(sample)
    centers = kmeans.cluster_centers_
    centers_f = centers.astype(np.float32).reshape(1, 1, n_clusters, 3)
    lab_f = lab.astype(np.float32).reshape(h, w, 1, 3)
    diff = lab_f - centers_f
    dist = np.linalg.norm(diff, axis=3)
    labels = np.argmin(dist, axis=2).astype(np.int32)
    masks = [(labels == k).astype(np.uint8) * 255 for k in range(n_clusters)]
    return masks

def refine_mask_by_hsv(mask, img, sat_min=SAT_MIN, val_min=VAL_MIN):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    s = hsv[:, :, 1]; v = hsv[:, :, 2]
    sv_mask = (s >= sat_min) & (v >= val_min)
    refined = mask.copy()
    refined[~sv_mask] = 0
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel, iterations=2)
    refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel, iterations=1)
    return refined

def contours_from_mask(mask):
    cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    rects = []
    for c in cnts:
        area = cv2.contourArea(c)
        if area < MIN_AREA or area > MAX_AREA:
            continue
        peri = cv2.arcLength(c, True)
        approx = cv2.approxPolyDP(c, APPROX_EPS * peri, True)
        x, y, w, h = cv2.boundingRect(approx)
        if h == 0 or w == 0:
            continue
        ar = w / float(h)
        if 0.12 < ar < 8:
            rects.append((x, y, w, h))
    return rects

def mser_candidates(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    mser = cv2.MSER_create()
    mser.setMinArea(60)
    mser.setMaxArea(MAX_AREA)
    regions, _ = mser.detectRegions(gray)
    rects = []
    for r in regions:
        x, y, w, h = cv2.boundingRect(r.reshape(-1, 1, 2))
        area = w * h
        if area < MIN_AREA or area > MAX_AREA:
            continue
        ar = w / float(h) if h > 0 else 0
        if 0.25 < ar < 4.0:
            rects.append((x, y, w, h))
    return rects

def collect_candidates(img):
    masks = color_cluster_masks(img, n_clusters=COLOR_CLUSTER_N)
    cluster_rects = []
    for m in masks:
        refined = refine_mask_by_hsv(m, img)
        rects = contours_from_mask(refined)
        cluster_rects.extend(rects)
    mser_rects = mser_candidates(img)
    all_rects = cluster_rects + mser_rects
    nms = non_max_suppression(all_rects, IOU_NMS)
    return nms

def filter_by_area(rects):
    if not rects:
        return rects
    areas = np.array([w * h for (_, _, w, h) in rects], dtype=float)
    avg_area = np.mean(areas)
    lower = avg_area * (1.0 - AREA_FILTER_THRESH)
    upper = avg_area * (1.0 + AREA_FILTER_THRESH)
    return [r for r, a in zip(rects, areas) if lower <= a <= upper]

def group_rows(rects, tol=ROW_TOL):
    if not rects:
        return []
    rects = sorted(rects, key=lambda b: b[1])
    rows = [[rects[0]]]
    for r in rects[1:]:
        prev = rows[-1][-1]
        y1 = prev[1] + prev[3] / 2.0
        y2 = r[1] + r[3] / 2.0
        avg_h = (prev[3] + r[3]) / 2.0
        if abs(y1 - y2) <= tol * avg_h:
            rows[-1].append(r)
        else:
            rows.append([r])
    return rows

def group_columns(rects, tol=ROW_TOL):
    if not rects:
        return []
    rects = sorted(rects, key=lambda b: b[0])
    cols = [[rects[0]]]
    for r in rects[1:]:
        prev = cols[-1][-1]
        x1 = prev[0] + prev[2] / 2.0
        x2 = r[0] + r[2] / 2.0
        avg_w = (prev[2] + r[2]) / 2.0
        if abs(x1 - x2) <= tol * avg_w:
            cols[-1].append(r)
        else:
            cols.append([r])
    return cols

def fill_missing_boxes(img, reference_rects, row_tol=ROW_TOL, col_tol=ROW_TOL):
    if not reference_rects:
        return []
    areas = [w * h for (_, _, w, h) in reference_rects]
    rounded_areas = [int(a // 100) * 100 for a in areas]
    unique, counts = np.unique(rounded_areas, return_counts=True)
    most_common_area = unique[np.argmax(counts)]
    closest_box = min(reference_rects, key=lambda r: abs((r[2]*r[3]) - most_common_area))
    avg_w, avg_h = int(closest_box[2]), int(closest_box[3])
    rows = group_rows(reference_rects, tol=row_tol)
    cols = group_columns(reference_rects, tol=col_tol)
    if not rows or not cols:
        return []
    row_ys = [int(np.mean([y+h/2.0 for (x,y,w,h) in r])) for r in rows]
    col_xs = [int(np.mean([x+w/2.0 for (x,y,w,h) in c])) for c in cols]
    centers_existing = [(int(x+w/2), int(y+h/2)) for (x,y,w,h) in reference_rects]
    synth_boxes = []
    tol_x = avg_w * 0.45
    tol_y = avg_h * 0.45
    for ry in row_ys:
        for cx in col_xs:
            exists = any(abs(ex[0]-cx)<tol_x and abs(ex[1]-ry)<tol_y for ex in centers_existing)
            if not exists:
                x = int(cx - avg_w/2)
                y = int(ry - avg_h/2)
                synth_boxes.append({'x': x, 'y': y, 'w': avg_w, 'h': avg_h, 'synthetic': True})
    return synth_boxes

def split_left_right(rects, img, left_frac):
    if not rects:
        return [], []
    h, w = img.shape[:2]
    left = [r for r in rects if (r[0] + r[2]/2.0) < left_frac * w]
    right = [r for r in rects if r not in left]
    return sorted(left, key=lambda b: b[1]), sorted(right, key=lambda b: (b[1], b[0]))

def match_left_to_right(left_boxes, right_boxes):
    mapping = {}
    for i, tb in enumerate(left_boxes):
        key = f"test_box_{i+1}"
        tx, ty, tw, th = tb
        tcy = ty + th/2.0
        mapping[key] = {"test_box": [int(tx), int(ty), int(tw), int(th)], "matched_refs": []}
        for rb in right_boxes:
            x, y, w, h = rb['x'], rb['y'], rb['w'], rb['h']
            cy = y + h/2.0
            avg_h = (th + h)/2.0
            if abs(tcy - cy) <= ROW_TOL * avg_h:
                mapping[key]["matched_refs"].append({k:int(v) for k,v in rb.items() if k!="synthetic"})
    return mapping

def visualize(img, left, right, mapping):
    vis = img.copy()

    # --- Draw left (test) boxes ---
    for i, tb in enumerate(left):
        x, y, w, h = tb
        cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2)
        cv2.putText(
            vis, f"T{i+1}", (x, y - 5),
            cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 0, 0), 1
        )

    # --- Draw right (reference) boxes (skip synthetic ones) ---
    for j, rb in enumerate(right):
        if rb.get('synthetic', False):
            continue  # skip synthetic boxes to avoid double-drawing
        color = (255, 0, 255)  # magenta for right boxes
        cv2.rectangle(
            vis,
            (rb['x'], rb['y']),
            (rb['x'] + rb['w'], rb['y'] + rb['h']),
            color,
            1
        )
    # --- Draw red matching lines ---
    for i, tb in enumerate(left):
        key = f"test_box_{i+1}"
        tx, ty, tw, th = tb
        tcx, tcy = int(tx + tw / 2), int(ty + th / 2)
        for rb in mapping.get(key, {}).get("matched_refs", []):
            rcx = int(rb["x"] + rb["w"] / 2)
            rcy = int(rb["y"] + rb["h"] / 2)
            cv2.line(vis, (tcx, tcy), (rcx, rcy), (0, 0, 255), 1)

    # Convert BGR → RGB for Gradio display
    vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
    return vis_rgb

def keep_one_box_per_row(rects, reference_rects=None, row_tol=ROW_TOL):
    """
    Keep only one representative box per row.
    Selection score for each box in a row is based on:
      - closeness of box area to expected (dominant) area
      - closeness of aspect ratio (w/h) to expected aspect ratio
      - small penalty for very skinny or very flat boxes
    reference_rects: list of all detected rects (used to compute expected area/aspect ratio).
    """
    if not rects:
        return rects

    # Compute expected statistics from reference_rects (fallback to rects if None)
    ref = reference_rects if (reference_rects and len(reference_rects) > 0) else rects
    areas_ref = np.array([w * h for (_, _, w, h) in ref], dtype=float)
    ars_ref = np.array([w / float(h) for (_, _, w, h) in ref], dtype=float)

    # Robust central estimates (median)
    expected_area = float(np.median(areas_ref))
    expected_ar = float(np.median(ars_ref))

    # Safety floor
    if expected_area <= 0:
        expected_area = np.mean(areas_ref) if len(areas_ref) else 1.0
    if expected_ar <= 0:
        expected_ar = 1.0

    # Group rectangles into rows by vertical center proximity
    rects_sorted = sorted(rects, key=lambda b: b[1])
    rows = [[rects_sorted[0]]]
    for r in rects_sorted[1:]:
        y_center = r[1] + r[3] / 2.0
        last = rows[-1][-1]
        last_center = last[1] + last[3] / 2.0
        avg_h = (r[3] + last[3]) / 2.0
        if abs(y_center - last_center) <= row_tol * avg_h:
            rows[-1].append(r)
        else:
            rows.append([r])

    kept = []
    for group in rows:
        if len(group) == 1:
            kept.append(group[0])
            continue

        # Compute a score for each candidate; lower is better
        scores = []
        for (x, y, w, h) in group:
            area = w * h
            ar = w / float(h) if h > 0 else 0.0

            # area closeness: use log-ratio so relative differences are symmetric
            area_score = abs(np.log((area + 1e-6) / (expected_area + 1e-6)))

            # aspect ratio closeness (normalized)
            ar_score = abs(ar - expected_ar) / (expected_ar + 1e-6)

            # penalty for extremely skinny or extremely tall flat boxes
            penalty = 0.0
            if ar < 0.25:     # very skinny tall
                penalty += 1.0
            if ar > 4.0:      # very wide flat (unlikely in left column but defensive)
                penalty += 0.6

            # small preference toward boxes centered horizontally in the row (optional)
            # compute row median x center
            group_centers_x = [g[0] + g[2]/2.0 for g in group]
            median_cx = float(np.median(group_centers_x))
            cx = x + w/2.0
            center_score = abs(cx - median_cx) / (expected_ar * np.sqrt(expected_area) + 1.0)

            # combine scores with weights (tune if needed)
            score = (2.0 * area_score) + (1.2 * ar_score) + (0.5 * center_score) + penalty
            scores.append(score)

        best_idx = int(np.argmin(scores))
        best_box = group[best_idx]
        kept.append(best_box)

    # optional: sort kept boxes by y
    kept = sorted(kept, key=lambda b: b[1])
    print(f"Kept {len(kept)} boxes (one per row) out of {len(rects)} candidates.")
    return kept
def clean_mapping(mapping, left_boxes):
    """
    Clean the mapping dictionary by:
      1. Removing any matched_refs that duplicate a test_box in left_boxes.
      2. Removing duplicate test_boxes in the mapping.
    
    Args:
        mapping (dict): Original mapping from match_left_to_right.
        left_boxes (list of tuples): List of test boxes [(x, y, w, h), ...].
    
    Returns:
        dict: Cleaned mapping.
    """
    # Step 1: Remove matched_refs that duplicate any test_box
    all_test_boxes = set(tuple(tb) for tb in left_boxes)
    for key, val in mapping.items():
        cleaned_refs = []
        for ref in val.get("matched_refs", []):
            ref_box = (ref["x"], ref["y"], ref["w"], ref["h"])
            if ref_box not in all_test_boxes and ref_box not in cleaned_refs:
                cleaned_refs.append(ref_box)
        val["matched_refs"] = [{"x": x, "y": y, "w": w, "h": h} for x, y, w, h in cleaned_refs]

    # Step 2: Remove duplicate test_boxes
    seen_test_boxes = set()
    cleaned_mapping = {}
    for key, val in mapping.items():
        tb = tuple(val["test_box"])
        if tb not in seen_test_boxes:
            seen_test_boxes.add(tb)
            cleaned_mapping[key] = val

    return cleaned_mapping

# ----------------- Pipeline -----------------
def process_image(image, left_frac):
    img_bgr = load_and_resize(image)
    rects = collect_candidates(img_bgr)
    rects = filter_by_area(rects)
    synth = fill_missing_boxes(img_bgr, rects)
    all_boxes = rects + [(b['x'], b['y'], b['w'], b['h']) for b in synth]
    left, right = split_left_right(all_boxes, img_bgr, left_frac)
    left = keep_one_box_per_row(left)
    right_with_synth = [{'x':x,'y':y,'w':w,'h':h,'synthetic':False} for (x,y,w,h) in right] + synth
    mapping = match_left_to_right(left, right_with_synth)
    mapping = clean_mapping(mapping, left)
    result = visualize(img_bgr, left, right_with_synth, mapping)
    return result

# ----------------- Gradio App -----------------
# ----------------- Gradio App -----------------
title = "Grid Detection & Matching Viewer"
description = "Upload an image, adjust the Left/Right threshold, and view final matching visualization."

# Add example images (place them in the same folder or give full paths)
examples = [
    ["2.png", 0.28],
    ["4.jpg", 0.35],
    ["5.jpg", 0.35],
]

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"),
        gr.Slider(0.1, 0.9, value=0.35, step=0.01, label="Left Fraction Threshold (LEFT_FRAC_FALLBACK)")
    ],
    outputs=gr.Image(label="Matched Output", type="numpy"),
    title=title,
    description=description,
    examples=examples  # ✅ Add examples here
)

if __name__ == "__main__":
    iface.launch()