Spaces:
Paused
Paused
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import io | |
| import os | |
| import zipfile | |
| import tempfile | |
| from PIL import Image | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| # βββ Cellpose model (lazy) ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| def get_model(): | |
| global _model | |
| if _model is None: | |
| from cellpose import models | |
| from huggingface_hub import hf_hub_download | |
| fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam") | |
| _model = models.CellposeModel(gpu=False, pretrained_model=fpath) | |
| return _model | |
| # βββ Image helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def normalize99(img): | |
| X = img.copy().astype(np.float32) | |
| p1, p99 = np.percentile(X, 1), np.percentile(X, 99) | |
| return (X - p1) / (1e-10 + p99 - p1) | |
| def image_resize(img, resize=1000): | |
| ny, nx = img.shape[:2] | |
| if max(ny, nx) > resize: | |
| if ny > nx: | |
| nx = int(nx / ny * resize); ny = resize | |
| else: | |
| ny = int(ny / nx * resize); nx = resize | |
| img = cv2.resize(img, (nx, ny)) | |
| return img.astype(np.uint8) | |
| def run_cellpose(img, model, flow_threshold=0.4, cellprob_threshold=0.0): | |
| masks, flows, _ = model.eval( | |
| img, niter=250, | |
| flow_threshold=flow_threshold, | |
| cellprob_threshold=cellprob_threshold, | |
| ) | |
| return masks | |
| # βββ YOLO Annotation Exporter βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def export_yolo_annotations(masks, img_shape, class_id=0): | |
| """ | |
| Converts Cellpose masks β YOLO segmentation format. | |
| YOLO segmentation line format: | |
| class_id x1 y1 x2 y2 ... (all normalized 0β1) | |
| class_id = 0 β 'grain' (you will split into broken/whole on Roboflow) | |
| """ | |
| h, w = img_shape[:2] | |
| lines = [] | |
| num_grains = int(masks.max()) | |
| for i in range(1, num_grains + 1): | |
| # Binary mask for this single grain | |
| single = (masks == i).astype(np.uint8) | |
| contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| continue | |
| # Pick the largest contour (in case of tiny noise) | |
| c = max(contours, key=cv2.contourArea) | |
| c = c.squeeze() | |
| if c.ndim < 2 or len(c) < 4: | |
| continue | |
| # Normalize each point to [0, 1] | |
| norm_pts = [] | |
| for x, y in c: | |
| norm_pts.append(round(float(x) / w, 6)) | |
| norm_pts.append(round(float(y) / h, 6)) | |
| pts_str = " ".join(map(str, norm_pts)) | |
| lines.append(f"{class_id} {pts_str}") | |
| return "\n".join(lines), num_grains | |
| def make_preview(img_np, masks): | |
| """Draw red outlines of all grain masks on the image for preview.""" | |
| preview = img_np.copy() | |
| for i in range(1, int(masks.max()) + 1): | |
| single = (masks == i).astype(np.uint8) | |
| contours, _ = cv2.findContours(single, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(preview, contours, -1, (220, 38, 38), 2) | |
| return Image.fromarray(preview) | |
| # βββ Main batch processor βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_batch(image_files, flow_threshold, cellprob_threshold, progress=gr.Progress()): | |
| """ | |
| Takes a list of uploaded image file paths. | |
| Returns: | |
| - Gallery of preview images (with outlines) | |
| - Summary text | |
| - Path to downloadable ZIP | |
| """ | |
| if not image_files: | |
| return [], "β οΈ No images uploaded.", None | |
| model = get_model() | |
| previews = [] # (PIL image, caption) for gallery | |
| log_lines = [] | |
| total_grains = 0 | |
| failed = [] | |
| # Temp folder to collect annotation files | |
| tmp_dir = tempfile.mkdtemp() | |
| images_dir = os.path.join(tmp_dir, "images") | |
| labels_dir = os.path.join(tmp_dir, "labels") | |
| os.makedirs(images_dir, exist_ok=True) | |
| os.makedirs(labels_dir, exist_ok=True) | |
| for idx, file_obj in enumerate(progress.tqdm(image_files, desc="Processing images")): | |
| # file_obj is a filepath string when using gr.File with type="filepath" | |
| filepath = file_obj if isinstance(file_obj, str) else file_obj.name | |
| fname = os.path.splitext(os.path.basename(filepath))[0] | |
| try: | |
| pil_img = Image.open(filepath).convert("RGB") | |
| img_np = np.array(pil_img) | |
| img_np = image_resize(img_np, resize=1000) | |
| masks = run_cellpose(img_np, model, | |
| flow_threshold=float(flow_threshold), | |
| cellprob_threshold=float(cellprob_threshold)) | |
| num_grains = int(masks.max()) | |
| if num_grains == 0: | |
| log_lines.append(f"β οΈ [{idx+1}] {fname} β No grains detected, skipped.") | |
| failed.append(fname) | |
| continue | |
| # Export YOLO annotation txt | |
| annotation_txt, _ = export_yolo_annotations(masks, img_np.shape, class_id=0) | |
| txt_path = os.path.join(labels_dir, f"{fname}.txt") | |
| with open(txt_path, "w") as f: | |
| f.write(annotation_txt) | |
| # Save image to images/ | |
| img_save_path = os.path.join(images_dir, f"{fname}.jpg") | |
| Image.fromarray(img_np).save(img_save_path, quality=95) | |
| # Make preview | |
| preview_pil = make_preview(img_np, masks) | |
| previews.append((preview_pil, f"{fname} β {num_grains} grains")) | |
| total_grains += num_grains | |
| log_lines.append(f"β [{idx+1}] {fname} β {num_grains} grains annotated.") | |
| except Exception as e: | |
| log_lines.append(f"β [{idx+1}] {fname} β Error: {str(e)}") | |
| failed.append(fname) | |
| # ββ Write data.yaml βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| yaml_content = ( | |
| "# YOLO Dataset β Rice Grain Segmentation\n" | |
| "# Generated by MLBench Annotation Tool\n\n" | |
| "path: ./dataset\n" | |
| "train: images/train\n" | |
| "val: images/val\n\n" | |
| "nc: 2\n" | |
| "names:\n" | |
| " 0: whole_grain\n" | |
| " 1: broken_grain\n\n" | |
| "# NOTE: All grains are currently class 0 (whole_grain).\n" | |
| "# Upload to Roboflow and re-label broken grains as class 1.\n" | |
| ) | |
| with open(os.path.join(tmp_dir, "data.yaml"), "w") as f: | |
| f.write(yaml_content) | |
| # ββ Write README ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| readme = ( | |
| "# Rice Grain YOLO Dataset\n\n" | |
| "## Folder Structure\n" | |
| "```\n" | |
| "dataset/\n" | |
| " images/ β your rice photos (.jpg)\n" | |
| " labels/ β YOLO polygon annotations (.txt)\n" | |
| " data.yaml β class config for YOLO training\n" | |
| "```\n\n" | |
| "## Label Format (YOLO Segmentation)\n" | |
| "Each .txt file has one line per grain:\n" | |
| "```\n" | |
| "class_id x1 y1 x2 y2 x3 y3 ... (normalized 0β1)\n" | |
| "```\n\n" | |
| "## Classes\n" | |
| "| ID | Name |\n" | |
| "|----|-------------|\n" | |
| "| 0 | whole_grain |\n" | |
| "| 1 | broken_grain |\n\n" | |
| "## Next Steps\n" | |
| "1. Upload this zip to **Roboflow** (Import > YOLOv8 Segmentation format)\n" | |
| "2. Re-label broken grains as class `1` in Roboflow\n" | |
| "3. Export from Roboflow as YOLOv8 format\n" | |
| "4. Train: `yolo segment train data=data.yaml model=yolov8n-seg.pt epochs=100`\n" | |
| ) | |
| with open(os.path.join(tmp_dir, "README.md"), "w") as f: | |
| f.write(readme) | |
| # ββ Package as ZIP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| zip_path = os.path.join(tempfile.mkdtemp(), "rice_yolo_dataset.zip") | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(tmp_dir): | |
| for file in files: | |
| full_path = os.path.join(root, file) | |
| arcname = os.path.relpath(full_path, tmp_dir) | |
| zf.write(full_path, arcname) | |
| # ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ok_count = len(image_files) - len(failed) | |
| summary = ( | |
| f"### β Done!\n" | |
| f"- **{ok_count} / {len(image_files)}** images processed\n" | |
| f"- **{total_grains}** total grains annotated\n" | |
| f"- **{len(failed)}** failed: {', '.join(failed) if failed else 'none'}\n\n" | |
| "**Download the ZIP below β upload to Roboflow β label broken grains β train YOLO!**\n\n" | |
| "---\n" + "\n".join(log_lines) | |
| ) | |
| return previews, summary, zip_path | |
| # βββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| body { font-family: 'IBM Plex Mono', monospace; } | |
| #header { | |
| background: #0F172A; | |
| padding: 20px 24px 14px; | |
| border-radius: 10px; | |
| margin-bottom: 12px; | |
| } | |
| #run-btn { margin-top: 8px; background: #7C3AED !important; } | |
| #dl-btn { margin-top: 6px; } | |
| .gr-gallery-item img { border-radius: 6px; } | |
| """ | |
| THEME = gr.themes.Soft( | |
| primary_hue="violet", | |
| secondary_hue="indigo", | |
| neutral_hue="slate", | |
| ) | |
| with gr.Blocks(theme=THEME, css=CSS, title="Rice YOLO Annotator") as demo: | |
| gr.HTML(""" | |
| <div id="header"> | |
| <span style="font-size:1.9rem;font-weight:900;color:#F1F5F9;font-family:monospace;"> | |
| ML<span style="color:#EF4444;">Bench</span> | |
| <span style="font-size:1rem;font-weight:400;color:#94A3B8;margin-left:12px;"> | |
| Rice Grain β YOLO Annotation Exporter | |
| </span> | |
| </span> | |
| <p style="color:#64748B;font-size:0.85rem;margin-top:6px;font-family:monospace;"> | |
| Upload up to 50 images Β· Cellpose segments each grain Β· | |
| Download ZIP with YOLO labels ready for Roboflow | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # ββ LEFT ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Upload Images") | |
| image_input = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| label="Drop up to 50 rice images here", | |
| height=180, | |
| ) | |
| with gr.Accordion("βοΈ Cellpose Settings", open=False): | |
| flow_thresh = gr.Slider( | |
| 0.0, 1.0, value=0.4, step=0.05, | |
| label="Flow Threshold", | |
| info="Higher = stricter (fewer false grains)" | |
| ) | |
| cellprob_thresh = gr.Slider( | |
| -4.0, 4.0, value=0.0, step=0.5, | |
| label="Cell Probability Threshold", | |
| info="Lower = detect more grains" | |
| ) | |
| run_btn = gr.Button( | |
| "π Run Cellpose & Export Annotations", | |
| variant="primary", size="lg", elem_id="run-btn" | |
| ) | |
| gr.Markdown(""" | |
| ### π Workflow | |
| 1. Upload 50 images here | |
| 2. Click **Run** β Cellpose segments every grain | |
| 3. Download the ZIP | |
| 4. Upload ZIP to **Roboflow** (format: YOLOv8 Segmentation) | |
| 5. Re-label broken grains as `broken_grain` class | |
| 6. Export & train YOLOv8! | |
| """) | |
| download_btn = gr.File( | |
| label="β¬οΈ Download YOLO Dataset ZIP", | |
| interactive=False, | |
| elem_id="dl-btn", | |
| ) | |
| # ββ RIGHT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Segmentation Previews") | |
| gallery = gr.Gallery( | |
| label="", | |
| show_label=False, | |
| columns=3, | |
| height=460, | |
| object_fit="contain", | |
| ) | |
| summary_box = gr.Markdown( | |
| value="_Results will appear here after processing..._" | |
| ) | |
| run_btn.click( | |
| fn=process_batch, | |
| inputs=[image_input, flow_thresh, cellprob_thresh], | |
| outputs=[gallery, summary_box, download_btn], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |