Apiarist / scripts /push_classifier_to_hub.py
Apiarist Dev
feat: ranked top-3 queen candidates with probability shown - honest AI, user verifies
7119114
Raw
History Blame Contribute Delete
4.08 kB
"""Push the trained queen vs worker classifier to its own HF model repo."""
import argparse
import os
from pathlib import Path
from dotenv import load_dotenv
from huggingface_hub import HfApi
WEIGHTS = Path(__file__).parent.parent / "weights" / "queen_classifier.pt"
README = """---
license: apache-2.0
library_name: timm
tags:
- bees
- beekeeping
- image-classification
- efficientnet
pipeline_tag: image-classification
---
# Apiarist Queen-vs-Worker Bee Classifier
Binary image classifier (EfficientNet-B0, ~5M params) trained to
distinguish queen bees from worker bees on cropped bee images.
Built as part of [Apiarist](https://huggingface.co/spaces/build-small-hackathon/Apiarist),
an offline AI hive inspector for backyard beekeepers, made for the
[Build Small Hackathon](https://huggingface.co/build-small-hackathon).
## Why a dedicated classifier?
Multi-class YOLO detectors fight two problems at once (localize + classify)
and queens lose because they're rare and visually subtle. A focused
binary classifier on cropped bee images is the right architecture:
small, fast, trained specifically for one decision.
## Training
- Backbone: `efficientnet_b0` (ImageNet pretrained)
- Training data: bee crops extracted from labelled bounding boxes in two
Roboflow datasets (Matt Nudi honey bees + Hendricks Ricky bee-project)
- 1,146 queen crops + 29,825 worker crops, balanced via weighted sampling
- Heavy augmentation: rotations, flips, color jitter
- 90/10 train/val split, weighted random sampling for class balance
- AdamW + cosine schedule, mixed precision on a single T4 GPU
- Trained on [Modal](https://modal.com)
## Validation metrics
- Accuracy: 0.997
- Precision (queen): 0.991
- Recall (queen): 0.934
- **F1: 0.962**
## Recommended use
Pair with a bee detector (e.g. YOLOv8). Run the detector first, then
classify each cropped bee through this model. Threshold queen
probability at 0.85 for high-precision flagging.
```python
import torch, timm
from torchvision import transforms
ckpt = torch.load("queen_classifier.pt", map_location="cpu")
model = timm.create_model(ckpt["arch"], pretrained=False, num_classes=2)
model.load_state_dict(ckpt["state_dict"])
model.eval()
tf = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])
with torch.no_grad():
probs = torch.softmax(model(tf(crop).unsqueeze(0)), dim=1)
queen_idx = ckpt["class_to_idx"]["queen"]
queen_prob = probs[0, queen_idx].item()
```
## Caveats
The training distribution leans toward close-up macro photos of bees on
honeycomb. Generalization to wide-angle inspection photos (with hands /
background visible) is weaker, since YOLO's bee bounding boxes on those
photos are often smaller and less precise than the training crops.
## License
Apache 2.0. Trained on data released under CC BY 4.0.
"""
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", required=True)
parser.add_argument("--private", action="store_true")
args = parser.parse_args()
load_dotenv()
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
if not token:
raise SystemExit("Set HF_TOKEN in .env")
if not WEIGHTS.exists():
raise SystemExit(f"Weights not at {WEIGHTS}")
api = HfApi(token=token)
print(f"Creating repo {args.repo_id} ...")
api.create_repo(args.repo_id, repo_type="model", private=args.private,
exist_ok=True)
api.upload_file(
path_or_fileobj=README.encode("utf-8"),
path_in_repo="README.md",
repo_id=args.repo_id, repo_type="model",
commit_message="Add model card",
)
api.upload_file(
path_or_fileobj=str(WEIGHTS),
path_in_repo="queen_classifier.pt",
repo_id=args.repo_id, repo_type="model",
commit_message="Upload EfficientNet-B0 queen classifier weights",
)
print(f"\n[OK] Model live at: https://huggingface.co/{args.repo_id}")
if __name__ == "__main__":
main()