mlbench123 commited on
Commit
b57f874
Β·
verified Β·
1 Parent(s): 6ad3a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -185
app.py CHANGED
@@ -15,7 +15,7 @@ CLASS_COLORS = {0: (34, 197, 94), 1: (239, 68, 68)} # green, red
15
  SAMPLE_PATHS = ["image1.jpg", "image2.jpg"]
16
 
17
  # ─── Paper reference ──────────────────────────────────────────────────────────
18
- PAPER_REAL_MM = 40.0 # white 4Γ—4 cm square = 40 mm per side
19
 
20
  def detect_paper_pixels(img_np):
21
  gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
@@ -34,7 +34,7 @@ def detect_paper_pixels(img_np):
34
  if 0.5 < (w / max(h, 1)) < 2.0 and area > best_area:
35
  best_area = area
36
  best = (h, w)
37
- return best # (h_px, w_px) or None
38
 
39
  def px_to_mm(px, paper_px_dim):
40
  if not paper_px_dim:
@@ -59,36 +59,25 @@ def _text_size(draw, text, font):
59
  return bbox[2] - bbox[0], bbox[3] - bbox[1]
60
 
61
 
62
- # ─── Visual-only helpers ───────────────────────────────────────────────────────
63
- def _make_visual_mask(mask_raw_float, w, h):
64
- """
65
- Upsample the raw float mask (0..1) with bilinear interpolation and apply
66
- morphological smoothing. Result is a uint8 binary mask for DISPLAY ONLY
67
- β€” backend mask_np (INTER_NEAREST) is kept separate for measurements.
68
- """
69
- # Bilinear upscale β†’ much better edge alignment than NEAREST
70
- vis = cv2.resize(mask_raw_float, (w, h), interpolation=cv2.INTER_LINEAR)
71
- vis = (vis > 0.45).astype(np.uint8)
72
-
73
- # Elliptical close then open: fills jagged notches, shaves sharp spurs
74
- k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
75
- vis = cv2.morphologyEx(vis, cv2.MORPH_CLOSE, k)
76
- vis = cv2.morphologyEx(vis, cv2.MORPH_OPEN, k)
77
- return vis
78
 
79
-
80
- def _smooth_contour(contour, window=9):
81
  """
82
- Circular sliding-window average over contour points.
83
- Smooths wiggly edges while preserving overall shape / corners.
84
- Returns an int32 contour array compatible with cv2.polylines.
 
 
85
  """
86
- pts = contour[:, 0, :].astype(np.float32)
87
  n = len(pts)
88
- if n < window * 2 + 1:
89
- return contour.astype(np.int32)
 
 
 
 
90
 
91
- half = window // 2
92
  padded = np.vstack([pts[-half:], pts, pts[:half]])
93
  weights = np.hanning(window).astype(np.float32)
94
  weights /= weights.sum()
@@ -100,48 +89,58 @@ def _smooth_contour(contour, window=9):
100
  return smoothed.astype(np.int32).reshape(-1, 1, 2)
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
103
  # ─────────────────────────────────────────────────────────────────────────────
104
- # STEP 1 β€” Fast segmentation + classification output image (no text labels)
105
- # + auto-crop/zoom to grain bounding box
 
 
 
 
 
 
 
106
  # ─────────────────────────────────────────────────────────────────────────────
107
  def run_segmentation(img_np):
108
- """
109
- Run YOLO, draw colored masks + smooth contours (NO text labels on image).
110
- Returns:
111
- annotated : full-size RGB annotated image (numpy)
112
- zoomed_pil : PIL image cropped/zoomed to grain region
113
- grain_boxes : list of dicts {cls_id, cls_name, mask_np, bbox}
114
- mask_np is INTER_NEAREST β€” used only for measurements.
115
- counts : {"Full": int, "Broken": int}
116
- """
117
  h, w = img_np.shape[:2]
118
  results = model(img_np, imgsz=1280, conf=0.25)[0]
119
 
120
- annotated = img_np.copy()
121
- overlay = img_np.copy()
122
- counts = {"Full": 0, "Broken": 0}
123
  grain_boxes = []
124
 
125
- # Bounding box of all grains combined (for zoom) β€” uses backend mask
126
  all_x1, all_y1, all_x2, all_y2 = w, h, 0, 0
127
 
128
  if results.masks is not None:
129
- for mask_tensor, box in zip(results.masks.data, results.boxes):
 
 
 
 
 
130
  cls_id = int(box.cls[0])
131
  cls_name = CLASS_NAMES.get(cls_id, "?")
132
  color = CLASS_COLORS.get(cls_id, (200, 200, 200))
133
  counts[cls_name] += 1
134
 
135
- # ── Backend mask (INTER_NEAREST) β€” used for measurements only ──
136
- mask_raw = mask_tensor.cpu().numpy() # float32, 0..1
137
- mask_np = cv2.resize(
138
- mask_raw, (w, h), interpolation=cv2.INTER_NEAREST
139
- ).astype(np.uint8)
140
 
141
- # ── Visual mask (bilinear + morphological smooth) ──────────────
142
- mask_vis = _make_visual_mask(mask_raw, w, h)
143
 
144
- # Update combined grain bounding box from backend mask
145
  ys, xs = np.where(mask_np == 1)
146
  if len(xs) > 0:
147
  all_x1 = min(all_x1, int(xs.min()))
@@ -149,49 +148,36 @@ def run_segmentation(img_np):
149
  all_x2 = max(all_x2, int(xs.max()))
150
  all_y2 = max(all_y2, int(ys.max()))
151
 
152
- # Fill overlay using visual mask (smooth region)
153
- overlay[mask_vis == 1] = color
154
 
155
  grain_boxes.append({
156
- "cls_id": cls_id,
157
- "cls_name": cls_name,
158
- "mask_np": mask_np, # ← backend only, never touched for visuals
159
- "mask_vis": mask_vis, # ← visual only, never used for measurements
160
  })
161
 
162
- # Blend fill (overlay) into annotated
163
  annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0)
164
 
165
- # ── Draw smooth, anti-aliased contours on top ──────────────────────────
166
  for g in grain_boxes:
167
  color = CLASS_COLORS[g["cls_id"]]
168
- cnts, _ = cv2.findContours(
169
- g["mask_vis"], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
 
 
170
  )
171
- smooth_cnts = []
172
- for c in cnts:
173
- if len(c) >= 10:
174
- smooth_cnts.append(_smooth_contour(c, window=9))
175
- elif len(c) >= 4:
176
- smooth_cnts.append(c.astype(np.int32))
177
-
178
- if smooth_cnts:
179
- # polylines with LINE_AA gives sub-pixel anti-aliased edges
180
- cv2.polylines(
181
- annotated, smooth_cnts,
182
- isClosed=True, color=color, thickness=2,
183
- lineType=cv2.LINE_AA,
184
- )
185
 
186
- # ── Zoom to grain region with padding ─────────────────────────────────
187
  if all_x2 > all_x1 and all_y2 > all_y1:
188
- pad = max(30, int(max(all_x2 - all_x1, all_y2 - all_y1) * 0.08))
189
- cx1 = max(0, all_x1 - pad)
190
- cy1 = max(0, all_y1 - pad)
191
- cx2 = min(w, all_x2 + pad)
192
- cy2 = min(h, all_y2 + pad)
193
- crop = annotated[cy1:cy2, cx1:cx2]
194
- zoomed_pil = Image.fromarray(crop)
195
  else:
196
  zoomed_pil = Image.fromarray(annotated)
197
 
@@ -199,21 +185,14 @@ def run_segmentation(img_np):
199
 
200
 
201
  # ─────────────────────────────────────────────────────────────────────────────
202
- # STEP 2 β€” Measure grain height & width from YOLO masks
203
  # ─────────────────────────────────────────────────────────────────────────────
204
  def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims):
205
- """
206
- For each grain mask, compute oriented bounding box (minAreaRect).
207
- h_px = long axis, w_px = short axis.
208
- Converts to mm if paper_dims is provided.
209
- Returns list of measurement dicts.
210
- NOTE: uses mask_np (INTER_NEAREST) β€” measurements are unaffected by visual smoothing.
211
- """
212
  paper_px = (paper_dims[0] + paper_dims[1]) / 2.0 if paper_dims else None
213
  measurements = []
214
 
215
  for idx, g in enumerate(grain_boxes):
216
- mask_np = g["mask_np"] # backend mask only
217
  pts_y, pts_x = np.where(mask_np == 1)
218
  if len(pts_x) < 5:
219
  continue
@@ -222,10 +201,10 @@ def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims):
222
  rect = cv2.minAreaRect(pts)
223
  (cx, cy), (rw, rh), _ = rect
224
 
225
- h_px = float(max(rw, rh))
226
- w_px = float(min(rw, rh))
227
- h_mm = px_to_mm(h_px, paper_px)
228
- w_mm = px_to_mm(w_px, paper_px)
229
  area_mm2 = (h_mm * w_mm) if (h_mm and w_mm) else None
230
 
231
  measurements.append({
@@ -244,7 +223,7 @@ def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims):
244
 
245
 
246
  # ─────────────────────────────────────────────────────────────────────────────
247
- # STEP 2b β€” Build measurement DataFrames (same style as previous.py)
248
  # ─────────────────────────────────────────────────────────────────────────────
249
  def build_table_data(measurements, paper_px, counts):
250
  has_mm = paper_px is not None
@@ -252,28 +231,27 @@ def build_table_data(measurements, paper_px, counts):
252
 
253
  rows = []
254
  for g in measurements:
255
- h_val = round(g["h_mm"], 2) if (has_mm and g["h_mm"]) else round(g["h_px"], 1)
256
- w_val = round(g["w_mm"], 2) if (has_mm and g["w_mm"]) else round(g["w_px"], 1)
257
  area_val = round(g["area_mm2"], 2) if g["area_mm2"] else None
258
  rows.append({
259
- "#": g["label"],
260
- "Class": g["cls_name"],
261
- f"Height ({unit})": h_val,
262
- f"Width ({unit})": w_val,
263
- f"Area (mmΒ²)" if has_mm else "Area": area_val,
264
  })
265
  grain_df = pd.DataFrame(rows)
266
 
267
- # Summary
268
- h_key = "h_mm" if has_mm else "h_px"
269
- w_key = "w_mm" if has_mm else "w_px"
270
  heights = [(g["label"], g[h_key]) for g in measurements if g.get(h_key)]
271
  widths = [(g["label"], g[w_key]) for g in measurements if g.get(w_key)]
272
 
273
- max_h = max(heights, key=lambda x: x[1]) if heights else (0, 0)
274
- min_h = min(heights, key=lambda x: x[1]) if heights else (0, 0)
275
- max_w = max(widths, key=lambda x: x[1]) if widths else (0, 0)
276
- min_w = min(widths, key=lambda x: x[1]) if widths else (0, 0)
277
  interval = (max_h[1] - min_h[1]) / 10.0 if (heights and max_h[1] != min_h[1]) else 0.0
278
 
279
  n_full = counts.get("Full", 0)
@@ -281,64 +259,49 @@ def build_table_data(measurements, paper_px, counts):
281
  total = n_full + n_broken
282
 
283
  summary_rows = [
284
- {"Metric": "Total Grains", "Value": str(total)},
285
- {"Metric": "🟒 Full Grains", "Value": str(n_full)},
286
- {"Metric": "πŸ”΄ Broken Grains", "Value": str(n_broken)},
287
- {"Metric": "Paper Reference", "Value": f"βœ… Found ({unit} mode)" if has_mm else "❌ Not found (px only)"},
288
- {"Metric": f"Max Height (Grain #{max_h[0]})", "Value": f"{max_h[1]:.2f} {unit}"},
289
- {"Metric": f"Min Height (Grain #{min_h[0]})", "Value": f"{min_h[1]:.2f} {unit}"},
290
- {"Metric": f"Max Width (Grain #{max_w[0]})", "Value": f"{max_w[1]:.2f} {unit}"},
291
- {"Metric": f"Min Width (Grain #{min_w[0]})", "Value": f"{min_w[1]:.2f} {unit}"},
292
- {"Metric": f"Mean Height", "Value": f"{np.mean([v for _, v in heights]):.2f} {unit}" if heights else "β€”"},
293
- {"Metric": f"Mean Width", "Value": f"{np.mean([v for _, v in widths]):.2f} {unit}" if widths else "β€”"},
294
- {"Metric": "Bin Interval (maxβˆ’min)/10", "Value": f"{interval:.3f} {unit}"},
295
  ]
296
  summary_df = pd.DataFrame(summary_rows)
297
  return grain_df, summary_df
298
 
299
 
300
  # ─────────────────────────────────────────────────────────────────────────────
301
- # GRADIO β€” TWO-STAGE predict
302
- # Stage 1: fast β€” returns segmentation image + count summary (no tables yet)
303
- # Stage 2: slower β€” returns measurement tables
304
  # ─────────────────────────────────────────────────────────────────────────────
305
  def predict_stage1(image: Image.Image):
306
- """Returns zoomed segmentation image + summary + placeholder table message."""
307
  if image is None:
308
  return None, "", "", None, None
309
-
310
  img_np = np.array(image.convert("RGB"))
311
-
312
- # Run segmentation (fast)
313
  _, zoomed_pil, grain_boxes, counts = run_segmentation(img_np)
314
-
315
- total = counts["Full"] + counts["Broken"]
316
- summary = f"βœ… {total} grains detected Β· 🟒 Full: {counts['Full']} Β· πŸ”΄ Broken: {counts['Broken']}"
317
-
318
- count_md = f"""| | Count |
319
- |---|---|
320
- | 🌾 Total Grains | **{total}** |
321
- | 🟒 Full Grains | **{counts['Full']}** |
322
- | πŸ”΄ Broken Grains | **{counts['Broken']}** |"""
323
-
324
- # Return image + summary immediately; tables = loading placeholder
325
  loading_df = pd.DataFrame([{"Status": "⏳ Calculating height & width of all grains..."}])
326
  return zoomed_pil, summary, count_md, loading_df, loading_df
327
 
328
 
329
  def predict_stage2(image: Image.Image):
330
- """Full pipeline β€” returns everything including measurement tables."""
331
  if image is None:
332
  return None, "", "", None, None
333
-
334
  img_np = np.array(image.convert("RGB"))
335
 
336
- # Run segmentation + paper detection in parallel
337
- def _seg():
338
- return run_segmentation(img_np)
339
-
340
- def _paper():
341
- return detect_paper_pixels(img_np)
342
 
343
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
344
  fut_seg = pool.submit(_seg)
@@ -346,26 +309,24 @@ def predict_stage2(image: Image.Image):
346
  _, zoomed_pil, grain_boxes, counts = fut_seg.result()
347
  paper_dims = fut_paper.result()
348
 
349
- # Measure grains
350
  measurements, paper_px = measure_grains_from_boxes(grain_boxes, img_np.shape, paper_dims)
351
-
352
- total = counts["Full"] + counts["Broken"]
353
- summary = (f"βœ… {total} grains detected Β· 🟒 Full: {counts['Full']} Β· πŸ”΄ Broken: {counts['Broken']}"
354
- + (f" Β· πŸ“ Paper found β€” measurements in mm" if paper_px else " Β· ⚠️ No paper β€” measurements in px"))
355
-
356
- count_md = f"""| | Count |
357
- |---|---|
358
- | 🌾 Total Grains | **{total}** |
359
- | 🟒 Full Grains | **{counts['Full']}** |
360
- | πŸ”΄ Broken Grains | **{counts['Broken']}** |"""
361
-
362
  grain_df, summary_df = build_table_data(measurements, paper_px, counts)
363
-
364
  return zoomed_pil, summary, count_md, grain_df, summary_df
365
 
366
 
367
  # ─────────────────────────────────────────────────────────────────────────────
368
- # UI (theme + styling identical to previous.py)
369
  # ─────────────────────────────────────────────────────────────────────────────
370
  THEME = gr.themes.Soft(
371
  primary_hue="violet",
@@ -377,7 +338,6 @@ THEME = gr.themes.Soft(
377
  CSS = """
378
  #run-btn { margin-top: 6px; }
379
  #status-box textarea { font-size: 0.92rem; }
380
- .gr-dataframe table { font-size: 0.88rem; }
381
  #count-box { font-size: 0.95rem; }
382
  """
383
 
@@ -397,33 +357,25 @@ with gr.Blocks(theme=THEME, title="GrainVision", css=CSS) as demo:
397
  """)
398
 
399
  with gr.Row(equal_height=False):
400
-
401
- # ── LEFT: Input ───────────────────────────────────────────────────────
402
  with gr.Column(scale=1):
403
  inp_image = gr.Image(type="pil", label="Upload Rice Image", height=280)
404
-
405
- run_btn = gr.Button("πŸ” Analyse Grains",
406
- variant="primary", size="lg", elem_id="run-btn")
407
-
408
  gr.Markdown("_Upload an image then press **Analyse**. Segmentation appears first, measurements follow._")
409
-
410
  status_box = gr.Textbox(
411
  label="Status", value="", interactive=False,
412
  visible=True, max_lines=3, elem_id="status-box",
413
  )
414
-
415
  gr.Markdown("### Example Images _(click to load)_")
416
  gr.Examples(
417
  examples=[[p] for p in SAMPLE_PATHS],
418
  inputs=inp_image, label="", examples_per_page=6,
419
  )
420
 
421
- # ── RIGHT: Output ─────────────────────────────────────────────────────
422
  with gr.Column(scale=1):
423
  gr.Markdown("### Segmentation Output *(zoomed to grains)*")
424
  seg_out = gr.Image(label="", interactive=False)
425
 
426
- # ── Count summary ─────────────────────────────────────────────────────────
427
  gr.Markdown("---")
428
  with gr.Row():
429
  with gr.Column(scale=1):
@@ -439,30 +391,32 @@ with gr.Blocks(theme=THEME, title="GrainVision", css=CSS) as demo:
439
  elem_id="count-box",
440
  )
441
 
442
- # ── Measurement tables ────────────────────────────────────────────────────
443
  gr.Markdown("---")
444
  gr.Markdown("### Grain Measurements Table")
445
  with gr.Row():
446
  with gr.Column(scale=2):
447
  gr.Markdown("#### Per-Grain Measurements")
448
- grain_table_out = gr.DataFrame(label="", interactive=False, wrap=False)
 
 
 
 
 
449
  with gr.Column(scale=1):
450
  gr.Markdown("#### Summary Statistics")
451
- summary_table_out = gr.DataFrame(label="", interactive=False, wrap=False)
 
 
 
 
 
452
 
453
- # ── Two-stage wiring ──────────────────────────────────────────────────────
454
  OUTPUTS = [seg_out, summary_box, count_md, grain_table_out, summary_table_out]
455
 
456
- # Stage 1: fast β€” fires on click, shows image + counts + loading placeholder
457
  run_btn.click(
458
- fn = predict_stage1,
459
- inputs = [inp_image],
460
- outputs = OUTPUTS,
461
  ).then(
462
- # Stage 2: fires immediately after stage 1 completes, fills in tables
463
- fn = predict_stage2,
464
- inputs = [inp_image],
465
- outputs = OUTPUTS,
466
  )
467
 
468
 
 
15
  SAMPLE_PATHS = ["image1.jpg", "image2.jpg"]
16
 
17
  # ─── Paper reference ──────────────────────────────────────────────────────────
18
+ PAPER_REAL_MM = 40.0 # white 4x4 cm square = 40 mm per side
19
 
20
  def detect_paper_pixels(img_np):
21
  gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
 
34
  if 0.5 < (w / max(h, 1)) < 2.0 and area > best_area:
35
  best_area = area
36
  best = (h, w)
37
+ return best
38
 
39
  def px_to_mm(px, paper_px_dim):
40
  if not paper_px_dim:
 
59
  return bbox[2] - bbox[0], bbox[3] - bbox[1]
60
 
61
 
62
+ # ─── Visual polygon helpers ────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def _smooth_polygon(pts_xy, window=11):
 
65
  """
66
+ Hanning-weighted circular sliding-window average on polygon vertices.
67
+ pts_xy : numpy (N, 2) float β€” already in original image coordinates
68
+ from results.masks.xy, so NO resize drift whatsoever.
69
+ Returns : int32 array (N, 1, 2) for cv2.polylines / fillPoly.
70
+ VISUAL ONLY β€” backend mask uses raw polygon via _polygon_to_mask().
71
  """
72
+ pts = pts_xy.astype(np.float32)
73
  n = len(pts)
74
+ if n < 6:
75
+ return pts.astype(np.int32).reshape(-1, 1, 2)
76
+
77
+ # Window must be odd and fit inside the polygon
78
+ window = min(window | 1, (n - 1) | 1)
79
+ half = window // 2
80
 
 
81
  padded = np.vstack([pts[-half:], pts, pts[:half]])
82
  weights = np.hanning(window).astype(np.float32)
83
  weights /= weights.sum()
 
89
  return smoothed.astype(np.int32).reshape(-1, 1, 2)
90
 
91
 
92
+ def _polygon_to_mask(pts_xy, h, w):
93
+ """
94
+ Rasterise raw polygon to binary uint8 mask.
95
+ Used ONLY for backend measurements β€” never for visuals.
96
+ """
97
+ mask = np.zeros((h, w), dtype=np.uint8)
98
+ if len(pts_xy) >= 3:
99
+ cv2.fillPoly(mask, [pts_xy.astype(np.int32)], 1)
100
+ return mask
101
+
102
+
103
  # ─────────────────────────────────────────────────────────────────────────────
104
+ # STEP 1 β€” Segmentation + visual output
105
+ #
106
+ # ROOT-CAUSE FIX FOR MISALIGNMENT:
107
+ # Previously used results.masks.data (low-res tensor) + cv2.resize
108
+ # which introduces sub-pixel drift at every grain boundary.
109
+ #
110
+ # Now uses results.masks.xy β€” ultralytics already maps each polygon
111
+ # to ORIGINAL image pixel coordinates, so alignment is exact.
112
+ # No resize, no drift, no displacement.
113
  # ─────────────────────────────────────────────────────────────────────────────
114
  def run_segmentation(img_np):
 
 
 
 
 
 
 
 
 
115
  h, w = img_np.shape[:2]
116
  results = model(img_np, imgsz=1280, conf=0.25)[0]
117
 
118
+ annotated = img_np.copy()
119
+ overlay = img_np.copy()
120
+ counts = {"Full": 0, "Broken": 0}
121
  grain_boxes = []
122
 
 
123
  all_x1, all_y1, all_x2, all_y2 = w, h, 0, 0
124
 
125
  if results.masks is not None:
126
+ xy_list = results.masks.xy # list of (N_i, 2) arrays, original coords
127
+
128
+ for poly_xy, box in zip(xy_list, results.boxes):
129
+ if len(poly_xy) < 3:
130
+ continue
131
+
132
  cls_id = int(box.cls[0])
133
  cls_name = CLASS_NAMES.get(cls_id, "?")
134
  color = CLASS_COLORS.get(cls_id, (200, 200, 200))
135
  counts[cls_name] += 1
136
 
137
+ # Backend mask: raw polygon fill β€” for measurements only
138
+ mask_np = _polygon_to_mask(poly_xy, h, w)
 
 
 
139
 
140
+ # Visual polygon: Hanning-smoothed β€” for display only
141
+ vis_poly = _smooth_polygon(poly_xy, window=11)
142
 
143
+ # Update zoom bounding box from backend mask
144
  ys, xs = np.where(mask_np == 1)
145
  if len(xs) > 0:
146
  all_x1 = min(all_x1, int(xs.min()))
 
148
  all_x2 = max(all_x2, int(xs.max()))
149
  all_y2 = max(all_y2, int(ys.max()))
150
 
151
+ # Fill overlay with smooth visual polygon
152
+ cv2.fillPoly(overlay, [vis_poly], color)
153
 
154
  grain_boxes.append({
155
+ "cls_id": cls_id,
156
+ "cls_name": cls_name,
157
+ "mask_np": mask_np, # backend only
158
+ "vis_poly": vis_poly, # visual only
159
  })
160
 
161
+ # Blend fill
162
  annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0)
163
 
164
+ # Draw smooth anti-aliased outlines on top
165
  for g in grain_boxes:
166
  color = CLASS_COLORS[g["cls_id"]]
167
+ cv2.polylines(
168
+ annotated, [g["vis_poly"]],
169
+ isClosed=True, color=color, thickness=2,
170
+ lineType=cv2.LINE_AA,
171
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # Zoom to grain region
174
  if all_x2 > all_x1 and all_y2 > all_y1:
175
+ pad = max(30, int(max(all_x2 - all_x1, all_y2 - all_y1) * 0.08))
176
+ cx1 = max(0, all_x1 - pad)
177
+ cy1 = max(0, all_y1 - pad)
178
+ cx2 = min(w, all_x2 + pad)
179
+ cy2 = min(h, all_y2 + pad)
180
+ zoomed_pil = Image.fromarray(annotated[cy1:cy2, cx1:cx2])
 
181
  else:
182
  zoomed_pil = Image.fromarray(annotated)
183
 
 
185
 
186
 
187
  # ─────────────────────────────────────────────────────────────────────────────
188
+ # STEP 2 β€” Measure grains (unchanged β€” uses backend mask_np only)
189
  # ─────────────────────────────────────────────────────────────────────────────
190
  def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims):
 
 
 
 
 
 
 
191
  paper_px = (paper_dims[0] + paper_dims[1]) / 2.0 if paper_dims else None
192
  measurements = []
193
 
194
  for idx, g in enumerate(grain_boxes):
195
+ mask_np = g["mask_np"]
196
  pts_y, pts_x = np.where(mask_np == 1)
197
  if len(pts_x) < 5:
198
  continue
 
201
  rect = cv2.minAreaRect(pts)
202
  (cx, cy), (rw, rh), _ = rect
203
 
204
+ h_px = float(max(rw, rh))
205
+ w_px = float(min(rw, rh))
206
+ h_mm = px_to_mm(h_px, paper_px)
207
+ w_mm = px_to_mm(w_px, paper_px)
208
  area_mm2 = (h_mm * w_mm) if (h_mm and w_mm) else None
209
 
210
  measurements.append({
 
223
 
224
 
225
  # ─────────────────────────────────────────────────────────────────────────────
226
+ # STEP 2b β€” Build DataFrames
227
  # ─────────────────────────────────────────────────────────────────────────────
228
  def build_table_data(measurements, paper_px, counts):
229
  has_mm = paper_px is not None
 
231
 
232
  rows = []
233
  for g in measurements:
234
+ h_val = round(g["h_mm"], 2) if (has_mm and g["h_mm"]) else round(g["h_px"], 1)
235
+ w_val = round(g["w_mm"], 2) if (has_mm and g["w_mm"]) else round(g["w_px"], 1)
236
  area_val = round(g["area_mm2"], 2) if g["area_mm2"] else None
237
  rows.append({
238
+ "#": g["label"],
239
+ "Class": g["cls_name"],
240
+ f"Height ({unit})": h_val,
241
+ f"Width ({unit})": w_val,
242
+ "Area (mm\u00b2)" if has_mm else "Area": area_val,
243
  })
244
  grain_df = pd.DataFrame(rows)
245
 
246
+ h_key = "h_mm" if has_mm else "h_px"
247
+ w_key = "w_mm" if has_mm else "w_px"
 
248
  heights = [(g["label"], g[h_key]) for g in measurements if g.get(h_key)]
249
  widths = [(g["label"], g[w_key]) for g in measurements if g.get(w_key)]
250
 
251
+ max_h = max(heights, key=lambda x: x[1]) if heights else (0, 0)
252
+ min_h = min(heights, key=lambda x: x[1]) if heights else (0, 0)
253
+ max_w = max(widths, key=lambda x: x[1]) if widths else (0, 0)
254
+ min_w = min(widths, key=lambda x: x[1]) if widths else (0, 0)
255
  interval = (max_h[1] - min_h[1]) / 10.0 if (heights and max_h[1] != min_h[1]) else 0.0
256
 
257
  n_full = counts.get("Full", 0)
 
259
  total = n_full + n_broken
260
 
261
  summary_rows = [
262
+ {"Metric": "Total Grains", "Value": str(total)},
263
+ {"Metric": "🟒 Full Grains", "Value": str(n_full)},
264
+ {"Metric": "πŸ”΄ Broken Grains", "Value": str(n_broken)},
265
+ {"Metric": "Paper Reference", "Value": f"βœ… Found ({unit} mode)" if has_mm else "❌ Not found (px only)"},
266
+ {"Metric": f"Max Height (Grain #{max_h[0]})", "Value": f"{max_h[1]:.2f} {unit}"},
267
+ {"Metric": f"Min Height (Grain #{min_h[0]})", "Value": f"{min_h[1]:.2f} {unit}"},
268
+ {"Metric": f"Max Width (Grain #{max_w[0]})", "Value": f"{max_w[1]:.2f} {unit}"},
269
+ {"Metric": f"Min Width (Grain #{min_w[0]})", "Value": f"{min_w[1]:.2f} {unit}"},
270
+ {"Metric": "Mean Height", "Value": f"{np.mean([v for _, v in heights]):.2f} {unit}" if heights else "β€”"},
271
+ {"Metric": "Mean Width", "Value": f"{np.mean([v for _, v in widths]):.2f} {unit}" if widths else "β€”"},
272
+ {"Metric": "Bin Interval (max-min)/10", "Value": f"{interval:.3f} {unit}"},
273
  ]
274
  summary_df = pd.DataFrame(summary_rows)
275
  return grain_df, summary_df
276
 
277
 
278
  # ─────────────────────────────────────────────────────────────────────────────
279
+ # GRADIO β€” two-stage predict
 
 
280
  # ─────────────────────────────────────────────────────────────────────────────
281
  def predict_stage1(image: Image.Image):
 
282
  if image is None:
283
  return None, "", "", None, None
 
284
  img_np = np.array(image.convert("RGB"))
 
 
285
  _, zoomed_pil, grain_boxes, counts = run_segmentation(img_np)
286
+ total = counts["Full"] + counts["Broken"]
287
+ summary = f"βœ… {total} grains detected Β· 🟒 Full: {counts['Full']} Β· πŸ”΄ Broken: {counts['Broken']}"
288
+ count_md = (
289
+ f"| | Count |\n|---|---|\n"
290
+ f"| 🌾 Total Grains | **{total}** |\n"
291
+ f"| 🟒 Full Grains | **{counts['Full']}** |\n"
292
+ f"| πŸ”΄ Broken Grains | **{counts['Broken']}** |"
293
+ )
 
 
 
294
  loading_df = pd.DataFrame([{"Status": "⏳ Calculating height & width of all grains..."}])
295
  return zoomed_pil, summary, count_md, loading_df, loading_df
296
 
297
 
298
  def predict_stage2(image: Image.Image):
 
299
  if image is None:
300
  return None, "", "", None, None
 
301
  img_np = np.array(image.convert("RGB"))
302
 
303
+ def _seg(): return run_segmentation(img_np)
304
+ def _paper(): return detect_paper_pixels(img_np)
 
 
 
 
305
 
306
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
307
  fut_seg = pool.submit(_seg)
 
309
  _, zoomed_pil, grain_boxes, counts = fut_seg.result()
310
  paper_dims = fut_paper.result()
311
 
 
312
  measurements, paper_px = measure_grains_from_boxes(grain_boxes, img_np.shape, paper_dims)
313
+ total = counts["Full"] + counts["Broken"]
314
+ summary = (
315
+ f"βœ… {total} grains detected Β· 🟒 Full: {counts['Full']} Β· πŸ”΄ Broken: {counts['Broken']}"
316
+ + (f" Β· πŸ“ Paper found β€” measurements in mm" if paper_px else " Β· ⚠️ No paper β€” measurements in px")
317
+ )
318
+ count_md = (
319
+ f"| | Count |\n|---|---|\n"
320
+ f"| 🌾 Total Grains | **{total}** |\n"
321
+ f"| 🟒 Full Grains | **{counts['Full']}** |\n"
322
+ f"| πŸ”΄ Broken Grains | **{counts['Broken']}** |"
323
+ )
324
  grain_df, summary_df = build_table_data(measurements, paper_px, counts)
 
325
  return zoomed_pil, summary, count_md, grain_df, summary_df
326
 
327
 
328
  # ─────────────────────────────────────────────────────────────────────────────
329
+ # UI
330
  # ─────────────────────────────────────────────────────────────────────────────
331
  THEME = gr.themes.Soft(
332
  primary_hue="violet",
 
338
  CSS = """
339
  #run-btn { margin-top: 6px; }
340
  #status-box textarea { font-size: 0.92rem; }
 
341
  #count-box { font-size: 0.95rem; }
342
  """
343
 
 
357
  """)
358
 
359
  with gr.Row(equal_height=False):
 
 
360
  with gr.Column(scale=1):
361
  inp_image = gr.Image(type="pil", label="Upload Rice Image", height=280)
362
+ run_btn = gr.Button("πŸ” Analyse Grains",
363
+ variant="primary", size="lg", elem_id="run-btn")
 
 
364
  gr.Markdown("_Upload an image then press **Analyse**. Segmentation appears first, measurements follow._")
 
365
  status_box = gr.Textbox(
366
  label="Status", value="", interactive=False,
367
  visible=True, max_lines=3, elem_id="status-box",
368
  )
 
369
  gr.Markdown("### Example Images _(click to load)_")
370
  gr.Examples(
371
  examples=[[p] for p in SAMPLE_PATHS],
372
  inputs=inp_image, label="", examples_per_page=6,
373
  )
374
 
 
375
  with gr.Column(scale=1):
376
  gr.Markdown("### Segmentation Output *(zoomed to grains)*")
377
  seg_out = gr.Image(label="", interactive=False)
378
 
 
379
  gr.Markdown("---")
380
  with gr.Row():
381
  with gr.Column(scale=1):
 
391
  elem_id="count-box",
392
  )
393
 
 
394
  gr.Markdown("---")
395
  gr.Markdown("### Grain Measurements Table")
396
  with gr.Row():
397
  with gr.Column(scale=2):
398
  gr.Markdown("#### Per-Grain Measurements")
399
+ grain_table_out = gr.DataFrame(
400
+ label="",
401
+ interactive=False,
402
+ wrap=False,
403
+ height=500, # shows all rows; scrollable if > 500px
404
+ )
405
  with gr.Column(scale=1):
406
  gr.Markdown("#### Summary Statistics")
407
+ summary_table_out = gr.DataFrame(
408
+ label="",
409
+ interactive=False,
410
+ wrap=False,
411
+ height=420, # fits all 11 summary rows without scroll
412
+ )
413
 
 
414
  OUTPUTS = [seg_out, summary_box, count_md, grain_table_out, summary_table_out]
415
 
 
416
  run_btn.click(
417
+ fn=predict_stage1, inputs=[inp_image], outputs=OUTPUTS,
 
 
418
  ).then(
419
+ fn=predict_stage2, inputs=[inp_image], outputs=OUTPUTS,
 
 
 
420
  )
421
 
422