Spaces:
Running on Zero
Running on Zero
Apiarist Dev
feat: ranked top-3 queen candidates with probability shown - honest AI, user verifies
7119114 | """Push the trained queen vs worker classifier to its own HF model repo.""" | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| from huggingface_hub import HfApi | |
| WEIGHTS = Path(__file__).parent.parent / "weights" / "queen_classifier.pt" | |
| README = """--- | |
| license: apache-2.0 | |
| library_name: timm | |
| tags: | |
| - bees | |
| - beekeeping | |
| - image-classification | |
| - efficientnet | |
| pipeline_tag: image-classification | |
| --- | |
| # Apiarist Queen-vs-Worker Bee Classifier | |
| Binary image classifier (EfficientNet-B0, ~5M params) trained to | |
| distinguish queen bees from worker bees on cropped bee images. | |
| Built as part of [Apiarist](https://huggingface.co/spaces/build-small-hackathon/Apiarist), | |
| an offline AI hive inspector for backyard beekeepers, made for the | |
| [Build Small Hackathon](https://huggingface.co/build-small-hackathon). | |
| ## Why a dedicated classifier? | |
| Multi-class YOLO detectors fight two problems at once (localize + classify) | |
| and queens lose because they're rare and visually subtle. A focused | |
| binary classifier on cropped bee images is the right architecture: | |
| small, fast, trained specifically for one decision. | |
| ## Training | |
| - Backbone: `efficientnet_b0` (ImageNet pretrained) | |
| - Training data: bee crops extracted from labelled bounding boxes in two | |
| Roboflow datasets (Matt Nudi honey bees + Hendricks Ricky bee-project) | |
| - 1,146 queen crops + 29,825 worker crops, balanced via weighted sampling | |
| - Heavy augmentation: rotations, flips, color jitter | |
| - 90/10 train/val split, weighted random sampling for class balance | |
| - AdamW + cosine schedule, mixed precision on a single T4 GPU | |
| - Trained on [Modal](https://modal.com) | |
| ## Validation metrics | |
| - Accuracy: 0.997 | |
| - Precision (queen): 0.991 | |
| - Recall (queen): 0.934 | |
| - **F1: 0.962** | |
| ## Recommended use | |
| Pair with a bee detector (e.g. YOLOv8). Run the detector first, then | |
| classify each cropped bee through this model. Threshold queen | |
| probability at 0.85 for high-precision flagging. | |
| ```python | |
| import torch, timm | |
| from torchvision import transforms | |
| ckpt = torch.load("queen_classifier.pt", map_location="cpu") | |
| model = timm.create_model(ckpt["arch"], pretrained=False, num_classes=2) | |
| model.load_state_dict(ckpt["state_dict"]) | |
| model.eval() | |
| tf = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]), | |
| ]) | |
| with torch.no_grad(): | |
| probs = torch.softmax(model(tf(crop).unsqueeze(0)), dim=1) | |
| queen_idx = ckpt["class_to_idx"]["queen"] | |
| queen_prob = probs[0, queen_idx].item() | |
| ``` | |
| ## Caveats | |
| The training distribution leans toward close-up macro photos of bees on | |
| honeycomb. Generalization to wide-angle inspection photos (with hands / | |
| background visible) is weaker, since YOLO's bee bounding boxes on those | |
| photos are often smaller and less precise than the training crops. | |
| ## License | |
| Apache 2.0. Trained on data released under CC BY 4.0. | |
| """ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--repo-id", required=True) | |
| parser.add_argument("--private", action="store_true") | |
| args = parser.parse_args() | |
| load_dotenv() | |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| if not token: | |
| raise SystemExit("Set HF_TOKEN in .env") | |
| if not WEIGHTS.exists(): | |
| raise SystemExit(f"Weights not at {WEIGHTS}") | |
| api = HfApi(token=token) | |
| print(f"Creating repo {args.repo_id} ...") | |
| api.create_repo(args.repo_id, repo_type="model", private=args.private, | |
| exist_ok=True) | |
| api.upload_file( | |
| path_or_fileobj=README.encode("utf-8"), | |
| path_in_repo="README.md", | |
| repo_id=args.repo_id, repo_type="model", | |
| commit_message="Add model card", | |
| ) | |
| api.upload_file( | |
| path_or_fileobj=str(WEIGHTS), | |
| path_in_repo="queen_classifier.pt", | |
| repo_id=args.repo_id, repo_type="model", | |
| commit_message="Upload EfficientNet-B0 queen classifier weights", | |
| ) | |
| print(f"\n[OK] Model live at: https://huggingface.co/{args.repo_id}") | |
| if __name__ == "__main__": | |
| main() | |