File size: 4,080 Bytes
7119114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Push the trained queen vs worker classifier to its own HF model repo."""

import argparse
import os
from pathlib import Path

from dotenv import load_dotenv
from huggingface_hub import HfApi


WEIGHTS = Path(__file__).parent.parent / "weights" / "queen_classifier.pt"

README = """---
license: apache-2.0
library_name: timm
tags:
- bees
- beekeeping
- image-classification
- efficientnet
pipeline_tag: image-classification
---

# Apiarist Queen-vs-Worker Bee Classifier

Binary image classifier (EfficientNet-B0, ~5M params) trained to
distinguish queen bees from worker bees on cropped bee images.

Built as part of [Apiarist](https://huggingface.co/spaces/build-small-hackathon/Apiarist),
an offline AI hive inspector for backyard beekeepers, made for the
[Build Small Hackathon](https://huggingface.co/build-small-hackathon).

## Why a dedicated classifier?

Multi-class YOLO detectors fight two problems at once (localize + classify)
and queens lose because they're rare and visually subtle. A focused
binary classifier on cropped bee images is the right architecture:
small, fast, trained specifically for one decision.

## Training

- Backbone: `efficientnet_b0` (ImageNet pretrained)
- Training data: bee crops extracted from labelled bounding boxes in two
  Roboflow datasets (Matt Nudi honey bees + Hendricks Ricky bee-project)
- 1,146 queen crops + 29,825 worker crops, balanced via weighted sampling
- Heavy augmentation: rotations, flips, color jitter
- 90/10 train/val split, weighted random sampling for class balance
- AdamW + cosine schedule, mixed precision on a single T4 GPU
- Trained on [Modal](https://modal.com)

## Validation metrics

- Accuracy: 0.997
- Precision (queen): 0.991
- Recall (queen): 0.934
- **F1: 0.962**

## Recommended use

Pair with a bee detector (e.g. YOLOv8). Run the detector first, then
classify each cropped bee through this model. Threshold queen
probability at 0.85 for high-precision flagging.

```python
import torch, timm
from torchvision import transforms

ckpt = torch.load("queen_classifier.pt", map_location="cpu")
model = timm.create_model(ckpt["arch"], pretrained=False, num_classes=2)
model.load_state_dict(ckpt["state_dict"])
model.eval()

tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])

with torch.no_grad():
    probs = torch.softmax(model(tf(crop).unsqueeze(0)), dim=1)
queen_idx = ckpt["class_to_idx"]["queen"]
queen_prob = probs[0, queen_idx].item()
```

## Caveats

The training distribution leans toward close-up macro photos of bees on
honeycomb. Generalization to wide-angle inspection photos (with hands /
background visible) is weaker, since YOLO's bee bounding boxes on those
photos are often smaller and less precise than the training crops.

## License

Apache 2.0. Trained on data released under CC BY 4.0.
"""


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--repo-id", required=True)
    parser.add_argument("--private", action="store_true")
    args = parser.parse_args()

    load_dotenv()
    token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
    if not token:
        raise SystemExit("Set HF_TOKEN in .env")
    if not WEIGHTS.exists():
        raise SystemExit(f"Weights not at {WEIGHTS}")

    api = HfApi(token=token)
    print(f"Creating repo {args.repo_id} ...")
    api.create_repo(args.repo_id, repo_type="model", private=args.private,
                    exist_ok=True)

    api.upload_file(
        path_or_fileobj=README.encode("utf-8"),
        path_in_repo="README.md",
        repo_id=args.repo_id, repo_type="model",
        commit_message="Add model card",
    )
    api.upload_file(
        path_or_fileobj=str(WEIGHTS),
        path_in_repo="queen_classifier.pt",
        repo_id=args.repo_id, repo_type="model",
        commit_message="Upload EfficientNet-B0 queen classifier weights",
    )
    print(f"\n[OK] Model live at: https://huggingface.co/{args.repo_id}")


if __name__ == "__main__":
    main()