mlbench123 commited on
Commit
6ad3a74
Β·
verified Β·
1 Parent(s): 5c8507d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -17
app.py CHANGED
@@ -58,17 +58,60 @@ def _text_size(draw, text, font):
58
  bbox = draw.textbbox((0, 0), text, font=font)
59
  return bbox[2] - bbox[0], bbox[3] - bbox[1]
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # ─────────────────────────────────────────────────────────────────────────────
62
  # STEP 1 β€” Fast segmentation + classification output image (no text labels)
63
  # + auto-crop/zoom to grain bounding box
64
  # ─────────────────────────────────────────────────────────────────────────────
65
  def run_segmentation(img_np):
66
  """
67
- Run YOLO, draw colored masks + contours (NO text labels on image).
68
  Returns:
69
  annotated : full-size RGB annotated image (numpy)
70
  zoomed_pil : PIL image cropped/zoomed to grain region
71
  grain_boxes : list of dicts {cls_id, cls_name, mask_np, bbox}
 
72
  counts : {"Full": int, "Broken": int}
73
  """
74
  h, w = img_np.shape[:2]
@@ -79,7 +122,7 @@ def run_segmentation(img_np):
79
  counts = {"Full": 0, "Broken": 0}
80
  grain_boxes = []
81
 
82
- # Bounding box of all grains combined (for zoom)
83
  all_x1, all_y1, all_x2, all_y2 = w, h, 0, 0
84
 
85
  if results.masks is not None:
@@ -89,10 +132,16 @@ def run_segmentation(img_np):
89
  color = CLASS_COLORS.get(cls_id, (200, 200, 200))
90
  counts[cls_name] += 1
91
 
92
- mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
93
- mask_np = cv2.resize(mask_np, (w, h), interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
94
 
95
- # Update combined grain bounding box
96
  ys, xs = np.where(mask_np == 1)
97
  if len(xs) > 0:
98
  all_x1 = min(all_x1, int(xs.min()))
@@ -100,25 +149,41 @@ def run_segmentation(img_np):
100
  all_x2 = max(all_x2, int(xs.max()))
101
  all_y2 = max(all_y2, int(ys.max()))
102
 
103
- overlay[mask_np == 1] = color
104
- cnts, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
105
- cv2.drawContours(annotated, cnts, -1, color, 2)
106
 
107
  grain_boxes.append({
108
- "cls_id": cls_id,
109
- "cls_name": cls_name,
110
- "mask_np": mask_np,
 
111
  })
112
 
113
- # Blend overlays
114
  annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0)
115
 
116
- # Redraw contours sharp on top
117
  for g in grain_boxes:
118
- cnts, _ = cv2.findContours(g["mask_np"], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
119
- cv2.drawContours(annotated, cnts, -1, CLASS_COLORS[g["cls_id"]], 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # ── Zoom to grain region with padding ─────────────────────────────────────
122
  if all_x2 > all_x1 and all_y2 > all_y1:
123
  pad = max(30, int(max(all_x2 - all_x1, all_y2 - all_y1) * 0.08))
124
  cx1 = max(0, all_x1 - pad)
@@ -142,12 +207,13 @@ def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims):
142
  h_px = long axis, w_px = short axis.
143
  Converts to mm if paper_dims is provided.
144
  Returns list of measurement dicts.
 
145
  """
146
  paper_px = (paper_dims[0] + paper_dims[1]) / 2.0 if paper_dims else None
147
  measurements = []
148
 
149
  for idx, g in enumerate(grain_boxes):
150
- mask_np = g["mask_np"]
151
  pts_y, pts_x = np.where(mask_np == 1)
152
  if len(pts_x) < 5:
153
  continue
 
58
  bbox = draw.textbbox((0, 0), text, font=font)
59
  return bbox[2] - bbox[0], bbox[3] - bbox[1]
60
 
61
+
62
+ # ─── Visual-only helpers ───────────────────────────────────────────────────────
63
+ def _make_visual_mask(mask_raw_float, w, h):
64
+ """
65
+ Upsample the raw float mask (0..1) with bilinear interpolation and apply
66
+ morphological smoothing. Result is a uint8 binary mask for DISPLAY ONLY
67
+ β€” backend mask_np (INTER_NEAREST) is kept separate for measurements.
68
+ """
69
+ # Bilinear upscale β†’ much better edge alignment than NEAREST
70
+ vis = cv2.resize(mask_raw_float, (w, h), interpolation=cv2.INTER_LINEAR)
71
+ vis = (vis > 0.45).astype(np.uint8)
72
+
73
+ # Elliptical close then open: fills jagged notches, shaves sharp spurs
74
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
75
+ vis = cv2.morphologyEx(vis, cv2.MORPH_CLOSE, k)
76
+ vis = cv2.morphologyEx(vis, cv2.MORPH_OPEN, k)
77
+ return vis
78
+
79
+
80
+ def _smooth_contour(contour, window=9):
81
+ """
82
+ Circular sliding-window average over contour points.
83
+ Smooths wiggly edges while preserving overall shape / corners.
84
+ Returns an int32 contour array compatible with cv2.polylines.
85
+ """
86
+ pts = contour[:, 0, :].astype(np.float32)
87
+ n = len(pts)
88
+ if n < window * 2 + 1:
89
+ return contour.astype(np.int32)
90
+
91
+ half = window // 2
92
+ padded = np.vstack([pts[-half:], pts, pts[:half]])
93
+ weights = np.hanning(window).astype(np.float32)
94
+ weights /= weights.sum()
95
+
96
+ smoothed = np.zeros_like(pts)
97
+ for i in range(n):
98
+ smoothed[i] = (padded[i : i + window] * weights[:, None]).sum(axis=0)
99
+
100
+ return smoothed.astype(np.int32).reshape(-1, 1, 2)
101
+
102
+
103
  # ─────────────────────────────────────────────────────────────────────────────
104
  # STEP 1 β€” Fast segmentation + classification output image (no text labels)
105
  # + auto-crop/zoom to grain bounding box
106
  # ─────────────────────────────────────────────────────────────────────────────
107
  def run_segmentation(img_np):
108
  """
109
+ Run YOLO, draw colored masks + smooth contours (NO text labels on image).
110
  Returns:
111
  annotated : full-size RGB annotated image (numpy)
112
  zoomed_pil : PIL image cropped/zoomed to grain region
113
  grain_boxes : list of dicts {cls_id, cls_name, mask_np, bbox}
114
+ mask_np is INTER_NEAREST β€” used only for measurements.
115
  counts : {"Full": int, "Broken": int}
116
  """
117
  h, w = img_np.shape[:2]
 
122
  counts = {"Full": 0, "Broken": 0}
123
  grain_boxes = []
124
 
125
+ # Bounding box of all grains combined (for zoom) β€” uses backend mask
126
  all_x1, all_y1, all_x2, all_y2 = w, h, 0, 0
127
 
128
  if results.masks is not None:
 
132
  color = CLASS_COLORS.get(cls_id, (200, 200, 200))
133
  counts[cls_name] += 1
134
 
135
+ # ── Backend mask (INTER_NEAREST) β€” used for measurements only ──
136
+ mask_raw = mask_tensor.cpu().numpy() # float32, 0..1
137
+ mask_np = cv2.resize(
138
+ mask_raw, (w, h), interpolation=cv2.INTER_NEAREST
139
+ ).astype(np.uint8)
140
+
141
+ # ── Visual mask (bilinear + morphological smooth) ──────────────
142
+ mask_vis = _make_visual_mask(mask_raw, w, h)
143
 
144
+ # Update combined grain bounding box from backend mask
145
  ys, xs = np.where(mask_np == 1)
146
  if len(xs) > 0:
147
  all_x1 = min(all_x1, int(xs.min()))
 
149
  all_x2 = max(all_x2, int(xs.max()))
150
  all_y2 = max(all_y2, int(ys.max()))
151
 
152
+ # Fill overlay using visual mask (smooth region)
153
+ overlay[mask_vis == 1] = color
 
154
 
155
  grain_boxes.append({
156
+ "cls_id": cls_id,
157
+ "cls_name": cls_name,
158
+ "mask_np": mask_np, # ← backend only, never touched for visuals
159
+ "mask_vis": mask_vis, # ← visual only, never used for measurements
160
  })
161
 
162
+ # Blend fill (overlay) into annotated
163
  annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0)
164
 
165
+ # ── Draw smooth, anti-aliased contours on top ──────────────────────────
166
  for g in grain_boxes:
167
+ color = CLASS_COLORS[g["cls_id"]]
168
+ cnts, _ = cv2.findContours(
169
+ g["mask_vis"], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
170
+ )
171
+ smooth_cnts = []
172
+ for c in cnts:
173
+ if len(c) >= 10:
174
+ smooth_cnts.append(_smooth_contour(c, window=9))
175
+ elif len(c) >= 4:
176
+ smooth_cnts.append(c.astype(np.int32))
177
+
178
+ if smooth_cnts:
179
+ # polylines with LINE_AA gives sub-pixel anti-aliased edges
180
+ cv2.polylines(
181
+ annotated, smooth_cnts,
182
+ isClosed=True, color=color, thickness=2,
183
+ lineType=cv2.LINE_AA,
184
+ )
185
 
186
+ # ── Zoom to grain region with padding ─────────────────────────────────
187
  if all_x2 > all_x1 and all_y2 > all_y1:
188
  pad = max(30, int(max(all_x2 - all_x1, all_y2 - all_y1) * 0.08))
189
  cx1 = max(0, all_x1 - pad)
 
207
  h_px = long axis, w_px = short axis.
208
  Converts to mm if paper_dims is provided.
209
  Returns list of measurement dicts.
210
+ NOTE: uses mask_np (INTER_NEAREST) β€” measurements are unaffected by visual smoothing.
211
  """
212
  paper_px = (paper_dims[0] + paper_dims[1]) / 2.0 if paper_dims else None
213
  measurements = []
214
 
215
  for idx, g in enumerate(grain_boxes):
216
+ mask_np = g["mask_np"] # backend mask only
217
  pts_y, pts_x = np.where(mask_np == 1)
218
  if len(pts_x) < 5:
219
  continue