Spaces:
Sleeping
Sleeping
| #region Imports | |
| import os | |
| # Route caches/configs to /tmp to avoid filling persistent storage and suppress permission warnings | |
| os.environ.setdefault("HF_HOME", "/tmp/hf_home") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_home/transformers") | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_home/hub") | |
| os.environ.setdefault("TORCH_HOME", "/tmp/torch_home") | |
| os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1") | |
| os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") | |
| import cv2 | |
| import time | |
| import shutil | |
| import numpy as np | |
| import gradio as gr | |
| from sahi import AutoDetectionModel | |
| from sahi.predict import get_sliced_prediction | |
| # Try to import ultralytics for native segmentation | |
| try: | |
| from ultralytics import YOLO | |
| _ULTRA_OK = True | |
| except Exception: | |
| _ULTRA_OK = False | |
| #endregion | |
| #region Config and setup | |
| MAX_SIDE_PX = 80 # set >0 (e.g., 70) to filter detections with large side; -1 disables | |
| SEG_DEFAULT_ALPHA = 0.45 | |
| # High-contrast colors for green backgrounds (BGR order) | |
| BERRIES_COLOR_BGR = (255, 0, 255) # magenta/pink for detection boxes | |
| BUNCHES_FILL_COLOR_BGR = (255, 255, 0) # cyan for mask fill | |
| BUNCHES_CONTOUR_COLOR_BGR = (255, 255, 255) # white for mask contours | |
| # Fixed weights (no UI controls). If you want them editable, add Textbox components and wire them as inputs. | |
| WEIGHTS_DETECTION = "weights/berries.pt" | |
| WEIGHTS_SEGMENTATION = "weights/bunches.pt" | |
| # Simple global caches to avoid reloading models each click | |
| _DET_MODEL_CACHE = {} # key: (weights_path, device) -> AutoDetectionModel | |
| _SEG_MODEL_CACHE = {} # key: weights_path -> YOLO | |
| #endregion | |
| #region Model and device handling | |
| def _choose_device(user_choice: str) -> str: | |
| if user_choice != "auto": | |
| return user_choice | |
| try: | |
| import torch | |
| return "cuda:0" if torch.cuda.is_available() else "cpu" | |
| except Exception: | |
| return "cpu" | |
| def _get_det_model(weights_path: str, device: str, conf: float): | |
| """ | |
| Returns a cached SAHI AutoDetectionModel. Updates confidence on the fly. | |
| """ | |
| if not os.path.exists(weights_path): | |
| raise gr.Error(f"Detection weights not found: {weights_path}") | |
| key = (weights_path, device) | |
| model = _DET_MODEL_CACHE.get(key) | |
| if model is None: | |
| try: | |
| model = AutoDetectionModel.from_pretrained( | |
| model_type="yolo11", | |
| model_path=weights_path, | |
| confidence_threshold=conf, | |
| device=device, | |
| ) | |
| except Exception: | |
| # CPU fallback | |
| model = AutoDetectionModel.from_pretrained( | |
| model_type="yolo11", | |
| model_path=weights_path, | |
| confidence_threshold=conf, | |
| device="cpu", | |
| ) | |
| _DET_MODEL_CACHE[key] = model | |
| else: | |
| # Update confidence threshold if present | |
| try: | |
| model.confidence_threshold = float(conf) | |
| except Exception: | |
| pass | |
| return model | |
| def _get_seg_model(weights_path: str): | |
| if not _ULTRA_OK: | |
| raise gr.Error("Ultralytics not found, please install it with: pip install ultralytics") | |
| if not os.path.exists(weights_path): | |
| raise gr.Error(f"Segmentation weights not found: {weights_path}") | |
| model = _SEG_MODEL_CACHE.get(weights_path) | |
| if model is None: | |
| model = YOLO(weights_path) | |
| _SEG_MODEL_CACHE[weights_path] = model | |
| return model | |
| #endregion | |
| #region Inference | |
| def _sahi_predict(image_rgb: np.ndarray, det_model, slice_h, slice_w, overlap_h, overlap_w): | |
| return get_sliced_prediction( | |
| image_rgb, | |
| det_model, | |
| slice_height=int(slice_h), | |
| slice_width=int(slice_w), | |
| overlap_height_ratio=float(overlap_h), | |
| overlap_width_ratio=float(overlap_w), | |
| postprocess_class_agnostic=False, | |
| verbose=0, | |
| ) | |
| def run_det( | |
| image, state, | |
| conf_det, slice_h, slice_w, overlap_h, overlap_w, device | |
| ): | |
| """ | |
| Run model A (berries detection via SAHI) and update only 'det' overlay. | |
| Assemble final image with both layers (det + seg) in timestamp order. | |
| """ | |
| if state is None or state.get("base") is None: | |
| raise gr.Error("Loading an image is required before inference.") | |
| base = state["base"] | |
| # basic auto-opt: if image fits one tile, set overlap 0 to speed up | |
| H, W = base.shape[:2] | |
| if H <= slice_h and W <= slice_w: | |
| overlap_h, overlap_w = 0.0, 0.0 | |
| det_model = _get_det_model(WEIGHTS_DETECTION, _choose_device(device), conf_det) | |
| sahi_res = _sahi_predict(base, det_model, slice_h, slice_w, overlap_h, overlap_w) | |
| # No target highlighting in this simplified app | |
| overlay_rgb, alpha, counts = _draw_boxes_overlay(base, sahi_res, target_class="", use_target=False) | |
| state["det"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()} | |
| state["det_counts"] = counts | |
| layers = [state["det"], state.get("seg")] | |
| composite = _composite_layers(base, layers) | |
| return composite, state, state["det_counts"], state.get("seg_counts", "") | |
| def run_seg( | |
| image, state, | |
| conf_seg, device, seg_alpha | |
| ): | |
| """ | |
| Run model B (bunches segmentation) and update only 'seg' overlay. | |
| Assemble final image with both layers (det + seg) in timestamp order. | |
| """ | |
| if state is None or state.get("base") is None: | |
| raise gr.Error("Loading an image is required before inference.") | |
| base = state["base"] | |
| seg_model = _get_seg_model(WEIGHTS_SEGMENTATION) | |
| try: | |
| seg_results = seg_model.predict(source=base, conf=float(conf_seg), device=_choose_device(device), verbose=False) | |
| r0 = seg_results[0] if isinstance(seg_results, (list, tuple)) else seg_results | |
| except Exception as e: | |
| raise gr.Error(f"Error in segmentation inference: {e}") | |
| # No target highlighting in this simplified app | |
| overlay_rgb, alpha, counts = _draw_seg_overlay(base, r0, target_class="", use_target=False, fill_alpha=float(seg_alpha)) | |
| state["seg"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()} | |
| state["seg_counts"] = counts | |
| layers = [state.get("det"), state["seg"]] | |
| composite = _composite_layers(base, layers) | |
| return composite, state, state.get("det_counts", ""), state["seg_counts"] | |
| #endregion | |
| #region Draw | |
| def _ensure_rgb(img: np.ndarray) -> np.ndarray: | |
| if img is None: | |
| return None | |
| if img.ndim == 2: | |
| return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| if img.shape[2] == 4: | |
| return cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) | |
| return img | |
| def _draw_boxes_overlay(image_rgb: np.ndarray, sahi_result, target_class: str, use_target: bool): | |
| """ | |
| Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text | |
| Only draws rectangles (no labels). Optionally filters boxes with max side > MAX_SIDE_PX if MAX_SIDE_PX > 0. | |
| """ | |
| H, W = image_rgb.shape[:2] | |
| overlay = np.zeros((H, W, 3), dtype=np.uint8) | |
| alpha = np.zeros((H, W), dtype=np.uint8) | |
| target_count = 0 | |
| total_count = 0 | |
| object_predictions = getattr(sahi_result, "object_prediction_list", []) or [] | |
| for item in object_predictions: | |
| # parse bbox | |
| try: | |
| x1, y1, x2, y2 = map(int, item.bbox.to_xyxy()) | |
| except Exception: | |
| x1, y1 = int(getattr(item.bbox, "minx", 0)), int(getattr(item.bbox, "miny", 0)) | |
| x2, y2 = int(getattr(item.bbox, "maxx", 0)), int(getattr(item.bbox, "maxy", 0)) | |
| # clamp and normalize | |
| x1 = max(0, min(x1, W - 1)); x2 = max(0, min(x2, W - 1)) | |
| y1 = max(0, min(y1, H - 1)); y2 = max(0, min(y2, H - 1)) | |
| if x2 < x1: x1, x2 = x2, x1 | |
| if y2 < y1: y1, y2 = y2, y1 | |
| w = max(0, x2 - x1) | |
| h = max(0, y2 - y1) | |
| if w == 0 or h == 0: | |
| continue | |
| if MAX_SIDE_PX > 0 and max(w, h) > MAX_SIDE_PX: | |
| continue | |
| area = getattr(item.bbox, "area", w * h) | |
| try: | |
| area_val = float(area() if callable(area) else area) | |
| except Exception: | |
| area_val = float(w * h) | |
| if area_val <= 0: | |
| continue | |
| cls = getattr(item.category, "name", "unknown") | |
| is_target = (cls == target_class) if use_target else False | |
| color_bgr = BERRIES_COLOR_BGR | |
| # Draw on overlay (BGR) | |
| cv2.rectangle(overlay, (x1, y1), (x2, y2), color_bgr, 2) | |
| cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2) | |
| total_count += 1 | |
| if is_target: | |
| target_count += 1 | |
| # Convert overlay BGR -> RGB | |
| overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) | |
| counts = (f"target='{target_class}': {target_count} | total: {total_count}") if use_target else f"total: {total_count}" | |
| return overlay_rgb, alpha, counts | |
| def _draw_seg_overlay(image_rgb: np.ndarray, yolo_result, target_class: str, use_target: bool, fill_alpha: float = SEG_DEFAULT_ALPHA): | |
| """ | |
| Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text for segmentation | |
| - Fills masks with color (red for target, green for others if target enabled; else green) | |
| - Draws contour opaque | |
| """ | |
| H, W = image_rgb.shape[:2] | |
| overlay_bgr = np.zeros((H, W, 3), dtype=np.uint8) | |
| alpha = np.zeros((H, W), dtype=np.uint8) | |
| r = yolo_result | |
| names = getattr(r, "names", None) | |
| boxes = getattr(r, "boxes", None) | |
| masks = getattr(r, "masks", None) | |
| if boxes is None or len(boxes) == 0: | |
| counts = f"target='{target_class}': 0 | total: 0" if use_target else "total: 0" | |
| return cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB), alpha, counts | |
| N = len(boxes) | |
| mask_data = None | |
| if masks is not None and getattr(masks, "data", None) is not None: | |
| try: | |
| mask_data = masks.data # torch.Tensor [N, H, W] | |
| except Exception: | |
| mask_data = None | |
| target_count = 0 | |
| total_count = 0 | |
| fa255 = int(max(0.0, min(1.0, float(fill_alpha))) * 255) | |
| for i in range(N): | |
| try: | |
| cls_idx = int(boxes.cls[i].item()) | |
| except Exception: | |
| cls_idx = -1 | |
| cls_name = str(cls_idx) | |
| if isinstance(names, dict): | |
| cls_name = names.get(cls_idx, cls_name) | |
| is_target = (cls_name == target_class) if use_target else False | |
| color_bgr = (0, 0, 255) if is_target and use_target else (0, 200, 0) | |
| if mask_data is not None and i < len(mask_data): | |
| try: | |
| m = mask_data[i] | |
| m = m.detach().cpu().numpy() | |
| m = (m > 0.5).astype(np.uint8) | |
| if m.shape[:2] != (H, W): | |
| m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) | |
| overlay_bgr[m == 1] = BUNCHES_FILL_COLOR_BGR | |
| alpha[m == 1] = np.maximum(alpha[m == 1], fa255) | |
| cnts, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(overlay_bgr, cnts, -1, BUNCHES_CONTOUR_COLOR_BGR, 2) | |
| cv2.drawContours(alpha, cnts, -1, 255, 2) | |
| except Exception: | |
| try: | |
| xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int) | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), BUNCHES_CONTOUR_COLOR_BGR, 2) | |
| cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2) | |
| except Exception: | |
| pass | |
| else: | |
| try: | |
| xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int) | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), BUNCHES_CONTOUR_COLOR_BGR, 2) | |
| cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2) | |
| except Exception: | |
| pass | |
| total_count += 1 | |
| if is_target: | |
| target_count += 1 | |
| overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) | |
| counts = (f"target='{target_class}': {target_count} | total: {total_count}") if use_target else f"total: {total_count}" | |
| return overlay_rgb, alpha, counts | |
| def _composite_layers(base_rgb: np.ndarray, layers: list): | |
| """ | |
| layers: list of dicts with keys: | |
| - 'overlay' : np.ndarray HxWx3 RGB | |
| - 'alpha' : np.ndarray HxW uint8 | |
| - 'ts' : float (timestamp), to control stacking order (oldest first) | |
| Newest layer should be on top: sort by ts ascending and apply in order. | |
| """ | |
| if base_rgb is None: | |
| return None | |
| result = base_rgb.astype(np.float32) | |
| layers_sorted = sorted([l for l in layers if l is not None], key=lambda d: d["ts"]) | |
| for layer in layers_sorted: | |
| ov = layer["overlay"].astype(np.float32) | |
| a = (layer["alpha"].astype(np.float32) / 255.0)[..., None] # HxWx1 | |
| if ov.shape[:2] != result.shape[:2]: | |
| ov = cv2.resize(ov, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| a = cv2.resize(a, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR)[..., None] | |
| result = ov * a + result * (1.0 - a) | |
| return np.clip(result, 0, 255).astype(np.uint8) | |
| def on_image_upload(image, state): | |
| """ | |
| Reset overlays if uploading a new image. | |
| """ | |
| if image is None: | |
| return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", "" | |
| img_rgb = _ensure_rgb(image) | |
| new_state = {"base": img_rgb, "det": None, "seg": None, "det_counts": "", "seg_counts": ""} | |
| return img_rgb, new_state, "", "" | |
| def clear_overlays(image, state): | |
| if state is None or state.get("base") is None: | |
| return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", "" | |
| base = state["base"] | |
| state["det"] = None | |
| state["seg"] = None | |
| state["det_counts"] = "" | |
| state["seg_counts"] = "" | |
| return base, state, "", "" | |
| #endregion | |
| #region Maintenance | |
| def _dir_size(path: str) -> int: | |
| try: | |
| total = 0 | |
| for root, _, files in os.walk(path): | |
| for f in files: | |
| fp = os.path.join(root, f) | |
| try: | |
| total += os.path.getsize(fp) | |
| except Exception: | |
| pass | |
| return total | |
| except Exception: | |
| return 0 | |
| def _fmt_bytes(n: int) -> str: | |
| for unit in ["B", "KB", "MB", "GB", "TB"]: | |
| if n < 1024.0: | |
| return f"{n:.1f} {unit}" | |
| n /= 1024.0 | |
| return f"{n:.1f} PB" | |
| def check_storage(): | |
| # Key cache locations | |
| paths = [ | |
| os.path.expanduser("~/.cache/huggingface/hub"), | |
| os.path.expanduser("~/.cache/torch"), | |
| os.path.expanduser("~/.cache/pip"), | |
| os.path.expanduser("~/.config/Ultralytics"), | |
| "/tmp/hf_home/hub", | |
| "/tmp/torch_home", | |
| ] | |
| lines = [] | |
| for p in paths: | |
| sz = _dir_size(p) if os.path.exists(p) else 0 | |
| lines.append(f"{p}: {_fmt_bytes(sz)}") | |
| try: | |
| total, used, free = shutil.disk_usage("/") | |
| disk_line = f"Disk usage: used {_fmt_bytes(used)} / total {_fmt_bytes(total)} (free {_fmt_bytes(free)})" | |
| except Exception: | |
| disk_line = "Disk usage: n/a" | |
| return "Cache sizes:\n" + "\n".join(lines) + "\n" + disk_line | |
| def clean_caches(): | |
| paths = [ | |
| os.path.expanduser("~/.cache/huggingface/hub"), | |
| os.path.expanduser("~/.cache/torch"), | |
| os.path.expanduser("~/.cache/pip"), | |
| os.path.expanduser("~/.config/Ultralytics"), | |
| "/tmp/hf_home", | |
| "/tmp/torch_home", | |
| ] | |
| removed = [] | |
| for p in paths: | |
| try: | |
| if os.path.exists(p): | |
| shutil.rmtree(p, ignore_errors=True) | |
| removed.append(p) | |
| except Exception: | |
| pass | |
| return "Removed:\n" + ("\n".join(removed) if removed else "(none)") | |
| #endregion | |
| def build_app(): | |
| with gr.Blocks(title="Berries detection & bunches segmentation") as demo: | |
| gr.Markdown( | |
| "## Double inference on the same image with combined overlays\n" | |
| "- Model A: berries detection\n" | |
| "- Model B: bunches segmentation\n" | |
| "- Run individually; overlays combine on the same image.\n" | |
| ) | |
| state = gr.State({"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(label="Image", type="numpy") | |
| with gr.Tab("Model A — Berries Detection"): | |
| with gr.Row(): | |
| conf_det = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (A)") | |
| device_a = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device") | |
| with gr.Row(): | |
| slice_h = gr.Slider(64, 2048, value=640, step=32, label="Slice H (A)") | |
| slice_w = gr.Slider(64, 2048, value=640, step=32, label="Slice W (A)") | |
| with gr.Row(): | |
| overlap_h = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap H (A)") | |
| overlap_w = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap W (A)") | |
| btn_det = gr.Button("Run berries detection") | |
| with gr.Tab("Model B — Bunches Segmentation"): | |
| with gr.Row(): | |
| conf_seg = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (B)") | |
| seg_alpha = gr.Slider(0.0, 1.0, value=SEG_DEFAULT_ALPHA, step=0.05, label="Alpha masks (B)") | |
| device_b = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device") | |
| btn_seg = gr.Button("Run bunches segmentation") | |
| with gr.Row(): | |
| btn_clear = gr.Button("Clean overlay", variant="secondary") | |
| with gr.Accordion("Disk Maintenance", open=False): | |
| btn_check = gr.Button("Check storage") | |
| btn_clean = gr.Button("Clean cache") | |
| maint_out = gr.Textbox(label="Log Maintenance", interactive=False) | |
| with gr.Column(scale=2): | |
| img_out = gr.Image(label="Combined Result", type="numpy") | |
| with gr.Row(): | |
| counts_out_det = gr.Textbox(label="Counts - Berries", interactive=False) | |
| counts_out_seg = gr.Textbox(label="Counts - Bunches", interactive=False) | |
| # Wiring | |
| img_in.change( | |
| on_image_upload, | |
| inputs=[img_in, state], | |
| outputs=[img_out, state, counts_out_det, counts_out_seg], | |
| ) | |
| btn_det.click( | |
| run_det, | |
| inputs=[ | |
| img_in, state, | |
| conf_det, slice_h, slice_w, overlap_h, overlap_w, device_a | |
| ], | |
| outputs=[img_out, state, counts_out_det, counts_out_seg], | |
| ) | |
| btn_seg.click( | |
| run_seg, | |
| inputs=[ | |
| img_in, state, | |
| conf_seg, device_b, seg_alpha | |
| ], | |
| outputs=[img_out, state, counts_out_det, counts_out_seg], | |
| ) | |
| btn_clear.click( | |
| clear_overlays, | |
| inputs=[img_in, state], | |
| outputs=[img_out, state, counts_out_det, counts_out_seg], | |
| ) | |
| btn_check.click( | |
| check_storage, | |
| inputs=[], | |
| outputs=[maint_out], | |
| ) | |
| btn_clean.click( | |
| clean_caches, | |
| inputs=[], | |
| outputs=[maint_out], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_app() | |
| demo.launch() |