grape_detection / app.py
colomboMk's picture
Update app.py
c245089 verified
#region Imports
import os
# Route caches/configs to /tmp to avoid filling persistent storage and suppress permission warnings
os.environ.setdefault("HF_HOME", "/tmp/hf_home")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_home/transformers")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_home/hub")
os.environ.setdefault("TORCH_HOME", "/tmp/torch_home")
os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
import cv2
import time
import shutil
import numpy as np
import gradio as gr
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
# Try to import ultralytics for native segmentation
try:
from ultralytics import YOLO
_ULTRA_OK = True
except Exception:
_ULTRA_OK = False
#endregion
#region Config and setup
MAX_SIDE_PX = 80 # set >0 (e.g., 70) to filter detections with large side; -1 disables
SEG_DEFAULT_ALPHA = 0.45
# High-contrast colors for green backgrounds (BGR order)
BERRIES_COLOR_BGR = (255, 0, 255) # magenta/pink for detection boxes
BUNCHES_FILL_COLOR_BGR = (255, 255, 0) # cyan for mask fill
BUNCHES_CONTOUR_COLOR_BGR = (255, 255, 255) # white for mask contours
# Fixed weights (no UI controls). If you want them editable, add Textbox components and wire them as inputs.
WEIGHTS_DETECTION = "weights/berries.pt"
WEIGHTS_SEGMENTATION = "weights/bunches.pt"
# Simple global caches to avoid reloading models each click
_DET_MODEL_CACHE = {} # key: (weights_path, device) -> AutoDetectionModel
_SEG_MODEL_CACHE = {} # key: weights_path -> YOLO
#endregion
#region Model and device handling
def _choose_device(user_choice: str) -> str:
if user_choice != "auto":
return user_choice
try:
import torch
return "cuda:0" if torch.cuda.is_available() else "cpu"
except Exception:
return "cpu"
def _get_det_model(weights_path: str, device: str, conf: float):
"""
Returns a cached SAHI AutoDetectionModel. Updates confidence on the fly.
"""
if not os.path.exists(weights_path):
raise gr.Error(f"Detection weights not found: {weights_path}")
key = (weights_path, device)
model = _DET_MODEL_CACHE.get(key)
if model is None:
try:
model = AutoDetectionModel.from_pretrained(
model_type="yolo11",
model_path=weights_path,
confidence_threshold=conf,
device=device,
)
except Exception:
# CPU fallback
model = AutoDetectionModel.from_pretrained(
model_type="yolo11",
model_path=weights_path,
confidence_threshold=conf,
device="cpu",
)
_DET_MODEL_CACHE[key] = model
else:
# Update confidence threshold if present
try:
model.confidence_threshold = float(conf)
except Exception:
pass
return model
def _get_seg_model(weights_path: str):
if not _ULTRA_OK:
raise gr.Error("Ultralytics not found, please install it with: pip install ultralytics")
if not os.path.exists(weights_path):
raise gr.Error(f"Segmentation weights not found: {weights_path}")
model = _SEG_MODEL_CACHE.get(weights_path)
if model is None:
model = YOLO(weights_path)
_SEG_MODEL_CACHE[weights_path] = model
return model
#endregion
#region Inference
def _sahi_predict(image_rgb: np.ndarray, det_model, slice_h, slice_w, overlap_h, overlap_w):
return get_sliced_prediction(
image_rgb,
det_model,
slice_height=int(slice_h),
slice_width=int(slice_w),
overlap_height_ratio=float(overlap_h),
overlap_width_ratio=float(overlap_w),
postprocess_class_agnostic=False,
verbose=0,
)
def run_det(
image, state,
conf_det, slice_h, slice_w, overlap_h, overlap_w, device
):
"""
Run model A (berries detection via SAHI) and update only 'det' overlay.
Assemble final image with both layers (det + seg) in timestamp order.
"""
if state is None or state.get("base") is None:
raise gr.Error("Loading an image is required before inference.")
base = state["base"]
# basic auto-opt: if image fits one tile, set overlap 0 to speed up
H, W = base.shape[:2]
if H <= slice_h and W <= slice_w:
overlap_h, overlap_w = 0.0, 0.0
det_model = _get_det_model(WEIGHTS_DETECTION, _choose_device(device), conf_det)
sahi_res = _sahi_predict(base, det_model, slice_h, slice_w, overlap_h, overlap_w)
# No target highlighting in this simplified app
overlay_rgb, alpha, counts = _draw_boxes_overlay(base, sahi_res, target_class="", use_target=False)
state["det"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()}
state["det_counts"] = counts
layers = [state["det"], state.get("seg")]
composite = _composite_layers(base, layers)
return composite, state, state["det_counts"], state.get("seg_counts", "")
def run_seg(
image, state,
conf_seg, device, seg_alpha
):
"""
Run model B (bunches segmentation) and update only 'seg' overlay.
Assemble final image with both layers (det + seg) in timestamp order.
"""
if state is None or state.get("base") is None:
raise gr.Error("Loading an image is required before inference.")
base = state["base"]
seg_model = _get_seg_model(WEIGHTS_SEGMENTATION)
try:
seg_results = seg_model.predict(source=base, conf=float(conf_seg), device=_choose_device(device), verbose=False)
r0 = seg_results[0] if isinstance(seg_results, (list, tuple)) else seg_results
except Exception as e:
raise gr.Error(f"Error in segmentation inference: {e}")
# No target highlighting in this simplified app
overlay_rgb, alpha, counts = _draw_seg_overlay(base, r0, target_class="", use_target=False, fill_alpha=float(seg_alpha))
state["seg"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()}
state["seg_counts"] = counts
layers = [state.get("det"), state["seg"]]
composite = _composite_layers(base, layers)
return composite, state, state.get("det_counts", ""), state["seg_counts"]
#endregion
#region Draw
def _ensure_rgb(img: np.ndarray) -> np.ndarray:
if img is None:
return None
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.shape[2] == 4:
return cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
return img
def _draw_boxes_overlay(image_rgb: np.ndarray, sahi_result, target_class: str, use_target: bool):
"""
Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text
Only draws rectangles (no labels). Optionally filters boxes with max side > MAX_SIDE_PX if MAX_SIDE_PX > 0.
"""
H, W = image_rgb.shape[:2]
overlay = np.zeros((H, W, 3), dtype=np.uint8)
alpha = np.zeros((H, W), dtype=np.uint8)
target_count = 0
total_count = 0
object_predictions = getattr(sahi_result, "object_prediction_list", []) or []
for item in object_predictions:
# parse bbox
try:
x1, y1, x2, y2 = map(int, item.bbox.to_xyxy())
except Exception:
x1, y1 = int(getattr(item.bbox, "minx", 0)), int(getattr(item.bbox, "miny", 0))
x2, y2 = int(getattr(item.bbox, "maxx", 0)), int(getattr(item.bbox, "maxy", 0))
# clamp and normalize
x1 = max(0, min(x1, W - 1)); x2 = max(0, min(x2, W - 1))
y1 = max(0, min(y1, H - 1)); y2 = max(0, min(y2, H - 1))
if x2 < x1: x1, x2 = x2, x1
if y2 < y1: y1, y2 = y2, y1
w = max(0, x2 - x1)
h = max(0, y2 - y1)
if w == 0 or h == 0:
continue
if MAX_SIDE_PX > 0 and max(w, h) > MAX_SIDE_PX:
continue
area = getattr(item.bbox, "area", w * h)
try:
area_val = float(area() if callable(area) else area)
except Exception:
area_val = float(w * h)
if area_val <= 0:
continue
cls = getattr(item.category, "name", "unknown")
is_target = (cls == target_class) if use_target else False
color_bgr = BERRIES_COLOR_BGR
# Draw on overlay (BGR)
cv2.rectangle(overlay, (x1, y1), (x2, y2), color_bgr, 2)
cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2)
total_count += 1
if is_target:
target_count += 1
# Convert overlay BGR -> RGB
overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
counts = (f"target='{target_class}': {target_count} | total: {total_count}") if use_target else f"total: {total_count}"
return overlay_rgb, alpha, counts
def _draw_seg_overlay(image_rgb: np.ndarray, yolo_result, target_class: str, use_target: bool, fill_alpha: float = SEG_DEFAULT_ALPHA):
"""
Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text for segmentation
- Fills masks with color (red for target, green for others if target enabled; else green)
- Draws contour opaque
"""
H, W = image_rgb.shape[:2]
overlay_bgr = np.zeros((H, W, 3), dtype=np.uint8)
alpha = np.zeros((H, W), dtype=np.uint8)
r = yolo_result
names = getattr(r, "names", None)
boxes = getattr(r, "boxes", None)
masks = getattr(r, "masks", None)
if boxes is None or len(boxes) == 0:
counts = f"target='{target_class}': 0 | total: 0" if use_target else "total: 0"
return cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB), alpha, counts
N = len(boxes)
mask_data = None
if masks is not None and getattr(masks, "data", None) is not None:
try:
mask_data = masks.data # torch.Tensor [N, H, W]
except Exception:
mask_data = None
target_count = 0
total_count = 0
fa255 = int(max(0.0, min(1.0, float(fill_alpha))) * 255)
for i in range(N):
try:
cls_idx = int(boxes.cls[i].item())
except Exception:
cls_idx = -1
cls_name = str(cls_idx)
if isinstance(names, dict):
cls_name = names.get(cls_idx, cls_name)
is_target = (cls_name == target_class) if use_target else False
color_bgr = (0, 0, 255) if is_target and use_target else (0, 200, 0)
if mask_data is not None and i < len(mask_data):
try:
m = mask_data[i]
m = m.detach().cpu().numpy()
m = (m > 0.5).astype(np.uint8)
if m.shape[:2] != (H, W):
m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
overlay_bgr[m == 1] = BUNCHES_FILL_COLOR_BGR
alpha[m == 1] = np.maximum(alpha[m == 1], fa255)
cnts, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay_bgr, cnts, -1, BUNCHES_CONTOUR_COLOR_BGR, 2)
cv2.drawContours(alpha, cnts, -1, 255, 2)
except Exception:
try:
xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
x1, y1, x2, y2 = map(int, xyxy)
cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), BUNCHES_CONTOUR_COLOR_BGR, 2)
cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2)
except Exception:
pass
else:
try:
xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
x1, y1, x2, y2 = map(int, xyxy)
cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), BUNCHES_CONTOUR_COLOR_BGR, 2)
cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2)
except Exception:
pass
total_count += 1
if is_target:
target_count += 1
overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
counts = (f"target='{target_class}': {target_count} | total: {total_count}") if use_target else f"total: {total_count}"
return overlay_rgb, alpha, counts
def _composite_layers(base_rgb: np.ndarray, layers: list):
"""
layers: list of dicts with keys:
- 'overlay' : np.ndarray HxWx3 RGB
- 'alpha' : np.ndarray HxW uint8
- 'ts' : float (timestamp), to control stacking order (oldest first)
Newest layer should be on top: sort by ts ascending and apply in order.
"""
if base_rgb is None:
return None
result = base_rgb.astype(np.float32)
layers_sorted = sorted([l for l in layers if l is not None], key=lambda d: d["ts"])
for layer in layers_sorted:
ov = layer["overlay"].astype(np.float32)
a = (layer["alpha"].astype(np.float32) / 255.0)[..., None] # HxWx1
if ov.shape[:2] != result.shape[:2]:
ov = cv2.resize(ov, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR)
a = cv2.resize(a, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR)[..., None]
result = ov * a + result * (1.0 - a)
return np.clip(result, 0, 255).astype(np.uint8)
def on_image_upload(image, state):
"""
Reset overlays if uploading a new image.
"""
if image is None:
return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", ""
img_rgb = _ensure_rgb(image)
new_state = {"base": img_rgb, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}
return img_rgb, new_state, "", ""
def clear_overlays(image, state):
if state is None or state.get("base") is None:
return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", ""
base = state["base"]
state["det"] = None
state["seg"] = None
state["det_counts"] = ""
state["seg_counts"] = ""
return base, state, "", ""
#endregion
#region Maintenance
def _dir_size(path: str) -> int:
try:
total = 0
for root, _, files in os.walk(path):
for f in files:
fp = os.path.join(root, f)
try:
total += os.path.getsize(fp)
except Exception:
pass
return total
except Exception:
return 0
def _fmt_bytes(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024.0:
return f"{n:.1f} {unit}"
n /= 1024.0
return f"{n:.1f} PB"
def check_storage():
# Key cache locations
paths = [
os.path.expanduser("~/.cache/huggingface/hub"),
os.path.expanduser("~/.cache/torch"),
os.path.expanduser("~/.cache/pip"),
os.path.expanduser("~/.config/Ultralytics"),
"/tmp/hf_home/hub",
"/tmp/torch_home",
]
lines = []
for p in paths:
sz = _dir_size(p) if os.path.exists(p) else 0
lines.append(f"{p}: {_fmt_bytes(sz)}")
try:
total, used, free = shutil.disk_usage("/")
disk_line = f"Disk usage: used {_fmt_bytes(used)} / total {_fmt_bytes(total)} (free {_fmt_bytes(free)})"
except Exception:
disk_line = "Disk usage: n/a"
return "Cache sizes:\n" + "\n".join(lines) + "\n" + disk_line
def clean_caches():
paths = [
os.path.expanduser("~/.cache/huggingface/hub"),
os.path.expanduser("~/.cache/torch"),
os.path.expanduser("~/.cache/pip"),
os.path.expanduser("~/.config/Ultralytics"),
"/tmp/hf_home",
"/tmp/torch_home",
]
removed = []
for p in paths:
try:
if os.path.exists(p):
shutil.rmtree(p, ignore_errors=True)
removed.append(p)
except Exception:
pass
return "Removed:\n" + ("\n".join(removed) if removed else "(none)")
#endregion
def build_app():
with gr.Blocks(title="Berries detection & bunches segmentation") as demo:
gr.Markdown(
"## Double inference on the same image with combined overlays\n"
"- Model A: berries detection\n"
"- Model B: bunches segmentation\n"
"- Run individually; overlays combine on the same image.\n"
)
state = gr.State({"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""})
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(label="Image", type="numpy")
with gr.Tab("Model A — Berries Detection"):
with gr.Row():
conf_det = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (A)")
device_a = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device")
with gr.Row():
slice_h = gr.Slider(64, 2048, value=640, step=32, label="Slice H (A)")
slice_w = gr.Slider(64, 2048, value=640, step=32, label="Slice W (A)")
with gr.Row():
overlap_h = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap H (A)")
overlap_w = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap W (A)")
btn_det = gr.Button("Run berries detection")
with gr.Tab("Model B — Bunches Segmentation"):
with gr.Row():
conf_seg = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (B)")
seg_alpha = gr.Slider(0.0, 1.0, value=SEG_DEFAULT_ALPHA, step=0.05, label="Alpha masks (B)")
device_b = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device")
btn_seg = gr.Button("Run bunches segmentation")
with gr.Row():
btn_clear = gr.Button("Clean overlay", variant="secondary")
with gr.Accordion("Disk Maintenance", open=False):
btn_check = gr.Button("Check storage")
btn_clean = gr.Button("Clean cache")
maint_out = gr.Textbox(label="Log Maintenance", interactive=False)
with gr.Column(scale=2):
img_out = gr.Image(label="Combined Result", type="numpy")
with gr.Row():
counts_out_det = gr.Textbox(label="Counts - Berries", interactive=False)
counts_out_seg = gr.Textbox(label="Counts - Bunches", interactive=False)
# Wiring
img_in.change(
on_image_upload,
inputs=[img_in, state],
outputs=[img_out, state, counts_out_det, counts_out_seg],
)
btn_det.click(
run_det,
inputs=[
img_in, state,
conf_det, slice_h, slice_w, overlap_h, overlap_w, device_a
],
outputs=[img_out, state, counts_out_det, counts_out_seg],
)
btn_seg.click(
run_seg,
inputs=[
img_in, state,
conf_seg, device_b, seg_alpha
],
outputs=[img_out, state, counts_out_det, counts_out_seg],
)
btn_clear.click(
clear_overlays,
inputs=[img_in, state],
outputs=[img_out, state, counts_out_det, counts_out_seg],
)
btn_check.click(
check_storage,
inputs=[],
outputs=[maint_out],
)
btn_clean.click(
clean_caches,
inputs=[],
outputs=[maint_out],
)
return demo
if __name__ == "__main__":
demo = build_app()
demo.launch()