| import gradio as gr |
| import cv2 |
| import numpy as np |
| import pandas as pd |
| import concurrent.futures |
| from PIL import Image, ImageDraw, ImageFont |
| from ultralytics import YOLO |
|
|
| |
| model = YOLO("best.pt") |
|
|
| CLASS_NAMES = {0: "Full", 1: "Broken"} |
| CLASS_COLORS = {0: (34, 197, 94), 1: (239, 68, 68)} |
|
|
| SAMPLE_PATHS = ["image1.jpg", "image2.jpg"] |
|
|
| |
| PAPER_REAL_MM = 40.0 |
|
|
| def detect_paper_pixels(img_np): |
| gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) |
| _, thresh = cv2.threshold(gray, 180, 255, cv2.THRESH_BINARY) |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7)) |
| thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) |
| thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel) |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| img_area = img_np.shape[0] * img_np.shape[1] |
| best, best_area = None, 0 |
| for c in contours: |
| area = cv2.contourArea(c) |
| if area < img_area * 0.02: |
| continue |
| x, y, w, h = cv2.boundingRect(c) |
| if 0.5 < (w / max(h, 1)) < 2.0 and area > best_area: |
| best_area = area |
| best = (h, w) |
| return best |
|
|
| def px_to_mm(px, paper_px_dim): |
| if not paper_px_dim: |
| return None |
| return px * PAPER_REAL_MM / paper_px_dim |
|
|
| |
| def _font(size, bold=False): |
| for path in [ |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" if bold else |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", |
| "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", |
| ]: |
| try: |
| return ImageFont.truetype(path, size) |
| except Exception: |
| pass |
| return ImageFont.load_default() |
|
|
| def _text_size(draw, text, font): |
| bbox = draw.textbbox((0, 0), text, font=font) |
| return bbox[2] - bbox[0], bbox[3] - bbox[1] |
|
|
|
|
| |
|
|
| def _polygon_to_mask(pts_xy, h, w): |
| """Rasterise raw polygon β binary uint8 mask. BACKEND / measurements only.""" |
| mask = np.zeros((h, w), dtype=np.uint8) |
| if len(pts_xy) >= 3: |
| cv2.fillPoly(mask, [pts_xy.astype(np.int32)], 1) |
| return mask |
|
|
|
|
| def _refine_mask_grabcut(img_bgr, coarse_mask): |
| """ |
| Refine a coarse binary mask to pixel-perfect grain boundary using GrabCut. |
| img_bgr : full BGR image |
| coarse_mask : uint8 binary mask (0/1), same size as img_bgr |
| Returns : refined binary uint8 mask (0/1) |
| """ |
| ys, xs = np.where(coarse_mask == 1) |
| if len(xs) < 5: |
| return coarse_mask |
|
|
| |
| x1, y1 = max(0, int(xs.min()) - 6), max(0, int(ys.min()) - 6) |
| x2, y2 = min(img_bgr.shape[1], int(xs.max()) + 6), min(img_bgr.shape[0], int(ys.max()) + 6) |
| crop = img_bgr[y1:y2, x1:x2] |
| ch, cw = crop.shape[:2] |
| if ch < 8 or cw < 8: |
| return coarse_mask |
|
|
| |
| gc_mask = np.full((ch, cw), cv2.GC_BGD, dtype=np.uint8) |
| local_fg = coarse_mask[y1:y2, x1:x2] |
|
|
| |
| k_sm = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| k_lg = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)) |
| def_fg = cv2.erode(local_fg, k_sm, iterations=2) |
| prob_fg = cv2.dilate(local_fg, k_lg, iterations=2) |
|
|
| gc_mask[prob_fg == 1] = cv2.GC_PR_FGD |
| gc_mask[def_fg == 1] = cv2.GC_FGD |
| |
| gc_mask[:3, :] = cv2.GC_BGD |
| gc_mask[-3:, :] = cv2.GC_BGD |
| gc_mask[:, :3] = cv2.GC_BGD |
| gc_mask[:, -3:] = cv2.GC_BGD |
|
|
| try: |
| bgd_model = np.zeros((1, 65), np.float64) |
| fgd_model = np.zeros((1, 65), np.float64) |
| cv2.grabCut(crop, gc_mask, None, bgd_model, fgd_model, 4, cv2.GC_INIT_WITH_MASK) |
| refined_local = np.where((gc_mask == cv2.GC_FGD) | (gc_mask == cv2.GC_PR_FGD), 1, 0).astype(np.uint8) |
| except Exception: |
| return coarse_mask |
|
|
| |
| k_cl = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| refined_local = cv2.morphologyEx(refined_local, cv2.MORPH_CLOSE, k_cl, iterations=2) |
| refined_local = cv2.morphologyEx(refined_local, cv2.MORPH_OPEN, k_cl, iterations=1) |
|
|
| |
| refined_full = np.zeros_like(coarse_mask) |
| refined_full[y1:y2, x1:x2] = refined_local |
| return refined_full |
|
|
|
|
| def _mask_to_smooth_contour(mask_np): |
| """ |
| Extract the outer contour of a binary mask and smooth it with |
| spline-like resampling β returns int32 array (N,1,2) for cv2 drawing. |
| """ |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| if not contours: |
| return None |
| cnt = max(contours, key=cv2.contourArea).astype(np.float32).reshape(-1, 2) |
| if len(cnt) < 6: |
| return cnt.astype(np.int32).reshape(-1, 1, 2) |
|
|
| |
| n_target = min(120, max(40, len(cnt))) |
| indices = np.linspace(0, len(cnt) - 1, n_target).astype(int) |
| sampled = cnt[indices] |
|
|
| |
| window = 9 |
| half = window // 2 |
| padded = np.vstack([sampled[-half:], sampled, sampled[:half]]) |
| kernel = cv2.getGaussianKernel(window, 0).flatten() |
| kernel /= kernel.sum() |
| smoothed = np.zeros_like(sampled) |
| for i in range(len(sampled)): |
| smoothed[i] = (padded[i:i + window] * kernel[:, None]).sum(axis=0) |
|
|
| return smoothed.astype(np.int32).reshape(-1, 1, 2) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| def run_segmentation(img_np): |
| h, w = img_np.shape[:2] |
| img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) |
| results = model(img_np, imgsz=1280, conf=0.25)[0] |
|
|
| annotated = img_np.copy() |
| overlay = img_np.copy() |
| counts = {"Full": 0, "Broken": 0} |
| grain_boxes = [] |
|
|
| all_x1, all_y1, all_x2, all_y2 = w, h, 0, 0 |
|
|
| if results.masks is not None: |
| xy_list = results.masks.xy |
|
|
| for poly_xy, box in zip(xy_list, results.boxes): |
| if len(poly_xy) < 3: |
| continue |
|
|
| cls_id = int(box.cls[0]) |
| cls_name = CLASS_NAMES.get(cls_id, "?") |
| color = CLASS_COLORS.get(cls_id, (200, 200, 200)) |
| counts[cls_name] += 1 |
|
|
| |
| mask_np = _polygon_to_mask(poly_xy, h, w) |
|
|
| |
| vis_mask = _refine_mask_grabcut(img_bgr, mask_np) |
| vis_contour = _mask_to_smooth_contour(vis_mask) |
|
|
| |
| ys, xs = np.where(mask_np == 1) |
| if len(xs) > 0: |
| all_x1 = min(all_x1, int(xs.min())) |
| all_y1 = min(all_y1, int(ys.min())) |
| all_x2 = max(all_x2, int(xs.max())) |
| all_y2 = max(all_y2, int(ys.max())) |
|
|
| |
| overlay[vis_mask == 1] = color |
|
|
| grain_boxes.append({ |
| "cls_id": cls_id, |
| "cls_name": cls_name, |
| "mask_np": mask_np, |
| "vis_mask": vis_mask, |
| "vis_contour": vis_contour, |
| }) |
|
|
| |
| annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0) |
|
|
| |
| for g in grain_boxes: |
| if g["vis_contour"] is not None: |
| cv2.polylines( |
| annotated, [g["vis_contour"]], |
| isClosed=True, color=CLASS_COLORS[g["cls_id"]], thickness=2, |
| lineType=cv2.LINE_AA, |
| ) |
|
|
| |
| if all_x2 > all_x1 and all_y2 > all_y1: |
| pad = max(30, int(max(all_x2 - all_x1, all_y2 - all_y1) * 0.08)) |
| cx1, cy1 = max(0, all_x1 - pad), max(0, all_y1 - pad) |
| cx2, cy2 = min(w, all_x2 + pad), min(h, all_y2 + pad) |
| zoomed_pil = Image.fromarray(annotated[cy1:cy2, cx1:cx2]) |
| else: |
| zoomed_pil = Image.fromarray(annotated) |
|
|
| return annotated, zoomed_pil, grain_boxes, counts |
|
|
|
|
| |
| |
| |
| def measure_grains_from_boxes(grain_boxes, img_shape, paper_dims): |
| paper_px = (paper_dims[0] + paper_dims[1]) / 2.0 if paper_dims else None |
| measurements = [] |
|
|
| for idx, g in enumerate(grain_boxes): |
| mask_np = g["mask_np"] |
| pts_y, pts_x = np.where(mask_np == 1) |
| if len(pts_x) < 5: |
| continue |
|
|
| pts = np.column_stack([pts_x.astype(np.float32), pts_y.astype(np.float32)]) |
| rect = cv2.minAreaRect(pts) |
| (cx, cy), (rw, rh), _ = rect |
|
|
| h_px = float(max(rw, rh)) |
| w_px = float(min(rw, rh)) |
| h_mm = px_to_mm(h_px, paper_px) |
| w_mm = px_to_mm(w_px, paper_px) |
| area_mm2 = (h_mm * w_mm) if (h_mm and w_mm) else None |
|
|
| measurements.append({ |
| "label": idx + 1, |
| "cls_name": g["cls_name"], |
| "h_px": h_px, |
| "w_px": w_px, |
| "h_mm": h_mm, |
| "w_mm": w_mm, |
| "area_mm2": area_mm2, |
| "centroid_x": int(cx), |
| "centroid_y": int(cy), |
| }) |
|
|
| return measurements, paper_px |
|
|
|
|
| |
| |
| |
| def build_table_data(measurements, paper_px, counts): |
| has_mm = paper_px is not None |
| unit = "mm" if has_mm else "px" |
|
|
| rows = [] |
| for g in measurements: |
| h_val = round(g["h_mm"], 2) if (has_mm and g["h_mm"]) else round(g["h_px"], 1) |
| w_val = round(g["w_mm"], 2) if (has_mm and g["w_mm"]) else round(g["w_px"], 1) |
| area_val = round(g["area_mm2"], 2) if g["area_mm2"] else None |
| rows.append({ |
| "#": g["label"], |
| "Class": g["cls_name"], |
| f"Height ({unit})": h_val, |
| f"Width ({unit})": w_val, |
| "Area (mm\u00b2)" if has_mm else "Area": area_val, |
| }) |
| grain_df = pd.DataFrame(rows) |
|
|
| h_key = "h_mm" if has_mm else "h_px" |
| w_key = "w_mm" if has_mm else "w_px" |
| heights = [(g["label"], g[h_key]) for g in measurements if g.get(h_key)] |
| widths = [(g["label"], g[w_key]) for g in measurements if g.get(w_key)] |
|
|
| max_h = max(heights, key=lambda x: x[1]) if heights else (0, 0) |
| min_h = min(heights, key=lambda x: x[1]) if heights else (0, 0) |
| max_w = max(widths, key=lambda x: x[1]) if widths else (0, 0) |
| min_w = min(widths, key=lambda x: x[1]) if widths else (0, 0) |
| interval = (max_h[1] - min_h[1]) / 10.0 if (heights and max_h[1] != min_h[1]) else 0.0 |
|
|
| n_full = counts.get("Full", 0) |
| n_broken = counts.get("Broken", 0) |
| total = n_full + n_broken |
|
|
| summary_rows = [ |
| {"Metric": "Total Grains", "Value": str(total)}, |
| {"Metric": "π’ Full Grains", "Value": str(n_full)}, |
| {"Metric": "π΄ Broken Grains", "Value": str(n_broken)}, |
| {"Metric": "Paper Reference", "Value": f"β
Found ({unit} mode)" if has_mm else "β Not found (px only)"}, |
| {"Metric": f"Max Height (Grain #{max_h[0]})", "Value": f"{max_h[1]:.2f} {unit}"}, |
| {"Metric": f"Min Height (Grain #{min_h[0]})", "Value": f"{min_h[1]:.2f} {unit}"}, |
| {"Metric": f"Max Width (Grain #{max_w[0]})", "Value": f"{max_w[1]:.2f} {unit}"}, |
| {"Metric": f"Min Width (Grain #{min_w[0]})", "Value": f"{min_w[1]:.2f} {unit}"}, |
| {"Metric": "Mean Height", "Value": f"{np.mean([v for _, v in heights]):.2f} {unit}" if heights else "β"}, |
| {"Metric": "Mean Width", "Value": f"{np.mean([v for _, v in widths]):.2f} {unit}" if widths else "β"}, |
| {"Metric": "Bin Interval (max-min)/10", "Value": f"{interval:.3f} {unit}"}, |
| ] |
| summary_df = pd.DataFrame(summary_rows) |
| return grain_df, summary_df |
|
|
|
|
| |
| |
| |
| def predict_stage1(image: Image.Image): |
| if image is None: |
| return None, "", "", None, None |
| img_np = np.array(image.convert("RGB")) |
| _, zoomed_pil, grain_boxes, counts = run_segmentation(img_np) |
| total = counts["Full"] + counts["Broken"] |
| summary = f"β
{total} grains detected Β· π’ Full: {counts['Full']} Β· π΄ Broken: {counts['Broken']}" |
| count_md = ( |
| f"| | Count |\n|---|---|\n" |
| f"| πΎ Total Grains | **{total}** |\n" |
| f"| π’ Full Grains | **{counts['Full']}** |\n" |
| f"| π΄ Broken Grains | **{counts['Broken']}** |" |
| ) |
| loading_df = pd.DataFrame([{"Status": "β³ Calculating height & width of all grains..."}]) |
| return zoomed_pil, summary, count_md, loading_df, loading_df |
|
|
|
|
| def predict_stage2(image: Image.Image): |
| if image is None: |
| return None, "", "", None, None |
| img_np = np.array(image.convert("RGB")) |
|
|
| def _seg(): return run_segmentation(img_np) |
| def _paper(): return detect_paper_pixels(img_np) |
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: |
| fut_seg = pool.submit(_seg) |
| fut_paper = pool.submit(_paper) |
| _, zoomed_pil, grain_boxes, counts = fut_seg.result() |
| paper_dims = fut_paper.result() |
|
|
| measurements, paper_px = measure_grains_from_boxes(grain_boxes, img_np.shape, paper_dims) |
| total = counts["Full"] + counts["Broken"] |
| summary = ( |
| f"β
{total} grains detected Β· π’ Full: {counts['Full']} Β· π΄ Broken: {counts['Broken']}" |
| + (f" Β· π Paper found β measurements in mm" if paper_px else " Β· β οΈ No paper β measurements in px") |
| ) |
| count_md = ( |
| f"| | Count |\n|---|---|\n" |
| f"| πΎ Total Grains | **{total}** |\n" |
| f"| π’ Full Grains | **{counts['Full']}** |\n" |
| f"| π΄ Broken Grains | **{counts['Broken']}** |" |
| ) |
| grain_df, summary_df = build_table_data(measurements, paper_px, counts) |
| return zoomed_pil, summary, count_md, grain_df, summary_df |
|
|
|
|
| |
| |
| |
| |
| |
| THEME = gr.themes.Soft( |
| primary_hue="violet", |
| secondary_hue="indigo", |
| neutral_hue="slate", |
| font=gr.themes.GoogleFont("Inter"), |
| ) |
|
|
| |
| CSS = """ |
| #run-btn { margin-top: 6px; } |
| #status-box textarea { font-size: 0.92rem; } |
| #count-box { font-size: 0.95rem; } |
| |
| /* Make both measurement tables tall enough to show all rows */ |
| #grain-table .table-wrap, |
| #grain-table .svelte-table, |
| #summary-table .table-wrap, |
| #summary-table .svelte-table { |
| max-height: none !important; |
| overflow-y: visible !important; |
| } |
| #grain-table, |
| #summary-table { |
| overflow: visible !important; |
| } |
| """ |
|
|
| with gr.Blocks(title="GrainVision") as demo: |
|
|
| gr.HTML(""" |
| <div style="padding:18px 12px 10px 12px; background-color:#0F172A; |
| border-radius:10px; margin-bottom:10px;"> |
| <span style="font-size:2rem;font-weight:900;color:#F1F5F9;font-family:sans-serif;"> |
| πΎ GrainVision |
| </span> |
| <p style="color:#CBD5E1;font-size:0.9rem;margin-top:4px;font-family:sans-serif;"> |
| Upload a rice image (with white 4Γ4 cm reference paper) to segment, classify, |
| measure, and analyse grains. |
| </p> |
| </div> |
| """) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=1): |
| inp_image = gr.Image(type="pil", label="Upload Rice Image", height=280) |
| run_btn = gr.Button("π Analyse Grains", |
| variant="primary", size="lg", elem_id="run-btn") |
| gr.Markdown("_Upload an image then press **Analyse**. " |
| "Segmentation appears first, measurements follow._") |
| status_box = gr.Textbox( |
| label="Status", value="", interactive=False, |
| visible=True, max_lines=3, elem_id="status-box", |
| ) |
| gr.Markdown("### Example Images _(click to load)_") |
| gr.Examples( |
| examples=[[p] for p in SAMPLE_PATHS], |
| inputs=inp_image, label="", examples_per_page=6, |
| ) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Segmentation Output *(zoomed to grains)*") |
| seg_out = gr.Image(label="", interactive=False) |
|
|
| gr.Markdown("---") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("#### Detection Summary") |
| summary_box = gr.Textbox( |
| label="", value="", interactive=False, |
| max_lines=2, elem_id="status-box", |
| ) |
| with gr.Column(scale=1): |
| gr.Markdown("#### Grain Count") |
| count_md = gr.Markdown( |
| value="| | Count |\n|---|---|\n" |
| "| πΎ Total | β |\n| π’ Full | β |\n| π΄ Broken | β |", |
| elem_id="count-box", |
| ) |
|
|
| gr.Markdown("---") |
| gr.Markdown("### Grain Measurements Table") |
| with gr.Row(): |
| with gr.Column(scale=2): |
| gr.Markdown("#### Per-Grain Measurements") |
| grain_table_out = gr.DataFrame( |
| label="", interactive=False, wrap=False, |
| elem_id="grain-table", |
| ) |
| with gr.Column(scale=1): |
| gr.Markdown("#### Summary Statistics") |
| summary_table_out = gr.DataFrame( |
| label="", interactive=False, wrap=False, |
| elem_id="summary-table", |
| ) |
|
|
| OUTPUTS = [seg_out, summary_box, count_md, grain_table_out, summary_table_out] |
|
|
| run_btn.click( |
| fn=predict_stage1, inputs=[inp_image], outputs=OUTPUTS, |
| ).then( |
| fn=predict_stage2, inputs=[inp_image], outputs=OUTPUTS, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| demo.launch(theme=THEME, css=CSS) |