smiler488 commited on
Commit
bf8d4d7
·
verified ·
1 Parent(s): 371cc99

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -83
app.py CHANGED
@@ -59,93 +59,85 @@ def detect_reference(
59
  mode: str,
60
  ref_size_mm: Optional[float],
61
  ) -> Tuple[float, Optional[Tuple[int, int]], Optional[str], Optional[Tuple[int, int, int, int]]]:
62
- """Detect reference object (coin / square) in top-left ROI.
63
 
64
- Parameters
65
- ----------
66
- img_bgr : np.ndarray
67
- Input BGR image.
68
- mode : {"auto", "coin", "square"}
69
- Detection strategy.
70
- ref_size_mm : float or None
71
- Real-world size (diameter for coin / side length for square).
72
-
73
- Returns
74
- -------
75
- px_per_mm : float
76
- Pixels per millimeter. Always > 0.
77
- center : (int, int) or None
78
- Reference object center in image coordinates.
79
- ref_type : str or None
80
- "coin" or "square" if detected, otherwise None.
81
  """
82
  h, w = img_bgr.shape[:2]
83
 
84
- # Use a ROI at the top-left to limit search cost
85
- roi_w = int(w * 0.25)
86
- roi_h = int(h * 0.25)
87
- roi = img_bgr[0:roi_h, 0:roi_w]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
90
- gray = cv2.medianBlur(gray, 5)
91
 
92
- px_per_mm: Optional[float] = None
93
  center: Optional[Tuple[int, int]] = None
94
  ref_type: Optional[str] = None
95
  bbox: Optional[Tuple[int, int, int, int]] = None
96
 
97
- # ----------------- coin detection -----------------
98
- if mode in ("auto", "coin"):
99
- circles = cv2.HoughCircles(
100
- gray,
101
- cv2.HOUGH_GRADIENT,
102
- dp=1.2,
103
- minDist=20,
104
- param1=120,
105
- param2=35,
106
- minRadius=8,
107
- maxRadius=min(roi_w, roi_h) // 2,
108
- )
109
- if circles is not None and len(circles) > 0:
110
- c = circles[0][0]
111
- r = float(c[2])
112
- d_px = 2.0 * r
113
- d_mm = ref_size_mm if ref_size_mm and ref_size_mm > 0 else 25.0
114
- px_per_mm = max(d_px / d_mm, 1e-6)
115
- center = (int(c[0]), int(c[1]))
116
- ref_type = "coin"
117
- bbox = (int(c[0] - r), int(c[1] - r), int(2 * r), int(2 * r))
118
-
119
- # ----------------- square detection -----------------
120
- if px_per_mm is None and mode in ("auto", "square"):
121
- edges = cv2.Canny(gray, 80, 160)
122
- cnts, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
123
- best = None
124
- best_score = 0.0
125
- for cnt in cnts:
126
- x, y, ww, hh = cv2.boundingRect(cnt)
127
- area = ww * hh
128
- if area < 225:
129
- continue
130
- score = min(ww, hh) / max(ww, hh)
131
- if score > best_score:
132
- best_score = score
133
- best = (x, y, ww, hh)
134
- if best is not None and best_score > 0.6:
135
- x, y, ww, hh = best
136
- s_px = float(max(ww, hh))
137
- s_mm = ref_size_mm if ref_size_mm and ref_size_mm > 0 else 20.0
138
- px_per_mm = max(s_px / s_mm, 1e-6)
139
- center = (x + ww // 2, y + hh // 2)
140
- ref_type = "square"
141
- bbox = (x, y, ww, hh)
142
 
143
- # ----------------- fallback -----------------
144
- if px_per_mm is None:
145
- # Fallback: approximate scale if no reference detected.
146
- # For typical scanner/phone images this is a safe range.
147
- # Use a slightly conservative default so values不会太夸张.
148
  px_per_mm = 4.0
 
 
 
149
 
150
  return px_per_mm, center, ref_type, bbox
151
 
@@ -394,10 +386,10 @@ def analyze(
394
  color_tol: int,
395
  hsv_low_h: int,
396
  hsv_high_h: int,
397
- ) -> Tuple[Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]]]:
398
  try:
399
  if image is None:
400
- return None, pd.DataFrame(), None, []
401
  img_rgb = np.array(image)
402
  img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
403
  img_bgr, scale = downscale_bgr(img_bgr)
@@ -416,13 +408,149 @@ def analyze(
416
  tmp.write(csv.encode("utf-8"))
417
  tmp.close()
418
  js = df.to_dict(orient="records")
419
- return overlay, df, tmp.name, js
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  except Exception as e:
421
- return None, pd.DataFrame(), None, [{"error": str(e)}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
 
424
  with gr.Blocks(theme=gr.themes.Default()) as demo:
425
  gr.Markdown("# Biological Sample Quantifier (Leaves / Seeds)")
 
426
  with gr.Row():
427
  with gr.Column(scale=1):
428
  image = gr.Image(type="numpy", label="Upload image")
@@ -435,17 +563,48 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
435
  color_tol = gr.Slider(5, 100, value=40, step=1, label="Color tolerance")
436
  hsv_low = gr.Slider(0, 179, value=35, step=1, label="HSV H lower (leaves)")
437
  hsv_high = gr.Slider(0, 179, value=85, step=1, label="HSV H upper (leaves)")
 
 
 
 
 
438
  run = gr.Button("Analyze")
439
  reset = gr.Button("Reset")
440
  with gr.Column(scale=2):
441
- overlay = gr.Image(label="Annotated")
442
  table = gr.Dataframe(label="Metrics", wrap=True)
443
  csv_out = gr.File(label="CSV export")
444
  json_out = gr.JSON(label="JSON preview")
445
  def _analyze(image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high):
446
- return analyze(image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high)
447
- run.click(_analyze, [image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high], [overlay, table, csv_out, json_out])
448
- reset.click(lambda: (None, pd.DataFrame(), None, []), None, [overlay, table, csv_out, json_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
  if __name__ == "__main__":
451
  demo.launch()
 
59
  mode: str,
60
  ref_size_mm: Optional[float],
61
  ) -> Tuple[float, Optional[Tuple[int, int]], Optional[str], Optional[Tuple[int, int, int, int]]]:
62
+ """Detect reference object (circle or square) using connected components.
63
 
64
+ Assumptions:
65
+ - White or near-white uniform background
66
+ - A single reference object is placed in the top-left region
67
+ - Reference is approximately square in its bounding box (square card or coin)
68
+ - ref_size_mm is the real diameter (coin) or side length (square)
 
 
 
 
 
 
 
 
 
 
 
 
69
  """
70
  h, w = img_bgr.shape[:2]
71
 
72
+ # 1. Estimate background color in LAB space and build "non-background" mask
73
+ # This works for any solid-color background as long as the reference object
74
+ # has a noticeable color difference from the background.
75
+ lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
76
+ # Use the median color of the whole image as background estimate (robust to small objects)
77
+ bg_color = np.median(lab.reshape(-1, 3), axis=0)
78
+
79
+ # Compute per-pixel Euclidean distance in LAB space
80
+ diff = lab - bg_color # shape (H, W, 3)
81
+ dist = np.sqrt(np.sum(diff * diff, axis=2)).astype(np.float32)
82
+
83
+ # Threshold on color distance: pixels far from background color are foreground
84
+ # You can tune 8.0 -> 6.0 or 10.0 depending on image contrast.
85
+ _, mask = cv2.threshold(dist, 8.0, 255, cv2.THRESH_BINARY)
86
+ mask = mask.astype(np.uint8)
87
+
88
+ # 2. Small morphological opening to remove noise
89
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
90
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
91
+
92
+ # 3. Connected components
93
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
94
+
95
+ # stats[i] = [x, y, w, h, area]
96
+ candidates = []
97
+ for i in range(1, num_labels): # skip label 0 (background)
98
+ x, y, ww, hh, area = stats[i]
99
+ if area < 400:
100
+ # too small, likely noise
101
+ continue
102
+
103
+ # Only consider objects in the upper-left region
104
+ if x > w * 0.6 or y > h * 0.6:
105
+ continue
106
+
107
+ # Require roughly square bounding box: circles and squares both satisfy this
108
+ ar = ww / float(hh + 1e-6)
109
+ if ar < 0.7 or ar > 1.3:
110
+ continue
111
 
112
+ candidates.append((i, x, y, ww, hh, area))
 
113
 
114
+ px_per_mm: float
115
  center: Optional[Tuple[int, int]] = None
116
  ref_type: Optional[str] = None
117
  bbox: Optional[Tuple[int, int, int, int]] = None
118
 
119
+ if candidates:
120
+ # 4. Pick the one closest to the top-left corner (smallest x + y)
121
+ label_id, x, y, ww, hh, area = min(candidates, key=lambda t: t[1] + t[2])
122
+ bbox = (int(x), int(y), int(ww), int(hh))
123
+ center = (int(x + ww // 2), int(y + hh // 2))
124
+
125
+ # Real-world size: diameter (coin) or side length (square)
126
+ ref_mm = ref_size_mm if ref_size_mm and ref_size_mm > 0 else 20.0
127
+
128
+ # For both circles and squares, the max side of the bounding box
129
+ # can be treated as "diameter/side" in pixels.
130
+ side_or_diam_px = float(max(ww, hh))
131
+ px_per_mm = max(side_or_diam_px / ref_mm, 1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Roughly classify reference type; optional, not used in scaling
134
+ ref_type = "square"
135
+ else:
136
+ # If no reference found, use a safe default scale to avoid division by zero.
 
137
  px_per_mm = 4.0
138
+ center = None
139
+ ref_type = None
140
+ bbox = None
141
 
142
  return px_per_mm, center, ref_type, bbox
143
 
 
386
  color_tol: int,
387
  hsv_low_h: int,
388
  hsv_high_h: int,
389
+ ) -> Tuple[Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]], Dict[str, Any]]:
390
  try:
391
  if image is None:
392
+ return None, pd.DataFrame(), None, [], {}
393
  img_rgb = np.array(image)
394
  img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
395
  img_bgr, scale = downscale_bgr(img_bgr)
 
408
  tmp.write(csv.encode("utf-8"))
409
  tmp.close()
410
  js = df.to_dict(orient="records")
411
+
412
+ # Store state for interactive correction
413
+ state_dict: Dict[str, Any] = {
414
+ "img_bgr": img_bgr,
415
+ "sample_type": sample_type,
416
+ "px_per_mm": px_per_mm,
417
+ "ref_center": ref_center,
418
+ "ref_type": ref_type,
419
+ "ref_bbox": ref_bbox,
420
+ "components": comps,
421
+ "expected_count": expected_count,
422
+ "ref_size_mm": ref_size_mm,
423
+ }
424
+ # By default, all components are active samples
425
+ state_dict["active_indices"] = list(range(len(comps)))
426
+
427
+ return overlay, df, tmp.name, js, state_dict
428
  except Exception as e:
429
+ return None, pd.DataFrame(), None, [{"error": str(e)}], {}
430
+
431
+
432
+ # --- Interactive correction helper ---
433
+ def apply_corrections(
434
+ click_event,
435
+ state_dict: Dict[str, Any],
436
+ correction_mode: str,
437
+ ) -> Tuple[Dict[str, Any], Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]]]:
438
+ """
439
+ Apply interactive corrections based on a click on the annotated image.
440
+
441
+ correction_mode:
442
+ - "none": do nothing
443
+ - "set-ref": treat the clicked object as the new reference
444
+ - "toggle-sample": toggle the clicked object between active/inactive sample
445
+ """
446
+ # If no valid state or no correction requested, do nothing
447
+ if not state_dict or "img_bgr" not in state_dict or correction_mode == "none" or click_event is None:
448
+ return state_dict, None, pd.DataFrame(), None, []
449
+
450
+ try:
451
+ # Gradio SelectData usually provides (x, y) in .index
452
+ if hasattr(click_event, "index"):
453
+ x, y = click_event.index
454
+ else:
455
+ # Fallback: assume click_event is a tuple
456
+ x, y = click_event
457
+
458
+ img_bgr = state_dict["img_bgr"]
459
+ components: List[Dict[str, Any]] = state_dict.get("components", [])
460
+ if not components:
461
+ return state_dict, None, pd.DataFrame(), None, []
462
+
463
+ # Find nearest component center to the click
464
+ min_dist = 1e9
465
+ nearest_idx = -1
466
+ for i, comp in enumerate(components):
467
+ cx, cy = comp["center"]
468
+ d = (cx - x) ** 2 + (cy - y) ** 2
469
+ if d < min_dist:
470
+ min_dist = d
471
+ nearest_idx = i
472
+
473
+ if nearest_idx < 0:
474
+ return state_dict, None, pd.DataFrame(), None, []
475
+
476
+ px_per_mm = state_dict.get("px_per_mm", 4.0)
477
+ ref_center = state_dict.get("ref_center")
478
+ ref_type = state_dict.get("ref_type", "square")
479
+ ref_bbox = state_dict.get("ref_bbox")
480
+ ref_size_mm = state_dict.get("ref_size_mm", 20.0)
481
+ sample_type = state_dict.get("sample_type", "leaves")
482
+
483
+ active_indices = state_dict.get("active_indices", list(range(len(components))))
484
+
485
+ if correction_mode == "set-ref":
486
+ # Use this component as the new reference object
487
+ comp = components[nearest_idx]
488
+ box = comp["box"]
489
+ xs = box[:, 0]
490
+ ys = box[:, 1]
491
+ x0, y0 = int(xs.min()), int(ys.min())
492
+ w0, h0 = int(xs.max() - xs.min()), int(ys.max() - ys.min())
493
+ ref_bbox = (x0, y0, w0, h0)
494
+ ref_center = (int(comp["center"][0]), int(comp["center"][1]))
495
+
496
+ # Update px_per_mm using the largest side as diameter/side length
497
+ side_px = float(max(w0, h0))
498
+ px_per_mm = max(side_px / (ref_size_mm if ref_size_mm > 0 else 20.0), 1e-6)
499
+ ref_type = "square"
500
+
501
+ # Remove this component from active samples (reference is not a sample)
502
+ new_components = []
503
+ for i, c in enumerate(components):
504
+ if i != nearest_idx:
505
+ new_components.append(c)
506
+ components = new_components
507
+ # Rebuild active_indices to cover all remaining components
508
+ active_indices = list(range(len(components)))
509
+
510
+ state_dict["components"] = components
511
+ state_dict["ref_bbox"] = ref_bbox
512
+ state_dict["ref_center"] = ref_center
513
+ state_dict["px_per_mm"] = px_per_mm
514
+ state_dict["ref_type"] = ref_type
515
+ state_dict["active_indices"] = active_indices
516
+
517
+ elif correction_mode == "toggle-sample":
518
+ # Toggle this component in/out of the active sample set
519
+ if nearest_idx in active_indices:
520
+ active_indices = [idx for idx in active_indices if idx != nearest_idx]
521
+ else:
522
+ active_indices.append(nearest_idx)
523
+ active_indices = sorted(set(active_indices))
524
+ state_dict["active_indices"] = active_indices
525
+
526
+ # Rebuild the list of active components
527
+ active_components = [components[i] for i in active_indices]
528
+
529
+ # Recompute metrics and overlay using the updated state
530
+ df = compute_metrics(img_bgr, active_components, px_per_mm)
531
+ overlay = render_overlay(
532
+ img_bgr.copy(),
533
+ px_per_mm,
534
+ (state_dict.get("ref_center"), state_dict.get("ref_type")),
535
+ active_components,
536
+ df,
537
+ state_dict.get("ref_bbox"),
538
+ )
539
+ csv = df.to_csv(index=False)
540
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
541
+ tmp.write(csv.encode("utf-8"))
542
+ tmp.close()
543
+ js = df.to_dict(orient="records")
544
+
545
+ return state_dict, overlay, df, tmp.name, js
546
+ except Exception:
547
+ # In case of any error, do not break the app; just keep current state
548
+ return state_dict, None, pd.DataFrame(), None, []
549
 
550
 
551
  with gr.Blocks(theme=gr.themes.Default()) as demo:
552
  gr.Markdown("# Biological Sample Quantifier (Leaves / Seeds)")
553
+ state = gr.State({})
554
  with gr.Row():
555
  with gr.Column(scale=1):
556
  image = gr.Image(type="numpy", label="Upload image")
 
563
  color_tol = gr.Slider(5, 100, value=40, step=1, label="Color tolerance")
564
  hsv_low = gr.Slider(0, 179, value=35, step=1, label="HSV H lower (leaves)")
565
  hsv_high = gr.Slider(0, 179, value=85, step=1, label="HSV H upper (leaves)")
566
+ correction_mode = gr.Radio(
567
+ ["none", "set-ref", "toggle-sample"],
568
+ value="none",
569
+ label="Correction mode (click on image)"
570
+ )
571
  run = gr.Button("Analyze")
572
  reset = gr.Button("Reset")
573
  with gr.Column(scale=2):
574
+ overlay = gr.Image(label="Annotated", interactive=True)
575
  table = gr.Dataframe(label="Metrics", wrap=True)
576
  csv_out = gr.File(label="CSV export")
577
  json_out = gr.JSON(label="JSON preview")
578
  def _analyze(image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high):
579
+ overlay_img, df, csv_path, js, state_dict = analyze(
580
+ image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high
581
+ )
582
+ return overlay_img, df, csv_path, js, state_dict
583
+
584
+ run.click(
585
+ _analyze,
586
+ [image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high],
587
+ [overlay, table, csv_out, json_out, state],
588
+ )
589
+
590
+ def _reset():
591
+ return None, pd.DataFrame(), None, [], {}
592
+
593
+ reset.click(_reset, None, [overlay, table, csv_out, json_out, state])
594
+
595
+ def _on_select(evt, current_state, correction_mode):
596
+ # Apply corrections based on a click on the annotated image
597
+ new_state, overlay_img, df, csv_path, js = apply_corrections(evt, current_state or {}, correction_mode)
598
+ # If overlay_img is None, keep the existing outputs unchanged by returning gr.update()
599
+ if overlay_img is None:
600
+ return gr.update(), gr.update(), gr.update(), gr.update(), new_state
601
+ return overlay_img, df, csv_path, js, new_state
602
+
603
+ overlay.select(
604
+ _on_select,
605
+ [state, correction_mode],
606
+ [overlay, table, csv_out, json_out, state],
607
+ )
608
 
609
  if __name__ == "__main__":
610
  demo.launch()