import gradio as gr import cv2 import numpy as np import io import os import zipfile import tempfile from PIL import Image import matplotlib matplotlib.use("Agg") # ─── Cellpose model (lazy) ──────────────────────────────────────────────────── _model = None def get_model(): global _model if _model is None: from cellpose import models from huggingface_hub import hf_hub_download fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam") _model = models.CellposeModel(gpu=False, pretrained_model=fpath) return _model # ─── Image helpers ──────────────────────────────────────────────────────────── def normalize99(img): X = img.copy().astype(np.float32) p1, p99 = np.percentile(X, 1), np.percentile(X, 99) return (X - p1) / (1e-10 + p99 - p1) def image_resize(img, resize=1000): ny, nx = img.shape[:2] if max(ny, nx) > resize: if ny > nx: nx = int(nx / ny * resize); ny = resize else: ny = int(ny / nx * resize); nx = resize img = cv2.resize(img, (nx, ny)) return img.astype(np.uint8) def run_cellpose(img, model, flow_threshold=0.4, cellprob_threshold=0.0): masks, flows, _ = model.eval( img, niter=250, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, ) return masks # ─── YOLO Annotation Exporter ───────────────────────────────────────────────── def export_yolo_annotations(masks, img_shape, class_id=0): """ Converts Cellpose masks → YOLO segmentation format. YOLO segmentation line format: class_id x1 y1 x2 y2 ... (all normalized 0–1) class_id = 0 → 'grain' (you will split into broken/whole on Roboflow) """ h, w = img_shape[:2] lines = [] num_grains = int(masks.max()) for i in range(1, num_grains + 1): # Binary mask for this single grain single = (masks == i).astype(np.uint8) contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: continue # Pick the largest contour (in case of tiny noise) c = max(contours, key=cv2.contourArea) c = c.squeeze() if c.ndim < 2 or len(c) < 4: continue # Normalize each point to [0, 1] norm_pts = [] for x, y in c: norm_pts.append(round(float(x) / w, 6)) norm_pts.append(round(float(y) / h, 6)) pts_str = " ".join(map(str, norm_pts)) lines.append(f"{class_id} {pts_str}") return "\n".join(lines), num_grains def make_preview(img_np, masks): """Draw red outlines of all grain masks on the image for preview.""" preview = img_np.copy() for i in range(1, int(masks.max()) + 1): single = (masks == i).astype(np.uint8) contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(preview, contours, -1, (220, 38, 38), 2) return Image.fromarray(preview) # ─── Main batch processor ───────────────────────────────────────────────────── def process_batch(image_files, flow_threshold, cellprob_threshold, progress=gr.Progress()): """ Takes a list of uploaded image file paths. Returns: - Gallery of preview images (with outlines) - Summary text - Path to downloadable ZIP """ if not image_files: return [], "⚠️ No images uploaded.", None model = get_model() previews = [] # (PIL image, caption) for gallery log_lines = [] total_grains = 0 failed = [] # Temp folder to collect annotation files tmp_dir = tempfile.mkdtemp() images_dir = os.path.join(tmp_dir, "images") labels_dir = os.path.join(tmp_dir, "labels") os.makedirs(images_dir, exist_ok=True) os.makedirs(labels_dir, exist_ok=True) for idx, file_obj in enumerate(progress.tqdm(image_files, desc="Processing images")): # file_obj is a filepath string when using gr.File with type="filepath" filepath = file_obj if isinstance(file_obj, str) else file_obj.name fname = os.path.splitext(os.path.basename(filepath))[0] try: pil_img = Image.open(filepath).convert("RGB") img_np = np.array(pil_img) img_np = image_resize(img_np, resize=1000) masks = run_cellpose(img_np, model, flow_threshold=float(flow_threshold), cellprob_threshold=float(cellprob_threshold)) num_grains = int(masks.max()) if num_grains == 0: log_lines.append(f"⚠️ [{idx+1}] {fname} — No grains detected, skipped.") failed.append(fname) continue # Export YOLO annotation txt annotation_txt, _ = export_yolo_annotations(masks, img_np.shape, class_id=0) txt_path = os.path.join(labels_dir, f"{fname}.txt") with open(txt_path, "w") as f: f.write(annotation_txt) # Save image to images/ img_save_path = os.path.join(images_dir, f"{fname}.jpg") Image.fromarray(img_np).save(img_save_path, quality=95) # Make preview preview_pil = make_preview(img_np, masks) previews.append((preview_pil, f"{fname} — {num_grains} grains")) total_grains += num_grains log_lines.append(f"✅ [{idx+1}] {fname} — {num_grains} grains annotated.") except Exception as e: log_lines.append(f"❌ [{idx+1}] {fname} — Error: {str(e)}") failed.append(fname) # ── Write data.yaml ─────────────────────────────────────────────────────── yaml_content = ( "# YOLO Dataset — Rice Grain Segmentation\n" "# Generated by MLBench Annotation Tool\n\n" "path: ./dataset\n" "train: images/train\n" "val: images/val\n\n" "nc: 2\n" "names:\n" " 0: whole_grain\n" " 1: broken_grain\n\n" "# NOTE: All grains are currently class 0 (whole_grain).\n" "# Upload to Roboflow and re-label broken grains as class 1.\n" ) with open(os.path.join(tmp_dir, "data.yaml"), "w") as f: f.write(yaml_content) # ── Write README ────────────────────────────────────────────────────────── readme = ( "# Rice Grain YOLO Dataset\n\n" "## Folder Structure\n" "```\n" "dataset/\n" " images/ ← your rice photos (.jpg)\n" " labels/ ← YOLO polygon annotations (.txt)\n" " data.yaml ← class config for YOLO training\n" "```\n\n" "## Label Format (YOLO Segmentation)\n" "Each .txt file has one line per grain:\n" "```\n" "class_id x1 y1 x2 y2 x3 y3 ... (normalized 0–1)\n" "```\n\n" "## Classes\n" "| ID | Name |\n" "|----|-------------|\n" "| 0 | whole_grain |\n" "| 1 | broken_grain |\n\n" "## Next Steps\n" "1. Upload this zip to **Roboflow** (Import > YOLOv8 Segmentation format)\n" "2. Re-label broken grains as class `1` in Roboflow\n" "3. Export from Roboflow as YOLOv8 format\n" "4. Train: `yolo segment train data=data.yaml model=yolov8n-seg.pt epochs=100`\n" ) with open(os.path.join(tmp_dir, "README.md"), "w") as f: f.write(readme) # ── Package as ZIP ──────────────────────────────────────────────────────── zip_path = os.path.join(tempfile.mkdtemp(), "rice_yolo_dataset.zip") with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for root, _, files in os.walk(tmp_dir): for file in files: full_path = os.path.join(root, file) arcname = os.path.relpath(full_path, tmp_dir) zf.write(full_path, arcname) # ── Summary ─────────────────────────────────────────────────────────────── ok_count = len(image_files) - len(failed) summary = ( f"### ✅ Done!\n" f"- **{ok_count} / {len(image_files)}** images processed\n" f"- **{total_grains}** total grains annotated\n" f"- **{len(failed)}** failed: {', '.join(failed) if failed else 'none'}\n\n" "**Download the ZIP below → upload to Roboflow → label broken grains → train YOLO!**\n\n" "---\n" + "\n".join(log_lines) ) return previews, summary, zip_path # ─── UI ─────────────────────────────────────────────────────────────────────── CSS = """ body { font-family: 'IBM Plex Mono', monospace; } #header { background: #0F172A; padding: 20px 24px 14px; border-radius: 10px; margin-bottom: 12px; } #run-btn { margin-top: 8px; background: #7C3AED !important; } #dl-btn { margin-top: 6px; } .gr-gallery-item img { border-radius: 6px; } """ THEME = gr.themes.Soft( primary_hue="violet", secondary_hue="indigo", neutral_hue="slate", ) with gr.Blocks(theme=THEME, css=CSS, title="Rice YOLO Annotator") as demo: gr.HTML("""
Upload up to 50 images · Cellpose segments each grain · Download ZIP with YOLO labels ready for Roboflow