import gradio as gr import cv2 import numpy as np import pandas as pd import concurrent.futures from PIL import Image, ImageDraw, ImageFont from ultralytics import YOLO # ─── Model ──────────────────────────────────────────────────────────────────── model = YOLO("best.pt") CLASS_NAMES = {0: "Full", 1: "Broken"} CLASS_COLORS = {0: (34, 197, 94), 1: (239, 68, 68)} # green, red SAMPLE_PATHS = ["image1.jpg", "image2.jpg"] # ─── Paper reference ────────────────────────────────────────────────────────── PAPER_REAL_MM = 40.0 # white 4x4 cm square = 40 mm per side def detect_paper_pixels(img_np): gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) _, thresh = cv2.threshold(gray, 180, 255, cv2.THRESH_BINARY) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7)) thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) img_area = img_np.shape[0] * img_np.shape[1] best, best_area = None, 0 for c in contours: area = cv2.contourArea(c) if area < img_area * 0.02: continue x, y, w, h = cv2.boundingRect(c) if 0.5 < (w / max(h, 1)) < 2.0 and area > best_area: best_area = area best = (h, w) return best def px_to_mm(px, paper_px_dim): if not paper_px_dim: return None return px * PAPER_REAL_MM / paper_px_dim # ─── Font helper ────────────────────────────────────────────────────────────── def _font(size, bold=False): for path in [ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" if bold else "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", ]: try: return ImageFont.truetype(path, size) except Exception: pass return ImageFont.load_default() def _text_size(draw, text, font): bbox = draw.textbbox((0, 0), text, font=font) return bbox[2] - bbox[0], bbox[3] - bbox[1] # ─── Mask helpers ───────────────────────────────────────────────────────────── def _polygon_to_mask(pts_xy, h, w): """Rasterise raw polygon → binary uint8 mask. BACKEND / measurements only.""" mask = np.zeros((h, w), dtype=np.uint8) if len(pts_xy) >= 3: cv2.fillPoly(mask, [pts_xy.astype(np.int32)], 1) return mask def _refine_mask_grabcut(img_bgr, coarse_mask): """ Refine a coarse binary mask to pixel-perfect grain boundary using GrabCut. img_bgr : full BGR image coarse_mask : uint8 binary mask (0/1), same size as img_bgr Returns : refined binary uint8 mask (0/1) """ ys, xs = np.where(coarse_mask == 1) if len(xs) < 5: return coarse_mask # Tight crop with small padding so GrabCut has background context x1, y1 = max(0, int(xs.min()) - 6), max(0, int(ys.min()) - 6) x2, y2 = min(img_bgr.shape[1], int(xs.max()) + 6), min(img_bgr.shape[0], int(ys.max()) + 6) crop = img_bgr[y1:y2, x1:x2] ch, cw = crop.shape[:2] if ch < 8 or cw < 8: return coarse_mask # Build GrabCut init mask from coarse mask crop gc_mask = np.full((ch, cw), cv2.GC_BGD, dtype=np.uint8) local_fg = coarse_mask[y1:y2, x1:x2] # Erode to get definite FG core, dilate to get probable FG ring k_sm = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) k_lg = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)) def_fg = cv2.erode(local_fg, k_sm, iterations=2) prob_fg = cv2.dilate(local_fg, k_lg, iterations=2) gc_mask[prob_fg == 1] = cv2.GC_PR_FGD gc_mask[def_fg == 1] = cv2.GC_FGD # Border strip = definite background gc_mask[:3, :] = cv2.GC_BGD gc_mask[-3:, :] = cv2.GC_BGD gc_mask[:, :3] = cv2.GC_BGD gc_mask[:, -3:] = cv2.GC_BGD try: bgd_model = np.zeros((1, 65), np.float64) fgd_model = np.zeros((1, 65), np.float64) cv2.grabCut(crop, gc_mask, None, bgd_model, fgd_model, 4, cv2.GC_INIT_WITH_MASK) refined_local = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 1, 0).astype(np.uint8) except Exception: return coarse_mask # Clean up with morphology: close small holes, smooth jagged edges k_cl = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) refined_local = cv2.morphologyEx(refined_local, cv2.MORPH_CLOSE, k_cl, iterations=2) refined_local = cv2.morphologyEx(refined_local, cv2.MORPH_OPEN, k_cl, iterations=1) # Put refined crop back into full-size mask refined_full = np.zeros_like(coarse_mask) refined_full[y1:y2, x1:x2] = refined_local return refined_full def _mask_to_smooth_contour(mask_np): """ Extract the outer contour of a binary mask and smooth it with spline-like resampling → returns int32 array (N,1,2) for cv2 drawing. """ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if not contours: return None cnt = max(contours, key=cv2.contourArea).astype(np.float32).reshape(-1, 2) if len(cnt) < 6: return cnt.astype(np.int32).reshape(-1, 1, 2) # Resample to ~120 evenly-spaced points for a smooth outline n_target = min(120, max(40, len(cnt))) indices = np.linspace(0, len(cnt) - 1, n_target).astype(int) sampled = cnt[indices] # Circular Gaussian smooth window = 9 half = window // 2 padded = np.vstack([sampled[-half:], sampled, sampled[:half]]) kernel = cv2.getGaussianKernel(window, 0).flatten() kernel /= kernel.sum() smoothed = np.zeros_like(sampled) for i in range(len(sampled)): smoothed[i] = (padded[i:i + window] * kernel[:, None]).sum(axis=0) return smoothed.astype(np.int32).reshape(-1, 1, 2) # ───────────────────────────────────────────────────────────────────────────── # STEP 1 — Segmentation + visual output # # Uses results.masks.xy (polygon in original-image px coords) instead of # results.masks.data (low-res tensor + resize) → zero resize drift, # pixel-perfect mask alignment. # ───────────────────────────────────────────────────────────────────────────── def run_segmentation(img_np): h, w = img_np.shape[:2] img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) results = model(img_np, imgsz=1280, conf=0.25)[0] annotated = img_np.copy() overlay = img_np.copy() counts = {"Full": 0, "Broken": 0} grain_boxes = [] all_x1, all_y1, all_x2, all_y2 = w, h, 0, 0 if results.masks is not None: xy_list = results.masks.xy # list of (N_i, 2) float arrays, orig coords for poly_xy, box in zip(xy_list, results.boxes): if len(poly_xy) < 3: continue cls_id = int(box.cls[0]) cls_name = CLASS_NAMES.get(cls_id, "?") color = CLASS_COLORS.get(cls_id, (200, 200, 200)) counts[cls_name] += 1 # Backend mask: raw polygon fill (used for measurements — never changed) mask_np = _polygon_to_mask(poly_xy, h, w) # Visual mask: GrabCut-refined → hugs actual grain pixels perfectly vis_mask = _refine_mask_grabcut(img_bgr, mask_np) vis_contour = _mask_to_smooth_contour(vis_mask) # Bounding box from backend mask for zoom crop ys, xs = np.where(mask_np == 1) if len(xs) > 0: all_x1 = min(all_x1, int(xs.min())) all_y1 = min(all_y1, int(ys.min())) all_x2 = max(all_x2, int(xs.max())) all_y2 = max(all_y2, int(ys.max())) # Fill overlay using the refined visual mask directly (pixel-perfect fill) overlay[vis_mask == 1] = color grain_boxes.append({ "cls_id": cls_id, "cls_name": cls_name, "mask_np": mask_np, # backend only — measurements "vis_mask": vis_mask, # refined visual mask "vis_contour": vis_contour, # smooth contour for outline }) # Blend fill annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0) # Draw smooth anti-aliased contour outlines over the blend for g in grain_boxes: if g["vis_contour"] is not None: cv2.polylines( annotated, [g["vis_contour"]], isClosed=True, color=CLASS_COLORS[g["cls_id"]], thickness=2, lineType=cv2.LINE_AA, ) # Zoom if all_x2 > all_x1 and all_y2 > all_y1: pad = max(30, int(max(all_x2 - all_x1, all_y2 - all_y1) * 0.08)) cx1, cy1 = max(0, all_x1 - pad), max(0, all_y1 - pad) cx2, cy2 = min(w, all_x2 + pad), min(h, all_y2 + pad) zoomed_pil = Image.fromarray(annotated[cy1:cy2, cx1:cx2]) else: zoomed_pil = Image.fromarray(annotated) return annotated, zoomed_pil, grain_boxes, counts # ───────────────────────────────────────────────────────────────────────────── # STEP 2 — Measure grains (backend mask_np only — unaffected by visual changes) # ───────────────────────────────────────────────────────────────────────────── def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims): paper_px = (paper_dims[0] + paper_dims[1]) / 2.0 if paper_dims else None measurements = [] for idx, g in enumerate(grain_boxes): mask_np = g["mask_np"] pts_y, pts_x = np.where(mask_np == 1) if len(pts_x) < 5: continue pts = np.column_stack([pts_x.astype(np.float32), pts_y.astype(np.float32)]) rect = cv2.minAreaRect(pts) (cx, cy), (rw, rh), _ = rect h_px = float(max(rw, rh)) w_px = float(min(rw, rh)) h_mm = px_to_mm(h_px, paper_px) w_mm = px_to_mm(w_px, paper_px) area_mm2 = (h_mm * w_mm) if (h_mm and w_mm) else None measurements.append({ "label": idx + 1, "cls_name": g["cls_name"], "h_px": h_px, "w_px": w_px, "h_mm": h_mm, "w_mm": w_mm, "area_mm2": area_mm2, "centroid_x": int(cx), "centroid_y": int(cy), }) return measurements, paper_px # ───────────────────────────────────────────────────────────────────────────── # STEP 2b — Build DataFrames # ───────────────────────────────────────────────────────────────────────────── def build_table_data(measurements, paper_px, counts): has_mm = paper_px is not None unit = "mm" if has_mm else "px" rows = [] for g in measurements: h_val = round(g["h_mm"], 2) if (has_mm and g["h_mm"]) else round(g["h_px"], 1) w_val = round(g["w_mm"], 2) if (has_mm and g["w_mm"]) else round(g["w_px"], 1) area_val = round(g["area_mm2"], 2) if g["area_mm2"] else None rows.append({ "#": g["label"], "Class": g["cls_name"], f"Height ({unit})": h_val, f"Width ({unit})": w_val, "Area (mm\u00b2)" if has_mm else "Area": area_val, }) grain_df = pd.DataFrame(rows) h_key = "h_mm" if has_mm else "h_px" w_key = "w_mm" if has_mm else "w_px" heights = [(g["label"], g[h_key]) for g in measurements if g.get(h_key)] widths = [(g["label"], g[w_key]) for g in measurements if g.get(w_key)] max_h = max(heights, key=lambda x: x[1]) if heights else (0, 0) min_h = min(heights, key=lambda x: x[1]) if heights else (0, 0) max_w = max(widths, key=lambda x: x[1]) if widths else (0, 0) min_w = min(widths, key=lambda x: x[1]) if widths else (0, 0) interval = (max_h[1] - min_h[1]) / 10.0 if (heights and max_h[1] != min_h[1]) else 0.0 n_full = counts.get("Full", 0) n_broken = counts.get("Broken", 0) total = n_full + n_broken summary_rows = [ {"Metric": "Total Grains", "Value": str(total)}, {"Metric": "🟢 Full Grains", "Value": str(n_full)}, {"Metric": "🔴 Broken Grains", "Value": str(n_broken)}, {"Metric": "Paper Reference", "Value": f"✅ Found ({unit} mode)" if has_mm else "❌ Not found (px only)"}, {"Metric": f"Max Height (Grain #{max_h[0]})", "Value": f"{max_h[1]:.2f} {unit}"}, {"Metric": f"Min Height (Grain #{min_h[0]})", "Value": f"{min_h[1]:.2f} {unit}"}, {"Metric": f"Max Width (Grain #{max_w[0]})", "Value": f"{max_w[1]:.2f} {unit}"}, {"Metric": f"Min Width (Grain #{min_w[0]})", "Value": f"{min_w[1]:.2f} {unit}"}, {"Metric": "Mean Height", "Value": f"{np.mean([v for _, v in heights]):.2f} {unit}" if heights else "—"}, {"Metric": "Mean Width", "Value": f"{np.mean([v for _, v in widths]):.2f} {unit}" if widths else "—"}, {"Metric": "Bin Interval (max-min)/10", "Value": f"{interval:.3f} {unit}"}, ] summary_df = pd.DataFrame(summary_rows) return grain_df, summary_df # ───────────────────────────────────────────────────────────────────────────── # GRADIO — two-stage predict # ───────────────────────────────────────────────────────────────────────────── def predict_stage1(image: Image.Image): if image is None: return None, "", "", None, None img_np = np.array(image.convert("RGB")) _, zoomed_pil, grain_boxes, counts = run_segmentation(img_np) total = counts["Full"] + counts["Broken"] summary = f"✅ {total} grains detected · 🟢 Full: {counts['Full']} · 🔴 Broken: {counts['Broken']}" count_md = ( f"| | Count |\n|---|---|\n" f"| 🌾 Total Grains | **{total}** |\n" f"| 🟢 Full Grains | **{counts['Full']}** |\n" f"| 🔴 Broken Grains | **{counts['Broken']}** |" ) loading_df = pd.DataFrame([{"Status": "⏳ Calculating height & width of all grains..."}]) return zoomed_pil, summary, count_md, loading_df, loading_df def predict_stage2(image: Image.Image): if image is None: return None, "", "", None, None img_np = np.array(image.convert("RGB")) def _seg(): return run_segmentation(img_np) def _paper(): return detect_paper_pixels(img_np) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: fut_seg = pool.submit(_seg) fut_paper = pool.submit(_paper) _, zoomed_pil, grain_boxes, counts = fut_seg.result() paper_dims = fut_paper.result() measurements, paper_px = measure_grains_from_boxes(grain_boxes, img_np.shape, paper_dims) total = counts["Full"] + counts["Broken"] summary = ( f"✅ {total} grains detected · 🟢 Full: {counts['Full']} · 🔴 Broken: {counts['Broken']}" + (f" · 📏 Paper found — measurements in mm" if paper_px else " · ⚠️ No paper — measurements in px") ) count_md = ( f"| | Count |\n|---|---|\n" f"| 🌾 Total Grains | **{total}** |\n" f"| 🟢 Full Grains | **{counts['Full']}** |\n" f"| 🔴 Broken Grains | **{counts['Broken']}** |" ) grain_df, summary_df = build_table_data(measurements, paper_px, counts) return zoomed_pil, summary, count_md, grain_df, summary_df # ───────────────────────────────────────────────────────────────────────────── # UI — Gradio 6 compatible # • theme / css → moved to demo.launch() # • gr.DataFrame has no height param → use CSS to expand tables # ───────────────────────────────────────────────────────────────────────────── THEME = gr.themes.Soft( primary_hue="violet", secondary_hue="indigo", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), ) # In Gradio 6 the CSS string is passed to launch(), not Blocks() CSS = """ #run-btn { margin-top: 6px; } #status-box textarea { font-size: 0.92rem; } #count-box { font-size: 0.95rem; } /* Make both measurement tables tall enough to show all rows */ #grain-table .table-wrap, #grain-table .svelte-table, #summary-table .table-wrap, #summary-table .svelte-table { max-height: none !important; overflow-y: visible !important; } #grain-table, #summary-table { overflow: visible !important; } """ with gr.Blocks(title="GrainVision") as demo: gr.HTML("""
🌾 GrainVision

Upload a rice image (with white 4×4 cm reference paper) to segment, classify, measure, and analyse grains.

""") with gr.Row(equal_height=False): with gr.Column(scale=1): inp_image = gr.Image(type="pil", label="Upload Rice Image", height=280) run_btn = gr.Button("🔍 Analyse Grains", variant="primary", size="lg", elem_id="run-btn") gr.Markdown("_Upload an image then press **Analyse**. " "Segmentation appears first, measurements follow._") status_box = gr.Textbox( label="Status", value="", interactive=False, visible=True, max_lines=3, elem_id="status-box", ) gr.Markdown("### Example Images _(click to load)_") gr.Examples( examples=[[p] for p in SAMPLE_PATHS], inputs=inp_image, label="", examples_per_page=6, ) with gr.Column(scale=1): gr.Markdown("### Segmentation Output *(zoomed to grains)*") seg_out = gr.Image(label="", interactive=False) gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### Detection Summary") summary_box = gr.Textbox( label="", value="", interactive=False, max_lines=2, elem_id="status-box", ) with gr.Column(scale=1): gr.Markdown("#### Grain Count") count_md = gr.Markdown( value="| | Count |\n|---|---|\n" "| 🌾 Total | — |\n| 🟢 Full | — |\n| 🔴 Broken | — |", elem_id="count-box", ) gr.Markdown("---") gr.Markdown("### Grain Measurements Table") with gr.Row(): with gr.Column(scale=2): gr.Markdown("#### Per-Grain Measurements") grain_table_out = gr.DataFrame( label="", interactive=False, wrap=False, elem_id="grain-table", ) with gr.Column(scale=1): gr.Markdown("#### Summary Statistics") summary_table_out = gr.DataFrame( label="", interactive=False, wrap=False, elem_id="summary-table", ) OUTPUTS = [seg_out, summary_box, count_md, grain_table_out, summary_table_out] run_btn.click( fn=predict_stage1, inputs=[inp_image], outputs=OUTPUTS, ).then( fn=predict_stage2, inputs=[inp_image], outputs=OUTPUTS, ) if __name__ == "__main__": # Gradio 6: theme and css passed here, not in gr.Blocks() demo.launch(theme=THEME, css=CSS)