Apiarist / detector.py
Apiarist Dev
feat: ranked top-3 queen candidates with probability shown - honest AI, user verifies
7119114
Raw
History Blame Contribute Delete
11.4 kB
"""
YOLOv8 specialist detector for bees / drones / queens / pollen bees.
Loads the custom-trained weights from weights/honey_bee_detector.pt if
present. Degrades gracefully (no-op detections, original image) if the
weights file is missing, useful while training is still in progress.
"""
from __future__ import annotations
import math
import os
import sys
from pathlib import Path
from typing import Optional
from PIL import Image, ImageDraw, ImageFont
# Run inference at a larger size than the YOLO default (640). Bees are tiny
# relative to a full frame photo; at 640 the queen's distinguishing abdomen
# length collapses to a couple of pixels. 1280 keeps that detail (and lifts
# small-bee recall generally) at a modest CPU cost.
def _pick_imgsz(img):
"""Choose inference image size based on input.
The YOLO was trained at 640 and degrades sharply when you upscale a
small image (e.g. a 225x224 thumbnail) to 1920. But large phone
photos (3000+ px) need more than 640 to keep individual bees
detectable. Strategy: stay at 640 unless the photo is genuinely
large, then step up.
"""
long_side = max(img.size)
if long_side <= 800:
return 640
if long_side <= 1600:
return 1280
return 1920
# Resolve the weights file relative to this module's actual location.
_HERE = Path(os.path.dirname(os.path.abspath(__file__)))
WEIGHTS_PATH = _HERE / "weights" / "honey_bee_detector.pt"
_yolo = None
_yolo_failed = False
_logged_status = False
def _log_status_once():
global _logged_status
if _logged_status:
return
_logged_status = True
print("=" * 60, file=sys.stderr)
print(f"[detector] module dir: {_HERE}", file=sys.stderr)
print(f"[detector] weights path: {WEIGHTS_PATH}", file=sys.stderr)
print(f"[detector] weights exist: {WEIGHTS_PATH.exists()}", file=sys.stderr)
if WEIGHTS_PATH.exists():
size = WEIGHTS_PATH.stat().st_size
print(
f"[detector] weights size: {size} bytes "
f"({size / 1024 / 1024:.1f} MB)",
file=sys.stderr,
)
# If file is < 1 KB, it's almost certainly an LFS pointer text
if size < 1024:
try:
head = WEIGHTS_PATH.read_text()[:200]
print(f"[detector] suspiciously small, head: {head}",
file=sys.stderr)
except Exception:
pass
# List contents of weights dir
weights_dir = WEIGHTS_PATH.parent
if weights_dir.exists():
print(
f"[detector] weights/ contents: "
f"{list(weights_dir.iterdir())}",
file=sys.stderr,
)
else:
print("[detector] weights/ dir does not exist", file=sys.stderr)
print("=" * 60, file=sys.stderr)
def _try_load_yolo():
"""Load YOLO model lazily on each call (so a file that appears after
startup is still picked up)."""
global _yolo, _yolo_failed
_log_status_once()
if _yolo is not None:
return _yolo
if not WEIGHTS_PATH.exists():
return None
# Sanity-check: real .pt weights are typically tens of MB; an LFS
# pointer text file is <200 bytes.
if WEIGHTS_PATH.stat().st_size < 1024:
print(
f"[detector] weights file is suspiciously small "
f"({WEIGHTS_PATH.stat().st_size} bytes); likely an unresolved "
"LFS pointer. Skipping load.",
file=sys.stderr,
)
return None
try:
from ultralytics import YOLO
_yolo = YOLO(str(WEIGHTS_PATH))
# Force CPU, we run on the main container, not the ZeroGPU worker.
# ZeroGPU's CUDA emulation rejects torch.cuda calls outside @gpu.
_yolo.to("cpu")
print(f"[detector] loaded YOLO weights from {WEIGHTS_PATH} (cpu)",
file=sys.stderr)
_yolo_failed = False
except Exception as e:
print(f"[detector] YOLO load failed: {type(e).__name__}: {e}",
file=sys.stderr)
_yolo_failed = True
return _yolo
def is_available() -> bool:
"""True if weights are present and look like a real binary."""
return (
WEIGHTS_PATH.exists()
and WEIGHTS_PATH.stat().st_size >= 1024
and not _yolo_failed
)
# Canonicalize class names so the app code works with one set of labels
# regardless of which training dataset shipped these weights. The new
# hendricks_ricky dataset uses "Queen Bee" / "Drone Bee" / "Worker Bee"
# / "Varroa Mite" labels, the older Matt Nudi set used "queen" / "drone"
# / "bee" / "pollenbee" - this handles both.
def _canonicalize(name: str) -> str:
n = name.lower().strip()
if "queen" in n:
return "queen"
if "varroa" in n or "mite" in n:
return "varroa"
if "drone" in n:
return "drone"
if "pollen" in n:
return "pollenbee"
if "worker" in n or "bee" in n:
return "bee"
return n
_PER_CLASS_CONF = {
"bee": 0.04, # very permissive - the classifier downstream filters out the noise
"drone": 0.30,
"varroa": 0.15,
"pollenbee": 0.25,
"queen": 0.04,
}
def detect(
image: Image.Image, conf: float = 0.10
) -> tuple[list[dict], Optional[Image.Image]]:
"""
Run YOLO on the image, then apply per-class confidence thresholds.
Returns (detections, annotated_image).
Each detection is a dict {class, confidence, bbox: [x1, y1, x2, y2]}.
If YOLO is unavailable, returns ([], None).
"""
yolo = _try_load_yolo()
if yolo is None:
return [], None
try:
# Cast a wide net at the model level, then filter per-class below.
imgsz = _pick_imgsz(image)
results = yolo(
image, conf=conf, imgsz=imgsz, verbose=False, device="cpu"
)
if not results:
return [], image
r = results[0]
class_names = r.names
detections: list[dict] = []
for box in r.boxes:
cls_id = int(box.cls.item())
raw_name = class_names[cls_id]
cls_name = _canonicalize(raw_name)
confidence = float(box.conf.item())
if confidence < _PER_CLASS_CONF.get(cls_name, 0.25):
continue
xyxy = box.xyxy[0].cpu().numpy().tolist()
detections.append({
"class": cls_name,
"confidence": round(confidence, 3),
"bbox": [round(v, 1) for v in xyxy],
})
annotated_pil = _draw_annotations(image, detections)
return detections, annotated_pil
except Exception as e:
print(f"[detector] inference failed: {e}", file=sys.stderr)
return [], None
# Per-class drawing config. RGB(A). Tuned for clarity over a busy
# honeycomb background.
_CLASS_STYLES = {
"bee": {"color": (244, 163, 0), "width": 1, "label": None},
"drone": {"color": (255, 80, 80), "width": 1, "label": None},
"pollenbee": {"color": (255, 220, 80), "width": 1, "label": None},
"varroa": {"color": (220, 50, 220), "width": 2, "label": "mite"},
"queen": {"color": (50, 255, 100), "width": 4, "label": "QUEEN"},
}
def _font(size: int = 14):
candidates = [
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"DejaVuSans-Bold.ttf",
"arial.ttf",
]
for path in candidates:
try:
return ImageFont.truetype(path, size)
except Exception:
continue
return ImageFont.load_default()
_CANDIDATE_COLOR = (90, 220, 255) # cyan - distinct from the green of a confirmed queen
def _dashed_rectangle(draw, box, color, width=3, dash=12, gap=8):
"""Draw a dashed rectangle (PIL has no native dashed outline)."""
x1, y1, x2, y2 = box
corners = [(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)]
for (ax, ay), (bx, by) in zip(corners, corners[1:]):
length = math.hypot(bx - ax, by - ay)
if length == 0:
continue
n = int(length // (dash + gap)) + 1
for s in range(n):
t0 = (s * (dash + gap)) / length
t1 = min(t0 + dash / length, 1.0)
draw.line(
[ax + (bx - ax) * t0, ay + (by - ay) * t0,
ax + (bx - ax) * t1, ay + (by - ay) * t1],
fill=color, width=width,
)
def draw_annotations(image: Image.Image, detections: list[dict]) -> Image.Image:
"""Public wrapper so the cascade can re-annotate after re-classification."""
return _draw_annotations(image, detections)
def _draw_annotations(image: Image.Image, detections: list[dict]) -> Image.Image:
"""Draw thin per-class boxes. Only the queen gets a text label.
Everything else is a subtle color-coded outline so the frame stays readable."""
out = image.convert("RGB").copy()
draw = ImageDraw.Draw(out, "RGBA")
font_queen = _font(16)
# Draw common classes first so the queen stacks on top
sort_order = {"bee": 0, "pollenbee": 1, "drone": 2, "varroa": 3, "queen": 4}
detections_sorted = sorted(
detections, key=lambda d: sort_order.get(d["class"], 0)
)
for d in detections_sorted:
cls = d["class"]
style = _CLASS_STYLES.get(
cls, {"color": (255, 255, 255), "width": 1, "label": None}
)
x1, y1, x2, y2 = d["bbox"]
color = style["color"]
width = style["width"]
draw.rectangle([x1, y1, x2, y2], outline=color + (255,), width=width)
if cls == "queen":
# Glowing fill + chip label for the queen
draw.rectangle([x1, y1, x2, y2], fill=color + (55,))
label = "QUEEN"
tw = draw.textlength(label, font=font_queen)
th = font_queen.size + 4
bx1, by1 = x1, max(0, y1 - th - 2)
bx2, by2 = x1 + tw + 10, y1
draw.rectangle([bx1, by1, bx2, by2], fill=color + (235,))
draw.text((bx1 + 5, by1 + 1), label, fill=(20, 16, 8), font=font_queen)
# Queen candidates: every bee tagged by the cascade with
# queen_candidate=True gets a dashed cyan box and a probability
# label. Drawn last so they sit on top of the worker bee outlines.
font_cand = _font(13)
candidates = [d for d in detections if d.get("queen_candidate")]
# Sort by probability descending so the highest gets drawn last (on top)
candidates.sort(key=lambda d: d.get("queen_prob", 0))
for cand in candidates:
cx1, cy1, cx2, cy2 = cand["bbox"]
_dashed_rectangle(
draw, [cx1, cy1, cx2, cy2], _CANDIDATE_COLOR + (255,), width=3
)
prob = cand.get("queen_prob") or cand.get("queen_standout", 0)
label = f"queen? {int(prob * 100)}%"
tw = draw.textlength(label, font=font_cand)
th = font_cand.size + 3
bx1, by1 = cx1, max(0, cy1 - th - 2)
draw.rectangle(
[bx1, by1, bx1 + tw + 10, by1 + th],
fill=_CANDIDATE_COLOR + (235,),
)
draw.text(
(bx1 + 5, by1 + 1), label, fill=(8, 16, 20), font=font_cand
)
return out
def summarize_counts(detections: list[dict]) -> dict[str, int]:
"""Aggregate detection counts per class."""
counts: dict[str, int] = {}
for d in detections:
counts[d["class"]] = counts.get(d["class"], 0) + 1
return counts