#region Imports import os # Route caches/configs to /tmp to avoid filling persistent storage and suppress permission warnings os.environ.setdefault("HF_HOME", "/tmp/hf_home") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_home/transformers") os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_home/hub") os.environ.setdefault("TORCH_HOME", "/tmp/torch_home") os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1") os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics") import cv2 import time import shutil import numpy as np import gradio as gr from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction # Try to import ultralytics for native segmentation try: from ultralytics import YOLO _ULTRA_OK = True except Exception: _ULTRA_OK = False #endregion #region Config and setup MAX_SIDE_PX = 80 # set >0 (e.g., 70) to filter detections with large side; -1 disables SEG_DEFAULT_ALPHA = 0.45 # High-contrast colors for green backgrounds (BGR order) BERRIES_COLOR_BGR = (255, 0, 255) # magenta/pink for detection boxes BUNCHES_FILL_COLOR_BGR = (255, 255, 0) # cyan for mask fill BUNCHES_CONTOUR_COLOR_BGR = (255, 255, 255) # white for mask contours # Fixed weights (no UI controls). If you want them editable, add Textbox components and wire them as inputs. WEIGHTS_DETECTION = "weights/berries.pt" WEIGHTS_SEGMENTATION = "weights/bunches.pt" # Simple global caches to avoid reloading models each click _DET_MODEL_CACHE = {} # key: (weights_path, device) -> AutoDetectionModel _SEG_MODEL_CACHE = {} # key: weights_path -> YOLO #endregion #region Model and device handling def _choose_device(user_choice: str) -> str: if user_choice != "auto": return user_choice try: import torch return "cuda:0" if torch.cuda.is_available() else "cpu" except Exception: return "cpu" def _get_det_model(weights_path: str, device: str, conf: float): """ Returns a cached SAHI AutoDetectionModel. Updates confidence on the fly. """ if not os.path.exists(weights_path): raise gr.Error(f"Detection weights not found: {weights_path}") key = (weights_path, device) model = _DET_MODEL_CACHE.get(key) if model is None: try: model = AutoDetectionModel.from_pretrained( model_type="yolo11", model_path=weights_path, confidence_threshold=conf, device=device, ) except Exception: # CPU fallback model = AutoDetectionModel.from_pretrained( model_type="yolo11", model_path=weights_path, confidence_threshold=conf, device="cpu", ) _DET_MODEL_CACHE[key] = model else: # Update confidence threshold if present try: model.confidence_threshold = float(conf) except Exception: pass return model def _get_seg_model(weights_path: str): if not _ULTRA_OK: raise gr.Error("Ultralytics not found, please install it with: pip install ultralytics") if not os.path.exists(weights_path): raise gr.Error(f"Segmentation weights not found: {weights_path}") model = _SEG_MODEL_CACHE.get(weights_path) if model is None: model = YOLO(weights_path) _SEG_MODEL_CACHE[weights_path] = model return model #endregion #region Inference def _sahi_predict(image_rgb: np.ndarray, det_model, slice_h, slice_w, overlap_h, overlap_w): return get_sliced_prediction( image_rgb, det_model, slice_height=int(slice_h), slice_width=int(slice_w), overlap_height_ratio=float(overlap_h), overlap_width_ratio=float(overlap_w), postprocess_class_agnostic=False, verbose=0, ) def run_det( image, state, conf_det, slice_h, slice_w, overlap_h, overlap_w, device ): """ Run model A (berries detection via SAHI) and update only 'det' overlay. Assemble final image with both layers (det + seg) in timestamp order. """ if state is None or state.get("base") is None: raise gr.Error("Loading an image is required before inference.") base = state["base"] # basic auto-opt: if image fits one tile, set overlap 0 to speed up H, W = base.shape[:2] if H <= slice_h and W <= slice_w: overlap_h, overlap_w = 0.0, 0.0 det_model = _get_det_model(WEIGHTS_DETECTION, _choose_device(device), conf_det) sahi_res = _sahi_predict(base, det_model, slice_h, slice_w, overlap_h, overlap_w) # No target highlighting in this simplified app overlay_rgb, alpha, counts = _draw_boxes_overlay(base, sahi_res, target_class="", use_target=False) state["det"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()} state["det_counts"] = counts layers = [state["det"], state.get("seg")] composite = _composite_layers(base, layers) return composite, state, state["det_counts"], state.get("seg_counts", "") def run_seg( image, state, conf_seg, device, seg_alpha ): """ Run model B (bunches segmentation) and update only 'seg' overlay. Assemble final image with both layers (det + seg) in timestamp order. """ if state is None or state.get("base") is None: raise gr.Error("Loading an image is required before inference.") base = state["base"] seg_model = _get_seg_model(WEIGHTS_SEGMENTATION) try: seg_results = seg_model.predict(source=base, conf=float(conf_seg), device=_choose_device(device), verbose=False) r0 = seg_results[0] if isinstance(seg_results, (list, tuple)) else seg_results except Exception as e: raise gr.Error(f"Error in segmentation inference: {e}") # No target highlighting in this simplified app overlay_rgb, alpha, counts = _draw_seg_overlay(base, r0, target_class="", use_target=False, fill_alpha=float(seg_alpha)) state["seg"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()} state["seg_counts"] = counts layers = [state.get("det"), state["seg"]] composite = _composite_layers(base, layers) return composite, state, state.get("det_counts", ""), state["seg_counts"] #endregion #region Draw def _ensure_rgb(img: np.ndarray) -> np.ndarray: if img is None: return None if img.ndim == 2: return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) if img.shape[2] == 4: return cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) return img def _draw_boxes_overlay(image_rgb: np.ndarray, sahi_result, target_class: str, use_target: bool): """ Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text Only draws rectangles (no labels). Optionally filters boxes with max side > MAX_SIDE_PX if MAX_SIDE_PX > 0. """ H, W = image_rgb.shape[:2] overlay = np.zeros((H, W, 3), dtype=np.uint8) alpha = np.zeros((H, W), dtype=np.uint8) target_count = 0 total_count = 0 object_predictions = getattr(sahi_result, "object_prediction_list", []) or [] for item in object_predictions: # parse bbox try: x1, y1, x2, y2 = map(int, item.bbox.to_xyxy()) except Exception: x1, y1 = int(getattr(item.bbox, "minx", 0)), int(getattr(item.bbox, "miny", 0)) x2, y2 = int(getattr(item.bbox, "maxx", 0)), int(getattr(item.bbox, "maxy", 0)) # clamp and normalize x1 = max(0, min(x1, W - 1)); x2 = max(0, min(x2, W - 1)) y1 = max(0, min(y1, H - 1)); y2 = max(0, min(y2, H - 1)) if x2 < x1: x1, x2 = x2, x1 if y2 < y1: y1, y2 = y2, y1 w = max(0, x2 - x1) h = max(0, y2 - y1) if w == 0 or h == 0: continue if MAX_SIDE_PX > 0 and max(w, h) > MAX_SIDE_PX: continue area = getattr(item.bbox, "area", w * h) try: area_val = float(area() if callable(area) else area) except Exception: area_val = float(w * h) if area_val <= 0: continue cls = getattr(item.category, "name", "unknown") is_target = (cls == target_class) if use_target else False color_bgr = BERRIES_COLOR_BGR # Draw on overlay (BGR) cv2.rectangle(overlay, (x1, y1), (x2, y2), color_bgr, 2) cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2) total_count += 1 if is_target: target_count += 1 # Convert overlay BGR -> RGB overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) counts = (f"target='{target_class}': {target_count} | total: {total_count}") if use_target else f"total: {total_count}" return overlay_rgb, alpha, counts def _draw_seg_overlay(image_rgb: np.ndarray, yolo_result, target_class: str, use_target: bool, fill_alpha: float = SEG_DEFAULT_ALPHA): """ Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text for segmentation - Fills masks with color (red for target, green for others if target enabled; else green) - Draws contour opaque """ H, W = image_rgb.shape[:2] overlay_bgr = np.zeros((H, W, 3), dtype=np.uint8) alpha = np.zeros((H, W), dtype=np.uint8) r = yolo_result names = getattr(r, "names", None) boxes = getattr(r, "boxes", None) masks = getattr(r, "masks", None) if boxes is None or len(boxes) == 0: counts = f"target='{target_class}': 0 | total: 0" if use_target else "total: 0" return cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB), alpha, counts N = len(boxes) mask_data = None if masks is not None and getattr(masks, "data", None) is not None: try: mask_data = masks.data # torch.Tensor [N, H, W] except Exception: mask_data = None target_count = 0 total_count = 0 fa255 = int(max(0.0, min(1.0, float(fill_alpha))) * 255) for i in range(N): try: cls_idx = int(boxes.cls[i].item()) except Exception: cls_idx = -1 cls_name = str(cls_idx) if isinstance(names, dict): cls_name = names.get(cls_idx, cls_name) is_target = (cls_name == target_class) if use_target else False color_bgr = (0, 0, 255) if is_target and use_target else (0, 200, 0) if mask_data is not None and i < len(mask_data): try: m = mask_data[i] m = m.detach().cpu().numpy() m = (m > 0.5).astype(np.uint8) if m.shape[:2] != (H, W): m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) overlay_bgr[m == 1] = BUNCHES_FILL_COLOR_BGR alpha[m == 1] = np.maximum(alpha[m == 1], fa255) cnts, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay_bgr, cnts, -1, BUNCHES_CONTOUR_COLOR_BGR, 2) cv2.drawContours(alpha, cnts, -1, 255, 2) except Exception: try: xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int) x1, y1, x2, y2 = map(int, xyxy) cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), BUNCHES_CONTOUR_COLOR_BGR, 2) cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2) except Exception: pass else: try: xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int) x1, y1, x2, y2 = map(int, xyxy) cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), BUNCHES_CONTOUR_COLOR_BGR, 2) cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2) except Exception: pass total_count += 1 if is_target: target_count += 1 overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) counts = (f"target='{target_class}': {target_count} | total: {total_count}") if use_target else f"total: {total_count}" return overlay_rgb, alpha, counts def _composite_layers(base_rgb: np.ndarray, layers: list): """ layers: list of dicts with keys: - 'overlay' : np.ndarray HxWx3 RGB - 'alpha' : np.ndarray HxW uint8 - 'ts' : float (timestamp), to control stacking order (oldest first) Newest layer should be on top: sort by ts ascending and apply in order. """ if base_rgb is None: return None result = base_rgb.astype(np.float32) layers_sorted = sorted([l for l in layers if l is not None], key=lambda d: d["ts"]) for layer in layers_sorted: ov = layer["overlay"].astype(np.float32) a = (layer["alpha"].astype(np.float32) / 255.0)[..., None] # HxWx1 if ov.shape[:2] != result.shape[:2]: ov = cv2.resize(ov, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR) a = cv2.resize(a, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR)[..., None] result = ov * a + result * (1.0 - a) return np.clip(result, 0, 255).astype(np.uint8) def on_image_upload(image, state): """ Reset overlays if uploading a new image. """ if image is None: return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", "" img_rgb = _ensure_rgb(image) new_state = {"base": img_rgb, "det": None, "seg": None, "det_counts": "", "seg_counts": ""} return img_rgb, new_state, "", "" def clear_overlays(image, state): if state is None or state.get("base") is None: return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", "" base = state["base"] state["det"] = None state["seg"] = None state["det_counts"] = "" state["seg_counts"] = "" return base, state, "", "" #endregion #region Maintenance def _dir_size(path: str) -> int: try: total = 0 for root, _, files in os.walk(path): for f in files: fp = os.path.join(root, f) try: total += os.path.getsize(fp) except Exception: pass return total except Exception: return 0 def _fmt_bytes(n: int) -> str: for unit in ["B", "KB", "MB", "GB", "TB"]: if n < 1024.0: return f"{n:.1f} {unit}" n /= 1024.0 return f"{n:.1f} PB" def check_storage(): # Key cache locations paths = [ os.path.expanduser("~/.cache/huggingface/hub"), os.path.expanduser("~/.cache/torch"), os.path.expanduser("~/.cache/pip"), os.path.expanduser("~/.config/Ultralytics"), "/tmp/hf_home/hub", "/tmp/torch_home", ] lines = [] for p in paths: sz = _dir_size(p) if os.path.exists(p) else 0 lines.append(f"{p}: {_fmt_bytes(sz)}") try: total, used, free = shutil.disk_usage("/") disk_line = f"Disk usage: used {_fmt_bytes(used)} / total {_fmt_bytes(total)} (free {_fmt_bytes(free)})" except Exception: disk_line = "Disk usage: n/a" return "Cache sizes:\n" + "\n".join(lines) + "\n" + disk_line def clean_caches(): paths = [ os.path.expanduser("~/.cache/huggingface/hub"), os.path.expanduser("~/.cache/torch"), os.path.expanduser("~/.cache/pip"), os.path.expanduser("~/.config/Ultralytics"), "/tmp/hf_home", "/tmp/torch_home", ] removed = [] for p in paths: try: if os.path.exists(p): shutil.rmtree(p, ignore_errors=True) removed.append(p) except Exception: pass return "Removed:\n" + ("\n".join(removed) if removed else "(none)") #endregion def build_app(): with gr.Blocks(title="Berries detection & bunches segmentation") as demo: gr.Markdown( "## Double inference on the same image with combined overlays\n" "- Model A: berries detection\n" "- Model B: bunches segmentation\n" "- Run individually; overlays combine on the same image.\n" ) state = gr.State({"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(label="Image", type="numpy") with gr.Tab("Model A — Berries Detection"): with gr.Row(): conf_det = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (A)") device_a = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device") with gr.Row(): slice_h = gr.Slider(64, 2048, value=640, step=32, label="Slice H (A)") slice_w = gr.Slider(64, 2048, value=640, step=32, label="Slice W (A)") with gr.Row(): overlap_h = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap H (A)") overlap_w = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap W (A)") btn_det = gr.Button("Run berries detection") with gr.Tab("Model B — Bunches Segmentation"): with gr.Row(): conf_seg = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (B)") seg_alpha = gr.Slider(0.0, 1.0, value=SEG_DEFAULT_ALPHA, step=0.05, label="Alpha masks (B)") device_b = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device") btn_seg = gr.Button("Run bunches segmentation") with gr.Row(): btn_clear = gr.Button("Clean overlay", variant="secondary") with gr.Accordion("Disk Maintenance", open=False): btn_check = gr.Button("Check storage") btn_clean = gr.Button("Clean cache") maint_out = gr.Textbox(label="Log Maintenance", interactive=False) with gr.Column(scale=2): img_out = gr.Image(label="Combined Result", type="numpy") with gr.Row(): counts_out_det = gr.Textbox(label="Counts - Berries", interactive=False) counts_out_seg = gr.Textbox(label="Counts - Bunches", interactive=False) # Wiring img_in.change( on_image_upload, inputs=[img_in, state], outputs=[img_out, state, counts_out_det, counts_out_seg], ) btn_det.click( run_det, inputs=[ img_in, state, conf_det, slice_h, slice_w, overlap_h, overlap_w, device_a ], outputs=[img_out, state, counts_out_det, counts_out_seg], ) btn_seg.click( run_seg, inputs=[ img_in, state, conf_seg, device_b, seg_alpha ], outputs=[img_out, state, counts_out_det, counts_out_seg], ) btn_clear.click( clear_overlays, inputs=[img_in, state], outputs=[img_out, state, counts_out_det, counts_out_seg], ) btn_check.click( check_storage, inputs=[], outputs=[maint_out], ) btn_clean.click( clean_caches, inputs=[], outputs=[maint_out], ) return demo if __name__ == "__main__": demo = build_app() demo.launch()