Spaces:
Running on Zero
Running on Zero
Apiarist Dev
feat: binary classifier cascade - prefers dedicated queen classifier, falls back to VLM grid
b27c5e3 | """ | |
| Train a dedicated binary queen-vs-worker classifier on bee crops. | |
| Why: YOLO is great at finding bees but mediocre at classifying queen | |
| vs worker because it has to learn both localization and classification | |
| at once. A focused binary classifier sees ONLY cropped bees and only | |
| decides "queen or not" - a much easier task. | |
| Pipeline: | |
| 1. Download both labelled datasets (Matt Nudi + Hendricks). | |
| 2. For every annotated bounding box, crop the bee and write it to | |
| either queen/ or worker/ depending on the class label. | |
| 3. Train an EfficientNet-B0 on those crops with heavy augmentation. | |
| 4. Save weights for inference on the Space. | |
| Run: | |
| py scripts/train_queen_classifier.py | |
| """ | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| import modal | |
| APP_NAME = "apiarist-queen-classifier" | |
| VOLUME_NAME = "apiarist-weights" | |
| IMG_SIZE = 224 | |
| BATCH = 64 | |
| EPOCHS = 25 | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| .pip_install( | |
| "roboflow==1.1.50", | |
| "timm==1.0.11", | |
| "torch==2.4.0", | |
| "torchvision==0.19.0", | |
| "pillow", | |
| "pyyaml", | |
| ) | |
| .apt_install("libgl1", "libglib2.0-0") | |
| ) | |
| vol = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) | |
| app = modal.App(APP_NAME) | |
| def _yolo_bbox_to_pixel(box, img_w, img_h): | |
| """YOLO format: cx, cy, w, h (all normalised 0-1) -> pixel x1,y1,x2,y2.""" | |
| cx, cy, w, h = box | |
| x1 = max(0, int((cx - w / 2) * img_w)) | |
| y1 = max(0, int((cy - h / 2) * img_h)) | |
| x2 = min(img_w, int((cx + w / 2) * img_w)) | |
| y2 = min(img_h, int((cy + h / 2) * img_h)) | |
| return x1, y1, x2, y2 | |
| def _extract_crops(dataset_root, label_map, out_root, prefix): | |
| """For every label file, crop the bee and save under | |
| out_root/<canonical_class>/<prefix>_<name>_<i>.jpg.""" | |
| from PIL import Image as PILImage | |
| for split in ("train", "valid", "test"): | |
| img_dir = dataset_root / split / "images" | |
| lbl_dir = dataset_root / split / "labels" | |
| if not img_dir.exists() or not lbl_dir.exists(): | |
| continue | |
| for img_path in img_dir.iterdir(): | |
| if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png"): | |
| continue | |
| lbl_path = lbl_dir / f"{img_path.stem}.txt" | |
| if not lbl_path.exists(): | |
| continue | |
| try: | |
| img = PILImage.open(img_path).convert("RGB") | |
| except Exception: | |
| continue | |
| W, H = img.size | |
| for i, line in enumerate(lbl_path.read_text().splitlines()): | |
| parts = line.strip().split() | |
| if len(parts) != 5: | |
| continue | |
| try: | |
| cls_id = int(parts[0]) | |
| coords = [float(x) for x in parts[1:]] | |
| except ValueError: | |
| continue | |
| canonical = label_map.get(cls_id) | |
| if canonical is None: | |
| continue | |
| x1, y1, x2, y2 = _yolo_bbox_to_pixel(coords, W, H) | |
| if x2 - x1 < 24 or y2 - y1 < 24: | |
| continue | |
| pad = 8 | |
| x1 = max(0, x1 - pad) | |
| y1 = max(0, y1 - pad) | |
| x2 = min(W, x2 + pad) | |
| y2 = min(H, y2 + pad) | |
| try: | |
| crop = img.crop((x1, y1, x2, y2)) | |
| except Exception: | |
| continue | |
| cls_dir = out_root / canonical | |
| cls_dir.mkdir(parents=True, exist_ok=True) | |
| crop.save(cls_dir / f"{prefix}_{img_path.stem}_{i}.jpg", "JPEG") | |
| def train(rf_api_key: str) -> str: | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, WeightedRandomSampler | |
| from torchvision import transforms | |
| from torchvision.datasets import ImageFolder | |
| import timm | |
| from roboflow import Roboflow | |
| print("Downloading datasets ...") | |
| rf = Roboflow(api_key=rf_api_key) | |
| hend = rf.workspace("hendricks_ricky-hotmail-de").project("bee-project").version(2).download("yolov8", location="/tmp/hendricks") | |
| matt = rf.workspace("matt-nudi").project("honey-bee-detection-model-zgjnb").version(4).download("yolov8", location="/tmp/matt_nudi") | |
| HENDRICKS_LABEL_MAP = { | |
| 0: "worker", # Drone Bee, NOT a queen (we treat drones as worker for this binary task) | |
| 1: "queen", | |
| 2: None, # Varroa Mite, exclude | |
| 3: "worker", # Worker Bee | |
| } | |
| MATT_LABEL_MAP = { | |
| 0: "worker", # bee | |
| 1: "worker", # drone (same: not a queen) | |
| 2: "worker", # pollenbee | |
| 3: "queen", # queen | |
| } | |
| crops_root = Path("/tmp/crops") | |
| if crops_root.exists(): | |
| shutil.rmtree(crops_root) | |
| print("Extracting crops ...") | |
| _extract_crops(Path(hend.location), HENDRICKS_LABEL_MAP, crops_root, "hendr") | |
| _extract_crops(Path(matt.location), MATT_LABEL_MAP, crops_root, "mnudi") | |
| n_queen = len(list((crops_root / "queen").iterdir())) if (crops_root / "queen").exists() else 0 | |
| n_worker = len(list((crops_root / "worker").iterdir())) if (crops_root / "worker").exists() else 0 | |
| print(f"Crops extracted: queen={n_queen}, worker={n_worker}") | |
| if n_queen < 50 or n_worker < 50: | |
| print("ERROR: not enough crops", file=sys.stderr) | |
| sys.exit(1) | |
| # Transforms | |
| train_tf = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.RandomRotation(20), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| val_tf = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| full_ds = ImageFolder(str(crops_root), transform=train_tf) | |
| print(f"Class to idx: {full_ds.class_to_idx}") | |
| # 90/10 train/val split | |
| n = len(full_ds) | |
| n_val = max(50, n // 10) | |
| n_train = n - n_val | |
| train_ds, val_ds = torch.utils.data.random_split( | |
| full_ds, [n_train, n_val], | |
| generator=torch.Generator().manual_seed(42), | |
| ) | |
| val_ds.dataset.transform = val_tf # type: ignore | |
| # Weighted sampler to balance queen vs worker | |
| labels = [full_ds.targets[i] for i in train_ds.indices] | |
| class_counts = [labels.count(i) for i in range(len(full_ds.classes))] | |
| print(f"Train class counts: {dict(zip(full_ds.classes, class_counts))}") | |
| class_weights = [1.0 / c for c in class_counts] | |
| sample_weights = [class_weights[l] for l in labels] | |
| sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) | |
| train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler, num_workers=4, pin_memory=True) | |
| val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True) | |
| # Model | |
| device = "cuda" | |
| model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=2) | |
| model = model.to(device) | |
| crit = nn.CrossEntropyLoss() | |
| opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4) | |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS) | |
| best_acc = 0 | |
| best_path = Path("/weights/queen_classifier/best.pt") | |
| best_path.parent.mkdir(parents=True, exist_ok=True) | |
| queen_idx = full_ds.class_to_idx["queen"] | |
| for epoch in range(1, EPOCHS + 1): | |
| model.train() | |
| train_loss = 0 | |
| for imgs, lbls in train_loader: | |
| imgs, lbls = imgs.to(device), lbls.to(device) | |
| opt.zero_grad() | |
| out = model(imgs) | |
| loss = crit(out, lbls) | |
| loss.backward() | |
| opt.step() | |
| train_loss += loss.item() * imgs.size(0) | |
| sched.step() | |
| train_loss /= len(train_loader.dataset) | |
| # Eval | |
| model.eval() | |
| tp = fp = tn = fn = 0 | |
| with torch.no_grad(): | |
| for imgs, lbls in val_loader: | |
| imgs, lbls = imgs.to(device), lbls.to(device) | |
| pred = model(imgs).argmax(1) | |
| for p, l in zip(pred.tolist(), lbls.tolist()): | |
| if p == queen_idx and l == queen_idx: tp += 1 | |
| elif p == queen_idx and l != queen_idx: fp += 1 | |
| elif p != queen_idx and l == queen_idx: fn += 1 | |
| else: tn += 1 | |
| precision = tp / max(1, tp + fp) | |
| recall = tp / max(1, tp + fn) | |
| acc = (tp + tn) / max(1, tp + tn + fp + fn) | |
| f1 = 2 * precision * recall / max(1e-6, precision + recall) | |
| print(f"epoch {epoch:>2}: train_loss={train_loss:.4f} val_acc={acc:.3f} P={precision:.3f} R={recall:.3f} F1={f1:.3f}") | |
| if f1 > best_acc: | |
| best_acc = f1 | |
| torch.save({ | |
| "state_dict": model.state_dict(), | |
| "class_to_idx": full_ds.class_to_idx, | |
| "img_size": IMG_SIZE, | |
| "arch": "efficientnet_b0", | |
| "epoch": epoch, | |
| "f1": f1, "precision": precision, "recall": recall, "acc": acc, | |
| }, best_path) | |
| print(f" saved new best (F1={f1:.3f})") | |
| print(f"\n[OK] best F1={best_acc:.3f}, weights at {best_path}") | |
| vol.commit() | |
| return str(best_path) | |
| def main() -> None: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| api_key = os.environ.get("ROBOFLOW_API_KEY") | |
| if not api_key: | |
| raise SystemExit("Missing ROBOFLOW_API_KEY in .env") | |
| print("Kicking off queen classifier training on Modal ...") | |
| weights_path = train.remote(rf_api_key=api_key) | |
| print(f"\nDONE. {weights_path}") | |
| print("\nDownload:\n modal volume get apiarist-weights queen_classifier/best.pt weights/queen_classifier.pt") | |