Spaces:
Running on Zero
Running on Zero
Apiarist Dev
feat: ranked top-3 queen candidates with probability shown - honest AI, user verifies
7119114 | """ | |
| YOLOv8 specialist detector for bees / drones / queens / pollen bees. | |
| Loads the custom-trained weights from weights/honey_bee_detector.pt if | |
| present. Degrades gracefully (no-op detections, original image) if the | |
| weights file is missing, useful while training is still in progress. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Optional | |
| from PIL import Image, ImageDraw, ImageFont | |
| # Run inference at a larger size than the YOLO default (640). Bees are tiny | |
| # relative to a full frame photo; at 640 the queen's distinguishing abdomen | |
| # length collapses to a couple of pixels. 1280 keeps that detail (and lifts | |
| # small-bee recall generally) at a modest CPU cost. | |
| def _pick_imgsz(img): | |
| """Choose inference image size based on input. | |
| The YOLO was trained at 640 and degrades sharply when you upscale a | |
| small image (e.g. a 225x224 thumbnail) to 1920. But large phone | |
| photos (3000+ px) need more than 640 to keep individual bees | |
| detectable. Strategy: stay at 640 unless the photo is genuinely | |
| large, then step up. | |
| """ | |
| long_side = max(img.size) | |
| if long_side <= 800: | |
| return 640 | |
| if long_side <= 1600: | |
| return 1280 | |
| return 1920 | |
| # Resolve the weights file relative to this module's actual location. | |
| _HERE = Path(os.path.dirname(os.path.abspath(__file__))) | |
| WEIGHTS_PATH = _HERE / "weights" / "honey_bee_detector.pt" | |
| _yolo = None | |
| _yolo_failed = False | |
| _logged_status = False | |
| def _log_status_once(): | |
| global _logged_status | |
| if _logged_status: | |
| return | |
| _logged_status = True | |
| print("=" * 60, file=sys.stderr) | |
| print(f"[detector] module dir: {_HERE}", file=sys.stderr) | |
| print(f"[detector] weights path: {WEIGHTS_PATH}", file=sys.stderr) | |
| print(f"[detector] weights exist: {WEIGHTS_PATH.exists()}", file=sys.stderr) | |
| if WEIGHTS_PATH.exists(): | |
| size = WEIGHTS_PATH.stat().st_size | |
| print( | |
| f"[detector] weights size: {size} bytes " | |
| f"({size / 1024 / 1024:.1f} MB)", | |
| file=sys.stderr, | |
| ) | |
| # If file is < 1 KB, it's almost certainly an LFS pointer text | |
| if size < 1024: | |
| try: | |
| head = WEIGHTS_PATH.read_text()[:200] | |
| print(f"[detector] suspiciously small, head: {head}", | |
| file=sys.stderr) | |
| except Exception: | |
| pass | |
| # List contents of weights dir | |
| weights_dir = WEIGHTS_PATH.parent | |
| if weights_dir.exists(): | |
| print( | |
| f"[detector] weights/ contents: " | |
| f"{list(weights_dir.iterdir())}", | |
| file=sys.stderr, | |
| ) | |
| else: | |
| print("[detector] weights/ dir does not exist", file=sys.stderr) | |
| print("=" * 60, file=sys.stderr) | |
| def _try_load_yolo(): | |
| """Load YOLO model lazily on each call (so a file that appears after | |
| startup is still picked up).""" | |
| global _yolo, _yolo_failed | |
| _log_status_once() | |
| if _yolo is not None: | |
| return _yolo | |
| if not WEIGHTS_PATH.exists(): | |
| return None | |
| # Sanity-check: real .pt weights are typically tens of MB; an LFS | |
| # pointer text file is <200 bytes. | |
| if WEIGHTS_PATH.stat().st_size < 1024: | |
| print( | |
| f"[detector] weights file is suspiciously small " | |
| f"({WEIGHTS_PATH.stat().st_size} bytes); likely an unresolved " | |
| "LFS pointer. Skipping load.", | |
| file=sys.stderr, | |
| ) | |
| return None | |
| try: | |
| from ultralytics import YOLO | |
| _yolo = YOLO(str(WEIGHTS_PATH)) | |
| # Force CPU, we run on the main container, not the ZeroGPU worker. | |
| # ZeroGPU's CUDA emulation rejects torch.cuda calls outside @gpu. | |
| _yolo.to("cpu") | |
| print(f"[detector] loaded YOLO weights from {WEIGHTS_PATH} (cpu)", | |
| file=sys.stderr) | |
| _yolo_failed = False | |
| except Exception as e: | |
| print(f"[detector] YOLO load failed: {type(e).__name__}: {e}", | |
| file=sys.stderr) | |
| _yolo_failed = True | |
| return _yolo | |
| def is_available() -> bool: | |
| """True if weights are present and look like a real binary.""" | |
| return ( | |
| WEIGHTS_PATH.exists() | |
| and WEIGHTS_PATH.stat().st_size >= 1024 | |
| and not _yolo_failed | |
| ) | |
| # Canonicalize class names so the app code works with one set of labels | |
| # regardless of which training dataset shipped these weights. The new | |
| # hendricks_ricky dataset uses "Queen Bee" / "Drone Bee" / "Worker Bee" | |
| # / "Varroa Mite" labels, the older Matt Nudi set used "queen" / "drone" | |
| # / "bee" / "pollenbee" - this handles both. | |
| def _canonicalize(name: str) -> str: | |
| n = name.lower().strip() | |
| if "queen" in n: | |
| return "queen" | |
| if "varroa" in n or "mite" in n: | |
| return "varroa" | |
| if "drone" in n: | |
| return "drone" | |
| if "pollen" in n: | |
| return "pollenbee" | |
| if "worker" in n or "bee" in n: | |
| return "bee" | |
| return n | |
| _PER_CLASS_CONF = { | |
| "bee": 0.04, # very permissive - the classifier downstream filters out the noise | |
| "drone": 0.30, | |
| "varroa": 0.15, | |
| "pollenbee": 0.25, | |
| "queen": 0.04, | |
| } | |
| def detect( | |
| image: Image.Image, conf: float = 0.10 | |
| ) -> tuple[list[dict], Optional[Image.Image]]: | |
| """ | |
| Run YOLO on the image, then apply per-class confidence thresholds. | |
| Returns (detections, annotated_image). | |
| Each detection is a dict {class, confidence, bbox: [x1, y1, x2, y2]}. | |
| If YOLO is unavailable, returns ([], None). | |
| """ | |
| yolo = _try_load_yolo() | |
| if yolo is None: | |
| return [], None | |
| try: | |
| # Cast a wide net at the model level, then filter per-class below. | |
| imgsz = _pick_imgsz(image) | |
| results = yolo( | |
| image, conf=conf, imgsz=imgsz, verbose=False, device="cpu" | |
| ) | |
| if not results: | |
| return [], image | |
| r = results[0] | |
| class_names = r.names | |
| detections: list[dict] = [] | |
| for box in r.boxes: | |
| cls_id = int(box.cls.item()) | |
| raw_name = class_names[cls_id] | |
| cls_name = _canonicalize(raw_name) | |
| confidence = float(box.conf.item()) | |
| if confidence < _PER_CLASS_CONF.get(cls_name, 0.25): | |
| continue | |
| xyxy = box.xyxy[0].cpu().numpy().tolist() | |
| detections.append({ | |
| "class": cls_name, | |
| "confidence": round(confidence, 3), | |
| "bbox": [round(v, 1) for v in xyxy], | |
| }) | |
| annotated_pil = _draw_annotations(image, detections) | |
| return detections, annotated_pil | |
| except Exception as e: | |
| print(f"[detector] inference failed: {e}", file=sys.stderr) | |
| return [], None | |
| # Per-class drawing config. RGB(A). Tuned for clarity over a busy | |
| # honeycomb background. | |
| _CLASS_STYLES = { | |
| "bee": {"color": (244, 163, 0), "width": 1, "label": None}, | |
| "drone": {"color": (255, 80, 80), "width": 1, "label": None}, | |
| "pollenbee": {"color": (255, 220, 80), "width": 1, "label": None}, | |
| "varroa": {"color": (220, 50, 220), "width": 2, "label": "mite"}, | |
| "queen": {"color": (50, 255, 100), "width": 4, "label": "QUEEN"}, | |
| } | |
| def _font(size: int = 14): | |
| candidates = [ | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", | |
| "DejaVuSans-Bold.ttf", | |
| "arial.ttf", | |
| ] | |
| for path in candidates: | |
| try: | |
| return ImageFont.truetype(path, size) | |
| except Exception: | |
| continue | |
| return ImageFont.load_default() | |
| _CANDIDATE_COLOR = (90, 220, 255) # cyan - distinct from the green of a confirmed queen | |
| def _dashed_rectangle(draw, box, color, width=3, dash=12, gap=8): | |
| """Draw a dashed rectangle (PIL has no native dashed outline).""" | |
| x1, y1, x2, y2 = box | |
| corners = [(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)] | |
| for (ax, ay), (bx, by) in zip(corners, corners[1:]): | |
| length = math.hypot(bx - ax, by - ay) | |
| if length == 0: | |
| continue | |
| n = int(length // (dash + gap)) + 1 | |
| for s in range(n): | |
| t0 = (s * (dash + gap)) / length | |
| t1 = min(t0 + dash / length, 1.0) | |
| draw.line( | |
| [ax + (bx - ax) * t0, ay + (by - ay) * t0, | |
| ax + (bx - ax) * t1, ay + (by - ay) * t1], | |
| fill=color, width=width, | |
| ) | |
| def draw_annotations(image: Image.Image, detections: list[dict]) -> Image.Image: | |
| """Public wrapper so the cascade can re-annotate after re-classification.""" | |
| return _draw_annotations(image, detections) | |
| def _draw_annotations(image: Image.Image, detections: list[dict]) -> Image.Image: | |
| """Draw thin per-class boxes. Only the queen gets a text label. | |
| Everything else is a subtle color-coded outline so the frame stays readable.""" | |
| out = image.convert("RGB").copy() | |
| draw = ImageDraw.Draw(out, "RGBA") | |
| font_queen = _font(16) | |
| # Draw common classes first so the queen stacks on top | |
| sort_order = {"bee": 0, "pollenbee": 1, "drone": 2, "varroa": 3, "queen": 4} | |
| detections_sorted = sorted( | |
| detections, key=lambda d: sort_order.get(d["class"], 0) | |
| ) | |
| for d in detections_sorted: | |
| cls = d["class"] | |
| style = _CLASS_STYLES.get( | |
| cls, {"color": (255, 255, 255), "width": 1, "label": None} | |
| ) | |
| x1, y1, x2, y2 = d["bbox"] | |
| color = style["color"] | |
| width = style["width"] | |
| draw.rectangle([x1, y1, x2, y2], outline=color + (255,), width=width) | |
| if cls == "queen": | |
| # Glowing fill + chip label for the queen | |
| draw.rectangle([x1, y1, x2, y2], fill=color + (55,)) | |
| label = "QUEEN" | |
| tw = draw.textlength(label, font=font_queen) | |
| th = font_queen.size + 4 | |
| bx1, by1 = x1, max(0, y1 - th - 2) | |
| bx2, by2 = x1 + tw + 10, y1 | |
| draw.rectangle([bx1, by1, bx2, by2], fill=color + (235,)) | |
| draw.text((bx1 + 5, by1 + 1), label, fill=(20, 16, 8), font=font_queen) | |
| # Queen candidates: every bee tagged by the cascade with | |
| # queen_candidate=True gets a dashed cyan box and a probability | |
| # label. Drawn last so they sit on top of the worker bee outlines. | |
| font_cand = _font(13) | |
| candidates = [d for d in detections if d.get("queen_candidate")] | |
| # Sort by probability descending so the highest gets drawn last (on top) | |
| candidates.sort(key=lambda d: d.get("queen_prob", 0)) | |
| for cand in candidates: | |
| cx1, cy1, cx2, cy2 = cand["bbox"] | |
| _dashed_rectangle( | |
| draw, [cx1, cy1, cx2, cy2], _CANDIDATE_COLOR + (255,), width=3 | |
| ) | |
| prob = cand.get("queen_prob") or cand.get("queen_standout", 0) | |
| label = f"queen? {int(prob * 100)}%" | |
| tw = draw.textlength(label, font=font_cand) | |
| th = font_cand.size + 3 | |
| bx1, by1 = cx1, max(0, cy1 - th - 2) | |
| draw.rectangle( | |
| [bx1, by1, bx1 + tw + 10, by1 + th], | |
| fill=_CANDIDATE_COLOR + (235,), | |
| ) | |
| draw.text( | |
| (bx1 + 5, by1 + 1), label, fill=(8, 16, 20), font=font_cand | |
| ) | |
| return out | |
| def summarize_counts(detections: list[dict]) -> dict[str, int]: | |
| """Aggregate detection counts per class.""" | |
| counts: dict[str, int] = {} | |
| for d in detections: | |
| counts[d["class"]] = counts.get(d["class"], 0) + 1 | |
| return counts | |