Spaces:
Sleeping
Sleeping
Erick commited on
Upload folder using huggingface_hub
Browse files- .env.example +7 -0
- .gitattributes +3 -0
- .gitignore +41 -0
- .gradio/certificate.pem +31 -0
- .python-version +1 -0
- CONTEXT.md +164 -0
- Makefile +47 -0
- README.md +204 -6
- app.py +501 -0
- autolabel/__init__.py +12 -0
- autolabel/config.py +177 -0
- autolabel/detect.py +188 -0
- autolabel/export.py +156 -0
- autolabel/finetune.py +554 -0
- autolabel/segment.py +127 -0
- autolabel/utils.py +51 -0
- pyproject.toml +47 -0
- samples/CREDITS.txt +17 -0
- samples/animals.jpg +3 -0
- samples/cat.jpg +3 -0
- samples/dog.jpg +0 -0
- samples/kitchen.jpg +3 -0
- scripts/export_coco.py +47 -0
- scripts/finetune_owlv2.py +152 -0
- scripts/run_detection.py +86 -0
.env.example
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Required on Apple Silicon: enables CPU fallback for MPS ops not yet in Metal
|
| 2 |
+
PYTORCH_ENABLE_MPS_FALLBACK=1
|
| 3 |
+
|
| 4 |
+
# Optional overrides (uncomment and edit as needed)
|
| 5 |
+
# AUTOLABEL_DEVICE=cpu
|
| 6 |
+
# AUTOLABEL_MODEL=google/owlv2-large-patch14-finetuned
|
| 7 |
+
# AUTOLABEL_THRESHOLD=0.1
|
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
samples/animals.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
samples/cat.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
samples/kitchen.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment
|
| 2 |
+
.env
|
| 3 |
+
.venv/
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*.pyo
|
| 7 |
+
.pytest_cache/
|
| 8 |
+
*.egg-info/
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
|
| 12 |
+
# Data directories (potentially large / private)
|
| 13 |
+
data
|
| 14 |
+
data/raw/
|
| 15 |
+
data/detections/
|
| 16 |
+
data/labeled/
|
| 17 |
+
|
| 18 |
+
# Model cache (downloaded from Hugging Face)
|
| 19 |
+
.cache/
|
| 20 |
+
models/
|
| 21 |
+
|
| 22 |
+
# Jupyter
|
| 23 |
+
.ipynb_checkpoints/
|
| 24 |
+
notebooks/.ipynb_checkpoints/
|
| 25 |
+
|
| 26 |
+
# macOS
|
| 27 |
+
.DS_Store
|
| 28 |
+
|
| 29 |
+
# Windows
|
| 30 |
+
Thumbs.db
|
| 31 |
+
ehthumbs.db
|
| 32 |
+
|
| 33 |
+
# IDE
|
| 34 |
+
.idea/
|
| 35 |
+
.vscode/
|
| 36 |
+
*.swp
|
| 37 |
+
*.swo
|
| 38 |
+
|
| 39 |
+
# uv
|
| 40 |
+
uv.lock
|
| 41 |
+
/.claude/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
CONTEXT.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CONTEXT.md — Technical Reference for autolabel
|
| 2 |
+
|
| 3 |
+
> Keep this file up to date as the project evolves. Read this first when
|
| 4 |
+
> resuming work after a break.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## What this project does
|
| 9 |
+
|
| 10 |
+
Uses **OWLv2** (open-vocabulary object detection) and **SAM2** (segment
|
| 11 |
+
anything) to auto-label images via text prompts, then exports a COCO dataset
|
| 12 |
+
for fine-tuning a detection or segmentation model.
|
| 13 |
+
|
| 14 |
+
**Current phase:** labeling — two modes available:
|
| 15 |
+
- **Detection** — OWLv2 only; produces bounding boxes.
|
| 16 |
+
- **Segmentation** — OWLv2 → boxes → SAM2 → pixel masks + COCO polygons.
|
| 17 |
+
|
| 18 |
+
**Future phase:** fine-tune OWLv2 on the exported COCO dataset using
|
| 19 |
+
`scripts/finetune_owlv2.py` (code is ready, not yet in active use).
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Architecture
|
| 24 |
+
|
| 25 |
+
### Primary interface — `app.py` (Gradio web UI)
|
| 26 |
+
|
| 27 |
+
Two-tab UI, all artifacts written to a session temp dir (nothing in the project):
|
| 28 |
+
|
| 29 |
+
| Tab | What it does |
|
| 30 |
+
|-----|-------------|
|
| 31 |
+
| 🧪 Test | Single image → instant annotated preview. Dial in prompts and threshold before a batch run. |
|
| 32 |
+
| 📂 Batch | Multiple images → annotated gallery + downloadable ZIP (resized images + `coco_export.json`). |
|
| 33 |
+
|
| 34 |
+
### CLI scripts (`scripts/`)
|
| 35 |
+
|
| 36 |
+
Independent entry points for headless / automation use:
|
| 37 |
+
|
| 38 |
+
| Script | Purpose |
|
| 39 |
+
|--------|---------|
|
| 40 |
+
| `run_detection.py` | Batch detect → `data/detections/` |
|
| 41 |
+
| `export_coco.py` | Build COCO JSON from `data/labeled/` |
|
| 42 |
+
| `finetune_owlv2.py` | Fine-tune OWLv2 (future) |
|
| 43 |
+
|
| 44 |
+
### `autolabel/` package
|
| 45 |
+
|
| 46 |
+
| Module | Responsibility |
|
| 47 |
+
|--------|---------------|
|
| 48 |
+
| `config.py` | Pydantic settings singleton, auto device detection |
|
| 49 |
+
| `detect.py` | OWLv2 inference — `infer()` (PIL, shared) + `detect_image()` (file) + `run_detection()` (batch CLI) |
|
| 50 |
+
| `segment.py` | SAM2 integration — `load_sam2()`, `segment_with_boxes()`, `_mask_to_polygon()` |
|
| 51 |
+
| `export.py` | COCO JSON builder (no pycocotools); supports both bbox-only and segmentation |
|
| 52 |
+
| `finetune.py` | Training loop, loss, dataset, scheduler |
|
| 53 |
+
| `utils.py` | `collect_images`, `save_json`, `load_json`, `setup_logging` |
|
| 54 |
+
|
| 55 |
+
**Key design:** `detect.infer()` is the single OWLv2 inference implementation.
|
| 56 |
+
`app.py` chains SAM2 on top when mode == "Segmentation" — no duplication.
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## Device strategy
|
| 61 |
+
|
| 62 |
+
| Platform | Device | dtype |
|
| 63 |
+
|----------|--------|-------|
|
| 64 |
+
| Apple Silicon | `mps` | `float32` |
|
| 65 |
+
| Windows/Linux GPU | `cuda` | `float16` |
|
| 66 |
+
| CPU fallback | `cpu` | `float32` |
|
| 67 |
+
|
| 68 |
+
`PYTORCH_ENABLE_MPS_FALLBACK=1` must be set before torch is imported on MPS
|
| 69 |
+
(`.env` handles this). Without it, some OWLv2 ops raise `NotImplementedError`.
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## OWLv2 model
|
| 74 |
+
|
| 75 |
+
Default: `google/owlv2-large-patch14-finetuned` (~700 MB, cached in
|
| 76 |
+
`~/.cache/huggingface` after first download).
|
| 77 |
+
|
| 78 |
+
Override via env var: `AUTOLABEL_MODEL=google/owlv2-base-patch16`
|
| 79 |
+
|
| 80 |
+
| Variant | Size | Notes |
|
| 81 |
+
|---------|------|-------|
|
| 82 |
+
| `owlv2-base-patch16` | ~300 MB | Faster, lower accuracy |
|
| 83 |
+
| `owlv2-large-patch14` | ~700 MB | Good balance |
|
| 84 |
+
| `owlv2-large-patch14-finetuned` | ~700 MB | Default — pre-trained on LVIS/Objects365 |
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Dependency decisions
|
| 89 |
+
|
| 90 |
+
| Package | Why kept |
|
| 91 |
+
|---------|---------|
|
| 92 |
+
| `torch` / `torchvision` | OWLv2 + SAM2 inference |
|
| 93 |
+
| `transformers>=4.45` | OWLv2 and SAM2 models & processors |
|
| 94 |
+
| `pillow` | Image I/O and annotation drawing |
|
| 95 |
+
| `numpy` | Gradio image array interchange; mask arrays |
|
| 96 |
+
| `opencv-python` | `cv2.findContours` for mask → COCO polygon (SAM2) |
|
| 97 |
+
| `pydantic` / `pydantic-settings` | Type-safe config with env-var loading |
|
| 98 |
+
| `click` | CLI option parsing |
|
| 99 |
+
| `tqdm` | Progress bars in CLI batch runner |
|
| 100 |
+
| `python-dotenv` | Load `.env` before torch (MPS fallback) |
|
| 101 |
+
| `gradio` | Web UI |
|
| 102 |
+
|
| 103 |
+
Removed: `supervision` (unused), `matplotlib` (fine-tune charts gone),
|
| 104 |
+
`requests` (Label Studio gone).
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## Inference flow
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
PIL image
|
| 112 |
+
↓
|
| 113 |
+
detect.infer(image, processor, model, prompts, threshold, device, dtype)
|
| 114 |
+
↓
|
| 115 |
+
list[{label, score, box_xyxy}]
|
| 116 |
+
│
|
| 117 |
+
├─ Detection mode ──────────────────────────────────────────────────
|
| 118 |
+
│ ↓ used by app.py directly
|
| 119 |
+
│ ↓ (CLI: wrapped by detect_image → JSON)
|
| 120 |
+
│ ↓ export.build_coco → coco_export.json (bbox only, segmentation:[])
|
| 121 |
+
│
|
| 122 |
+
└─ Segmentation mode ───────────────────────────────────────────────
|
| 123 |
+
↓
|
| 124 |
+
segment.segment_with_boxes(image, detections, sam2_processor, sam2_model)
|
| 125 |
+
↓
|
| 126 |
+
list[{label, score, box_xyxy, mask (np.ndarray), segmentation (polygons)}]
|
| 127 |
+
↓ mask used for visualization overlay; dropped before JSON serialisation
|
| 128 |
+
↓ export.build_coco → coco_export.json (bbox + segmentation polygons)
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
---
|
| 132 |
+
|
| 133 |
+
## Batch export ZIP structure
|
| 134 |
+
|
| 135 |
+
```
|
| 136 |
+
autolabel_export.zip
|
| 137 |
+
├── coco_export.json # COCO format, dimensions match images below
|
| 138 |
+
└── images/
|
| 139 |
+
├── photo1.jpg # resized to chosen training size (e.g. 640×640)
|
| 140 |
+
└── photo2.jpg
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
COCO bounding boxes are in the coordinate space of the resized images.
|
| 144 |
+
|
| 145 |
+
---
|
| 146 |
+
|
| 147 |
+
## Known limitations
|
| 148 |
+
|
| 149 |
+
- OWLv2 is detection-only — bounding boxes, no masks.
|
| 150 |
+
- Objects < 32×32 px are often missed at default resolution.
|
| 151 |
+
- MPS inference is slower than CUDA but fast enough for development.
|
| 152 |
+
- Threshold default is 0.1 (intentionally low — easier to discard false
|
| 153 |
+
positives than recover missed objects).
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
## Fine-tuning (future)
|
| 158 |
+
|
| 159 |
+
The fine-tuning infrastructure is complete (`autolabel/finetune.py`,
|
| 160 |
+
`scripts/finetune_owlv2.py`) but not in active use. Workflow when ready:
|
| 161 |
+
|
| 162 |
+
1. Use the Batch tab to generate a labeled `coco_export.json`
|
| 163 |
+
2. Run `make finetune` (or `uv run python scripts/finetune_owlv2.py --help`)
|
| 164 |
+
3. Evaluate the fine-tuned model in the Test tab
|
Makefile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: setup detect export finetune clean app help
|
| 2 |
+
|
| 3 |
+
PYTHON := python
|
| 4 |
+
UV := uv
|
| 5 |
+
DATA_RAW := data/raw
|
| 6 |
+
DATA_DET := data/detections
|
| 7 |
+
DATA_LAB := data/labeled
|
| 8 |
+
|
| 9 |
+
help:
|
| 10 |
+
@echo "autolabel — OWLv2 labeling pipeline"
|
| 11 |
+
@echo ""
|
| 12 |
+
@echo "Targets:"
|
| 13 |
+
@echo " setup Install dependencies"
|
| 14 |
+
@echo " app Launch the Gradio UI (primary workflow)"
|
| 15 |
+
@echo " detect Run OWLv2 batch detection via CLI → data/detections/"
|
| 16 |
+
@echo " export Build COCO JSON from data/labeled/ via CLI"
|
| 17 |
+
@echo " finetune Fine-tune OWLv2 via CLI (future use)"
|
| 18 |
+
@echo " clean Remove generated JSON files (raw images untouched)"
|
| 19 |
+
|
| 20 |
+
setup:
|
| 21 |
+
$(UV) sync
|
| 22 |
+
@cp -n .env.example .env 2>/dev/null || true
|
| 23 |
+
@echo "Done. Run: make app"
|
| 24 |
+
|
| 25 |
+
app:
|
| 26 |
+
PYTORCH_ENABLE_MPS_FALLBACK=1 $(UV) run python app.py
|
| 27 |
+
|
| 28 |
+
detect:
|
| 29 |
+
$(UV) run python scripts/run_detection.py \
|
| 30 |
+
--image-dir $(DATA_RAW) \
|
| 31 |
+
--output-dir $(DATA_DET)
|
| 32 |
+
|
| 33 |
+
export:
|
| 34 |
+
$(UV) run python scripts/export_coco.py \
|
| 35 |
+
--labeled-dir $(DATA_LAB) \
|
| 36 |
+
--output $(DATA_LAB)/coco_export.json
|
| 37 |
+
|
| 38 |
+
finetune:
|
| 39 |
+
PYTORCH_ENABLE_MPS_FALLBACK=1 $(UV) run python scripts/finetune_owlv2.py \
|
| 40 |
+
--coco-json $(DATA_LAB)/coco_export.json \
|
| 41 |
+
--image-dir $(DATA_RAW)
|
| 42 |
+
|
| 43 |
+
clean:
|
| 44 |
+
@echo "Removing generated files..."
|
| 45 |
+
find $(DATA_DET) -name "*.json" -delete 2>/dev/null || true
|
| 46 |
+
find $(DATA_LAB) -name "*.json" -delete 2>/dev/null || true
|
| 47 |
+
@echo "Done. Raw images in $(DATA_RAW) are untouched."
|
README.md
CHANGED
|
@@ -1,12 +1,210 @@
|
|
| 1 |
---
|
| 2 |
title: LabelPlayground
|
| 3 |
-
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.8.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
---
|
| 2 |
title: LabelPlayground
|
| 3 |
+
app_file: app.py
|
|
|
|
|
|
|
| 4 |
sdk: gradio
|
| 5 |
sdk_version: 6.8.0
|
|
|
|
|
|
|
| 6 |
---
|
| 7 |
+
# autolabel — OWLv2 + SAM2 labeling pipeline
|
| 8 |
+
|
| 9 |
+
Auto-label images using **OWLv2** (open-vocabulary object detection) and
|
| 10 |
+
optionally **SAM2** (instance segmentation), then export a COCO dataset ready
|
| 11 |
+
for model fine-tuning.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Quickstart
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
# 1. Install
|
| 19 |
+
uv sync
|
| 20 |
+
|
| 21 |
+
# 2. Copy env file (sets PYTORCH_ENABLE_MPS_FALLBACK=1 for Apple Silicon)
|
| 22 |
+
cp .env.example .env
|
| 23 |
+
|
| 24 |
+
# 3. Launch
|
| 25 |
+
make app
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Models download automatically on first use and are cached in
|
| 29 |
+
`~/.cache/huggingface`. Nothing else is written to the project directory.
|
| 30 |
+
|
| 31 |
+
| Model | Size | Purpose |
|
| 32 |
+
|-------|------|---------|
|
| 33 |
+
| `owlv2-large-patch14-finetuned` | ~700 MB | Text → bounding boxes |
|
| 34 |
+
| `sam2-hiera-tiny` | ~160 MB | Box prompts → pixel masks |
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## How the app works
|
| 39 |
+
|
| 40 |
+
### Mode selector
|
| 41 |
+
|
| 42 |
+
Both tabs have a **Detection / Segmentation** radio button:
|
| 43 |
+
|
| 44 |
+
| Mode | What runs | COCO output |
|
| 45 |
+
|------|-----------|-------------|
|
| 46 |
+
| **Detection** | OWLv2 only | `bbox` + empty `segmentation: []` |
|
| 47 |
+
| **Segmentation** | OWLv2 → SAM2 | `bbox` + `segmentation` polygon list |
|
| 48 |
+
|
| 49 |
+
### How Detection and Segmentation work
|
| 50 |
+
|
| 51 |
+
**Detection** uses [OWLv2](https://huggingface.co/google/owlv2-large-patch14-finetuned) — an
|
| 52 |
+
open-vocabulary object detector. You give it a text prompt ("cup, bottle") and it returns
|
| 53 |
+
bounding boxes with confidence scores. No fixed class list, no retraining needed.
|
| 54 |
+
|
| 55 |
+
**Segmentation** uses the **Grounded SAM2** pattern — two models chained together:
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
Text prompts ("cup, bottle")
|
| 59 |
+
│
|
| 60 |
+
▼
|
| 61 |
+
OWLv2 ← understands text, produces bounding boxes
|
| 62 |
+
│
|
| 63 |
+
▼
|
| 64 |
+
Bounding boxes
|
| 65 |
+
│
|
| 66 |
+
▼
|
| 67 |
+
SAM2 ← understands spatial prompts, produces pixel masks
|
| 68 |
+
│
|
| 69 |
+
▼
|
| 70 |
+
Masks + COCO polygons
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
SAM2 (`sam2-hiera-tiny`) is a *prompt-based* segmenter — it accepts box, point, or mask
|
| 74 |
+
prompts but has no concept of text or class names. It can't answer "find me a cup"; it
|
| 75 |
+
can only answer "segment the object inside this box." OWLv2 is the **grounding** step
|
| 76 |
+
that translates your words into coordinates SAM2 can act on.
|
| 77 |
+
|
| 78 |
+
Both models run in Segmentation mode. Detection mode skips SAM2 entirely.
|
| 79 |
+
|
| 80 |
+
### 🧪 Test tab
|
| 81 |
+
|
| 82 |
+
Upload a single image, pick a mode, and type comma-separated object prompts.
|
| 83 |
+
Hit **Detect** to see an annotated preview alongside a results table (label,
|
| 84 |
+
confidence, bounding box). In Segmentation mode, pixel mask overlays are drawn
|
| 85 |
+
on top of the bounding boxes. Use this tab to dial in prompts and threshold
|
| 86 |
+
before a batch run — nothing is saved to disk.
|
| 87 |
+
|
| 88 |
+
### 📂 Batch tab
|
| 89 |
+
|
| 90 |
+
Upload multiple images and run the chosen mode on all of them at once. You get:
|
| 91 |
+
|
| 92 |
+
- An annotated **gallery** showing every image
|
| 93 |
+
- A **Download ZIP** button containing:
|
| 94 |
+
- `coco_export.json` — COCO-format annotations ready for fine-tuning
|
| 95 |
+
- `images/` — all images resized to your chosen training size
|
| 96 |
+
|
| 97 |
+
The size dropdown offers common YOLOX training resolutions (416 → 1024) plus
|
| 98 |
+
**As is** to keep the original dimensions. Coordinates in the COCO file match
|
| 99 |
+
the resized images exactly.
|
| 100 |
+
|
| 101 |
+
All artifacts live in a system temp directory — nothing is written to the project.
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
|
| 105 |
+
## Project layout
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
autolabel/
|
| 109 |
+
├── config.py # Pydantic settings, auto device detection (CUDA → MPS → CPU)
|
| 110 |
+
├── detect.py # OWLv2 inference — infer() shared by app + CLI
|
| 111 |
+
├── segment.py # SAM2 integration — box prompts → masks + COCO polygons
|
| 112 |
+
├── export.py # COCO JSON builder (no pycocotools); bbox + segmentation
|
| 113 |
+
├── finetune.py # Fine-tuning loop (future use)
|
| 114 |
+
└── utils.py # Shared helpers
|
| 115 |
+
scripts/
|
| 116 |
+
├── run_detection.py # CLI: batch detect → data/detections/
|
| 117 |
+
├── export_coco.py # CLI: build coco_export.json from data/labeled/
|
| 118 |
+
└── finetune_owlv2.py # CLI: fine-tune OWLv2 (future use)
|
| 119 |
+
app.py # Gradio web UI
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## CLI workflow
|
| 125 |
+
|
| 126 |
+
Detection and export can be driven from the command line without the UI:
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
# Detect all images in data/raw/ → data/detections/
|
| 130 |
+
make detect
|
| 131 |
+
|
| 132 |
+
# Custom prompts
|
| 133 |
+
uv run python scripts/run_detection.py --prompts "cup,mug,bottle"
|
| 134 |
+
|
| 135 |
+
# Force re-run on already-processed images
|
| 136 |
+
uv run python scripts/run_detection.py --force
|
| 137 |
+
|
| 138 |
+
# Build COCO JSON from data/labeled/
|
| 139 |
+
make export
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## Fine-tuning (future)
|
| 145 |
+
|
| 146 |
+
The fine-tuning infrastructure is already in place. Once you have a
|
| 147 |
+
`coco_export.json` from a Batch run:
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
make finetune
|
| 151 |
+
# or:
|
| 152 |
+
uv run python scripts/finetune_owlv2.py \
|
| 153 |
+
--coco-json data/labeled/coco_export.json \
|
| 154 |
+
--image-dir data/raw \
|
| 155 |
+
--epochs 10
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Key hyperparameters
|
| 159 |
+
|
| 160 |
+
| Parameter | Default | Notes |
|
| 161 |
+
|-----------|---------|-------|
|
| 162 |
+
| Epochs | 10 | More epochs → higher overfit risk on small datasets |
|
| 163 |
+
| Learning rate | 1e-4 | Applied to the detection head |
|
| 164 |
+
| Gradient accumulation | 4 | Effective batch size multiplier |
|
| 165 |
+
| Unfreeze backbone | off | Also trains the vision encoder — needs more data |
|
| 166 |
+
|
| 167 |
+
### Tips
|
| 168 |
+
|
| 169 |
+
- Start with **50–100 annotated images per class** minimum; 200–500 is better.
|
| 170 |
+
- Fine-tuned models are more confident — raise the threshold to 0.2–0.4.
|
| 171 |
+
- Leave the backbone frozen unless you have 500+ images per class.
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## Prerequisites
|
| 176 |
+
|
| 177 |
+
| Tool | Version | Notes |
|
| 178 |
+
|------|---------|-------|
|
| 179 |
+
| Python | **3.11.x** | Managed by uv |
|
| 180 |
+
| [uv](https://docs.astral.sh/uv/) | latest | `curl -LsSf https://astral.sh/uv/install.sh \| sh` |
|
| 181 |
+
| CUDA toolkit | 11.8+ | Windows/Linux GPU users only |
|
| 182 |
+
|
| 183 |
+
**Apple Silicon:** `PYTORCH_ENABLE_MPS_FALLBACK=1` is pre-set in `.env.example`.
|
| 184 |
+
|
| 185 |
+
**Windows/CUDA:** remove `PYTORCH_ENABLE_MPS_FALLBACK` from `.env`. For a
|
| 186 |
+
specific CUDA build:
|
| 187 |
+
|
| 188 |
+
```powershell
|
| 189 |
+
uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
|
| 190 |
+
uv sync
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
---
|
| 194 |
+
|
| 195 |
+
## Makefile targets
|
| 196 |
+
|
| 197 |
+
| Target | Description |
|
| 198 |
+
|--------|-------------|
|
| 199 |
+
| `make setup` | Install dependencies, copy `.env.example` |
|
| 200 |
+
| `make app` | Launch the Gradio UI |
|
| 201 |
+
| `make detect` | Batch detect via CLI → `data/detections/` |
|
| 202 |
+
| `make export` | Build COCO JSON via CLI |
|
| 203 |
+
| `make finetune` | Fine-tune OWLv2 via CLI |
|
| 204 |
+
| `make clean` | Delete generated JSONs (raw images untouched) |
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
## License
|
| 209 |
|
| 210 |
+
MIT
|
app.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py — OWLv2 / SAM2 image labeling UI
|
| 3 |
+
|
| 4 |
+
Tab 1 — Test: upload one image, pick Detection or Segmentation mode,
|
| 5 |
+
tune prompts/threshold/size, see instant annotated results.
|
| 6 |
+
Tab 2 — Batch: upload multiple images, run in the chosen mode, download a ZIP
|
| 7 |
+
containing resized images + coco_export.json.
|
| 8 |
+
|
| 9 |
+
All artifacts live in a system temp directory — nothing is written to the project.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import shutil
|
| 19 |
+
import tempfile
|
| 20 |
+
import zipfile
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import numpy as np
|
| 26 |
+
from dotenv import load_dotenv
|
| 27 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 28 |
+
|
| 29 |
+
load_dotenv()
|
| 30 |
+
|
| 31 |
+
from autolabel.config import settings
|
| 32 |
+
from autolabel.detect import infer as _owlv2_infer
|
| 33 |
+
from autolabel.export import build_coco
|
| 34 |
+
from autolabel.segment import load_sam2, segment_with_boxes
|
| 35 |
+
from autolabel.utils import save_json, setup_logging
|
| 36 |
+
|
| 37 |
+
setup_logging(logging.INFO)
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
# Temp directory for this session — cleaned up by the OS on reboot
|
| 41 |
+
_TMPDIR = Path(tempfile.mkdtemp(prefix="autolabel_"))
|
| 42 |
+
logger.info("Session temp dir: %s", _TMPDIR)
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Image sizing
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
_SIZE_OPTIONS = {
|
| 48 |
+
"As is": None,
|
| 49 |
+
"416 × 416": (416, 416),
|
| 50 |
+
"480 × 480": (480, 480),
|
| 51 |
+
"512 × 512": (512, 512),
|
| 52 |
+
"640 × 640": (640, 640),
|
| 53 |
+
"736 × 736": (736, 736),
|
| 54 |
+
"896 × 896": (896, 896),
|
| 55 |
+
"1024 × 1024": (1024, 1024),
|
| 56 |
+
}
|
| 57 |
+
_SIZE_LABELS = list(_SIZE_OPTIONS.keys())
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _resize(pil: Image.Image, size_label: str) -> Image.Image:
|
| 61 |
+
target = _SIZE_OPTIONS[size_label]
|
| 62 |
+
if target is None:
|
| 63 |
+
return pil
|
| 64 |
+
return pil.resize(target, Image.LANCZOS)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# Colours & annotation
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
_PALETTE = [
|
| 71 |
+
(52, 211, 153), (251, 146, 60), (96, 165, 250), (248, 113, 113),
|
| 72 |
+
(167, 139, 250),(250, 204, 21), (34, 211, 238), (244, 114, 182),
|
| 73 |
+
(74, 222, 128), (232, 121, 249), (125, 211, 252), (253, 186, 116),
|
| 74 |
+
(110, 231, 183),(196, 181, 253), (253, 164, 175), (134, 239, 172),
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _colour_for(label: str, prompts: list[str]) -> tuple[int, int, int]:
|
| 79 |
+
try:
|
| 80 |
+
return _PALETTE[prompts.index(label) % len(_PALETTE)]
|
| 81 |
+
except ValueError:
|
| 82 |
+
return _PALETTE[hash(label) % len(_PALETTE)]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _annotate(
|
| 86 |
+
pil_image: Image.Image,
|
| 87 |
+
detections: list[dict],
|
| 88 |
+
prompts: list[str],
|
| 89 |
+
mode: str = "Detection",
|
| 90 |
+
) -> Image.Image:
|
| 91 |
+
"""Draw bounding boxes (+ mask overlays in Segmentation mode) on *pil_image*."""
|
| 92 |
+
img = pil_image.copy().convert("RGBA")
|
| 93 |
+
|
| 94 |
+
# --- Segmentation: paint semi-transparent mask overlays first ---
|
| 95 |
+
if mode == "Segmentation":
|
| 96 |
+
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 97 |
+
for det in detections:
|
| 98 |
+
mask = det.get("mask")
|
| 99 |
+
if mask is None or not isinstance(mask, np.ndarray):
|
| 100 |
+
continue
|
| 101 |
+
r, g, b = _colour_for(det["label"], prompts)
|
| 102 |
+
mask_rgba = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
| 103 |
+
mask_rgba[mask] = [r, g, b, 100] # semi-transparent fill
|
| 104 |
+
overlay = Image.alpha_composite(overlay, Image.fromarray(mask_rgba, "RGBA"))
|
| 105 |
+
img = Image.alpha_composite(img, overlay)
|
| 106 |
+
|
| 107 |
+
# --- Bounding boxes and labels (both modes) ---
|
| 108 |
+
draw = ImageDraw.Draw(img, "RGBA")
|
| 109 |
+
try:
|
| 110 |
+
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size=18)
|
| 111 |
+
except Exception:
|
| 112 |
+
font = ImageFont.load_default()
|
| 113 |
+
|
| 114 |
+
for det in detections:
|
| 115 |
+
x1, y1, x2, y2 = det["box_xyxy"]
|
| 116 |
+
r, g, b = _colour_for(det["label"], prompts)
|
| 117 |
+
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b), width=3)
|
| 118 |
+
tag = f"{det['label']} {det['score']:.2f}"
|
| 119 |
+
bbox = draw.textbbox((x1, y1), tag, font=font)
|
| 120 |
+
draw.rectangle([bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3], fill=(r, g, b, 210))
|
| 121 |
+
draw.text((x1, y1), tag, fill=(255, 255, 255), font=font)
|
| 122 |
+
|
| 123 |
+
return img.convert("RGB")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
# OWLv2 model (cached)
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
_owlv2_cache: dict = {}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _get_owlv2():
|
| 133 |
+
if settings.model not in _owlv2_cache:
|
| 134 |
+
_owlv2_cache.clear()
|
| 135 |
+
from transformers import Owlv2ForObjectDetection, Owlv2Processor
|
| 136 |
+
logger.info("Loading OWLv2 %s on %s …", settings.model, settings.device)
|
| 137 |
+
processor = Owlv2Processor.from_pretrained(settings.model)
|
| 138 |
+
model = Owlv2ForObjectDetection.from_pretrained(
|
| 139 |
+
settings.model, torch_dtype=settings.torch_dtype
|
| 140 |
+
).to(settings.device)
|
| 141 |
+
model.eval()
|
| 142 |
+
_owlv2_cache[settings.model] = (processor, model)
|
| 143 |
+
logger.info("OWLv2 ready.")
|
| 144 |
+
return _owlv2_cache[settings.model]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
# SAM2 model (cached)
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
_sam2_cache: dict = {}
|
| 151 |
+
_SAM2_MODEL_ID = "facebook/sam2-hiera-tiny"
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_sam2():
|
| 155 |
+
if _SAM2_MODEL_ID not in _sam2_cache:
|
| 156 |
+
processor, model = load_sam2(settings.device, _SAM2_MODEL_ID)
|
| 157 |
+
_sam2_cache[_SAM2_MODEL_ID] = (processor, model)
|
| 158 |
+
return _sam2_cache[_SAM2_MODEL_ID]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ---------------------------------------------------------------------------
|
| 162 |
+
# Shared inference helpers
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
|
| 165 |
+
def _run_detection(
|
| 166 |
+
pil_image: Image.Image,
|
| 167 |
+
prompts: list[str],
|
| 168 |
+
threshold: float,
|
| 169 |
+
mode: str,
|
| 170 |
+
) -> list[dict]:
|
| 171 |
+
"""Run OWLv2 (and optionally SAM2) on *pil_image*.
|
| 172 |
+
|
| 173 |
+
Returns detections enriched with 'mask' and 'segmentation' when
|
| 174 |
+
mode == "Segmentation".
|
| 175 |
+
"""
|
| 176 |
+
processor, model = _get_owlv2()
|
| 177 |
+
detections = _owlv2_infer(
|
| 178 |
+
pil_image, processor, model, prompts, threshold,
|
| 179 |
+
settings.device, settings.torch_dtype,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if mode == "Segmentation" and detections:
|
| 183 |
+
sam2_processor, sam2_model = _get_sam2()
|
| 184 |
+
detections = segment_with_boxes(
|
| 185 |
+
pil_image, detections, sam2_processor, sam2_model, settings.device
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return detections
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _parse_prompts(text: str) -> list[str]:
|
| 192 |
+
return [p.strip() for p in text.split(",") if p.strip()]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Object crops
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
def _make_crops(
|
| 200 |
+
pil_image: Image.Image,
|
| 201 |
+
detections: list[dict],
|
| 202 |
+
prompts: list[str],
|
| 203 |
+
mode: str,
|
| 204 |
+
) -> list[tuple[Image.Image, str]]:
|
| 205 |
+
"""Return one (cropped PIL image, caption) pair per detection.
|
| 206 |
+
|
| 207 |
+
Detection mode: plain bounding-box crop with a coloured border.
|
| 208 |
+
Segmentation mode: tight crop around the mask's nonzero region; pixels
|
| 209 |
+
outside the mask are set to white for a clean cutout.
|
| 210 |
+
"""
|
| 211 |
+
crops: list[tuple[Image.Image, str]] = []
|
| 212 |
+
img_w, img_h = pil_image.size
|
| 213 |
+
|
| 214 |
+
for det in detections:
|
| 215 |
+
x1, y1, x2, y2 = det["box_xyxy"]
|
| 216 |
+
x1 = max(0, int(x1)); y1 = max(0, int(y1))
|
| 217 |
+
x2 = min(img_w, int(x2)); y2 = min(img_h, int(y2))
|
| 218 |
+
if x2 <= x1 or y2 <= y1:
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
r, g, b = _colour_for(det["label"], prompts)
|
| 222 |
+
|
| 223 |
+
if mode == "Segmentation":
|
| 224 |
+
mask = det.get("mask")
|
| 225 |
+
if mask is not None and isinstance(mask, np.ndarray):
|
| 226 |
+
# Find the tight bounding box of the mask's nonzero region
|
| 227 |
+
rows = np.any(mask, axis=1)
|
| 228 |
+
cols = np.any(mask, axis=0)
|
| 229 |
+
if rows.any() and cols.any():
|
| 230 |
+
r_min, r_max = int(np.where(rows)[0][0]), int(np.where(rows)[0][-1])
|
| 231 |
+
c_min, c_max = int(np.where(cols)[0][0]), int(np.where(cols)[0][-1])
|
| 232 |
+
mask_tight = mask[r_min:r_max + 1, c_min:c_max + 1]
|
| 233 |
+
region = np.array(
|
| 234 |
+
pil_image.crop((c_min, r_min, c_max + 1, r_max + 1)).convert("RGB")
|
| 235 |
+
)
|
| 236 |
+
# White background outside the mask
|
| 237 |
+
region[~mask_tight] = [255, 255, 255]
|
| 238 |
+
crop_rgb = Image.fromarray(region)
|
| 239 |
+
else:
|
| 240 |
+
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
|
| 241 |
+
else:
|
| 242 |
+
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
|
| 243 |
+
else:
|
| 244 |
+
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
|
| 245 |
+
|
| 246 |
+
# Coloured border
|
| 247 |
+
bordered = Image.new("RGB", (crop_rgb.width + 6, crop_rgb.height + 6), (r, g, b))
|
| 248 |
+
bordered.paste(crop_rgb, (3, 3))
|
| 249 |
+
|
| 250 |
+
caption = f"{det['label']} {det['score']:.2f}"
|
| 251 |
+
crops.append((bordered, caption))
|
| 252 |
+
|
| 253 |
+
return crops
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
# Tab 1 — Test
|
| 258 |
+
# ---------------------------------------------------------------------------
|
| 259 |
+
|
| 260 |
+
def run_test(
|
| 261 |
+
image_np: Optional[np.ndarray],
|
| 262 |
+
prompts_text: str,
|
| 263 |
+
threshold: float,
|
| 264 |
+
size_label: str,
|
| 265 |
+
mode: str,
|
| 266 |
+
):
|
| 267 |
+
if image_np is None or not prompts_text.strip():
|
| 268 |
+
return image_np, [], []
|
| 269 |
+
|
| 270 |
+
prompts = _parse_prompts(prompts_text)
|
| 271 |
+
if not prompts:
|
| 272 |
+
return image_np, [], []
|
| 273 |
+
|
| 274 |
+
pil = _resize(Image.fromarray(image_np), size_label)
|
| 275 |
+
detections = _run_detection(pil, prompts, threshold, mode)
|
| 276 |
+
|
| 277 |
+
table = [
|
| 278 |
+
[i + 1, d["label"], f"{d['score']:.3f}",
|
| 279 |
+
f"[{d['box_xyxy'][0]:.0f}, {d['box_xyxy'][1]:.0f}, "
|
| 280 |
+
f"{d['box_xyxy'][2]:.0f}, {d['box_xyxy'][3]:.0f}]"]
|
| 281 |
+
for i, d in enumerate(detections)
|
| 282 |
+
]
|
| 283 |
+
crops = _make_crops(pil, detections, prompts, mode)
|
| 284 |
+
return np.array(_annotate(pil, detections, prompts, mode)), table, crops
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ---------------------------------------------------------------------------
|
| 288 |
+
# Tab 2 — Batch
|
| 289 |
+
# ---------------------------------------------------------------------------
|
| 290 |
+
|
| 291 |
+
def run_batch(files, prompts_text: str, threshold: float, size_label: str, mode: str):
|
| 292 |
+
if not files or not prompts_text.strip():
|
| 293 |
+
return [], "Upload images and enter prompts to get started.", None
|
| 294 |
+
|
| 295 |
+
prompts = _parse_prompts(prompts_text)
|
| 296 |
+
if not prompts:
|
| 297 |
+
return [], "No valid prompts.", None
|
| 298 |
+
|
| 299 |
+
# Fresh temp dir for this run
|
| 300 |
+
run_dir = _TMPDIR / "current_run"
|
| 301 |
+
if run_dir.exists():
|
| 302 |
+
shutil.rmtree(run_dir)
|
| 303 |
+
images_dir = run_dir / "images"
|
| 304 |
+
images_dir.mkdir(parents=True)
|
| 305 |
+
|
| 306 |
+
gallery: list[Image.Image] = []
|
| 307 |
+
total_dets = 0
|
| 308 |
+
|
| 309 |
+
for f in files:
|
| 310 |
+
try:
|
| 311 |
+
src = Path(f.name if hasattr(f, "name") else str(f))
|
| 312 |
+
pil = _resize(Image.open(src).convert("RGB"), size_label)
|
| 313 |
+
w, h = pil.size
|
| 314 |
+
detections = _run_detection(pil, prompts, threshold, mode)
|
| 315 |
+
total_dets += len(detections)
|
| 316 |
+
|
| 317 |
+
# Save resized image (included in the ZIP)
|
| 318 |
+
img_name = src.name
|
| 319 |
+
pil.save(images_dir / img_name)
|
| 320 |
+
|
| 321 |
+
# Per-image JSON consumed by build_coco.
|
| 322 |
+
# Drop numpy mask arrays — they are not JSON-serialisable.
|
| 323 |
+
json_dets = [
|
| 324 |
+
{k: v for k, v in det.items() if k != "mask"}
|
| 325 |
+
for det in detections
|
| 326 |
+
]
|
| 327 |
+
save_json(
|
| 328 |
+
{"image_path": img_name, "image_width": w,
|
| 329 |
+
"image_height": h, "detections": json_dets},
|
| 330 |
+
run_dir / (src.stem + ".json"),
|
| 331 |
+
)
|
| 332 |
+
gallery.append(_annotate(pil, detections, prompts, mode))
|
| 333 |
+
except Exception:
|
| 334 |
+
logger.exception("Failed to process %s", f)
|
| 335 |
+
|
| 336 |
+
# Build COCO JSON
|
| 337 |
+
coco = build_coco(run_dir)
|
| 338 |
+
coco_path = run_dir / "coco_export.json"
|
| 339 |
+
if coco:
|
| 340 |
+
save_json(coco, coco_path)
|
| 341 |
+
|
| 342 |
+
# Package everything into a ZIP
|
| 343 |
+
zip_path = run_dir / "autolabel_export.zip"
|
| 344 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 345 |
+
if coco_path.exists():
|
| 346 |
+
zf.write(coco_path, "coco_export.json")
|
| 347 |
+
for img_file in sorted(images_dir.iterdir()):
|
| 348 |
+
zf.write(img_file, f"images/{img_file.name}")
|
| 349 |
+
|
| 350 |
+
n_ann = len(coco.get("annotations", [])) if coco else 0
|
| 351 |
+
size_note = f" · resized to {size_label}" if size_label != "As is" else ""
|
| 352 |
+
mode_note = f" · {mode.lower()}"
|
| 353 |
+
stats = (
|
| 354 |
+
f"{len(gallery)} image(s) · {total_dets} detection(s) · "
|
| 355 |
+
f"{n_ann} annotations{size_note}{mode_note}"
|
| 356 |
+
)
|
| 357 |
+
return gallery, stats, str(zip_path)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# ---------------------------------------------------------------------------
|
| 361 |
+
# UI
|
| 362 |
+
# ---------------------------------------------------------------------------
|
| 363 |
+
_DEFAULT_PROMPTS = ", ".join(settings.prompts[:8])
|
| 364 |
+
|
| 365 |
+
_HOW_IT_WORKS_MD = """\
|
| 366 |
+
## How it works
|
| 367 |
+
|
| 368 |
+
| Mode | Models | Output |
|
| 369 |
+
|------|--------|--------|
|
| 370 |
+
| **Detection** | OWLv2 | Bounding boxes + class labels |
|
| 371 |
+
| **Segmentation** | OWLv2 → SAM2 | Bounding boxes + pixel masks + COCO polygons |
|
| 372 |
+
|
| 373 |
+
**Detection** uses [OWLv2](https://huggingface.co/google/owlv2-large-patch14-finetuned), an
|
| 374 |
+
open-vocabulary detector that converts your text prompts directly into bounding boxes — no
|
| 375 |
+
fixed class list required.
|
| 376 |
+
|
| 377 |
+
**Segmentation** uses the **Grounded SAM2** pattern:
|
| 378 |
+
|
| 379 |
+
1. **OWLv2** reads your text prompts and produces bounding boxes
|
| 380 |
+
2. **SAM2** (`sam2-hiera-tiny`) takes each box as a spatial prompt and refines it into a
|
| 381 |
+
pixel-level mask
|
| 382 |
+
|
| 383 |
+
SAM2 has no concept of text — it only understands spatial prompts (boxes, points, masks).
|
| 384 |
+
OWLv2 acts as the *grounding* step, translating words into coordinates that SAM2 can use.
|
| 385 |
+
Both models must run in Segmentation mode; Detection mode skips SAM2 entirely.
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
with gr.Blocks(title="autolabel") as demo:
|
| 389 |
+
gr.Markdown("# autolabel — OWLv2 + SAM2")
|
| 390 |
+
|
| 391 |
+
with gr.Accordion("ℹ️ How it works", open=False):
|
| 392 |
+
gr.Markdown(_HOW_IT_WORKS_MD)
|
| 393 |
+
|
| 394 |
+
with gr.Tabs():
|
| 395 |
+
|
| 396 |
+
# ── Tab 1: Test ──────────────────────────────────────────────────
|
| 397 |
+
with gr.Tab("🧪 Test"):
|
| 398 |
+
with gr.Row():
|
| 399 |
+
with gr.Column(scale=1):
|
| 400 |
+
t1_image = gr.Image(label="Image — upload, paste, or pick a sample below",
|
| 401 |
+
type="numpy", sources=["upload", "clipboard"])
|
| 402 |
+
t1_mode = gr.Radio(
|
| 403 |
+
["Detection", "Segmentation"],
|
| 404 |
+
label="Mode", value="Detection",
|
| 405 |
+
info="Detection: OWLv2 → boxes only. "
|
| 406 |
+
"Segmentation: OWLv2 → boxes → SAM2 → pixel masks.",
|
| 407 |
+
)
|
| 408 |
+
t1_prompts = gr.Textbox(label="Prompts (comma-separated)",
|
| 409 |
+
value=_DEFAULT_PROMPTS, lines=2)
|
| 410 |
+
t1_threshold = gr.Slider(label="Threshold", minimum=0.01,
|
| 411 |
+
maximum=0.9, step=0.01, value=settings.threshold)
|
| 412 |
+
t1_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS,
|
| 413 |
+
value="As is")
|
| 414 |
+
t1_btn = gr.Button("Detect", variant="primary")
|
| 415 |
+
with gr.Column(scale=1):
|
| 416 |
+
t1_output = gr.Image(label="Result", type="numpy")
|
| 417 |
+
t1_table = gr.Dataframe(
|
| 418 |
+
headers=["#", "Label", "Score", "Box (xyxy)"],
|
| 419 |
+
row_count=(0, "dynamic"), column_count=(4, "fixed"),
|
| 420 |
+
)
|
| 421 |
+
t1_crops = gr.Gallery(
|
| 422 |
+
label="Object crops",
|
| 423 |
+
columns=4, height=220,
|
| 424 |
+
object_fit="contain", show_label=True,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Sample images — click any thumbnail to load it into the image input
|
| 428 |
+
_SAMPLES_DIR = Path(__file__).parent / "samples"
|
| 429 |
+
gr.Examples(
|
| 430 |
+
label="Sample images (click to load)",
|
| 431 |
+
examples=[
|
| 432 |
+
[str(_SAMPLES_DIR / "animals.jpg"), "Detection",
|
| 433 |
+
"crown, necklace, ball, animal eye", 0.40, "As is"],
|
| 434 |
+
[str(_SAMPLES_DIR / "kitchen.jpg"), "Detection",
|
| 435 |
+
"apple, banana, orange, broccoli, carrot, bottle, bowl", 0.40, "As is"],
|
| 436 |
+
[str(_SAMPLES_DIR / "dog.jpg"), "Detection",
|
| 437 |
+
"dog", 0.40, "As is"],
|
| 438 |
+
[str(_SAMPLES_DIR / "cat.jpg"), "Detection",
|
| 439 |
+
"cat", 0.40, "As is"]
|
| 440 |
+
],
|
| 441 |
+
inputs=[t1_image, t1_mode, t1_prompts, t1_threshold, t1_size],
|
| 442 |
+
examples_per_page=5,
|
| 443 |
+
cache_examples=False,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
t1_btn.click(
|
| 447 |
+
run_test,
|
| 448 |
+
inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode],
|
| 449 |
+
outputs=[t1_output, t1_table, t1_crops],
|
| 450 |
+
)
|
| 451 |
+
t1_prompts.submit(
|
| 452 |
+
run_test,
|
| 453 |
+
inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode],
|
| 454 |
+
outputs=[t1_output, t1_table, t1_crops],
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# ── Tab 2: Batch ─────────────────────────────────────────────────
|
| 458 |
+
with gr.Tab("📂 Batch"):
|
| 459 |
+
with gr.Row():
|
| 460 |
+
with gr.Column(scale=1):
|
| 461 |
+
t2_files = gr.File(label="Images", file_count="multiple",
|
| 462 |
+
file_types=["image"])
|
| 463 |
+
t2_mode = gr.Radio(
|
| 464 |
+
["Detection", "Segmentation"],
|
| 465 |
+
label="Mode", value="Detection",
|
| 466 |
+
info="Detection: OWLv2 → boxes only. "
|
| 467 |
+
"Segmentation: OWLv2 → boxes → SAM2 → pixel masks.",
|
| 468 |
+
)
|
| 469 |
+
t2_prompts = gr.Textbox(label="Prompts (comma-separated)",
|
| 470 |
+
value=_DEFAULT_PROMPTS, lines=2)
|
| 471 |
+
t2_threshold = gr.Slider(label="Threshold", minimum=0.01,
|
| 472 |
+
maximum=0.9, step=0.01, value=settings.threshold)
|
| 473 |
+
t2_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS,
|
| 474 |
+
value="640 × 640")
|
| 475 |
+
t2_btn = gr.Button("Run", variant="primary")
|
| 476 |
+
t2_stats = gr.Textbox(label="Stats", interactive=False)
|
| 477 |
+
t2_download = gr.DownloadButton(
|
| 478 |
+
label="Download ZIP (images + COCO JSON)",
|
| 479 |
+
visible=False, variant="secondary", size="sm",
|
| 480 |
+
)
|
| 481 |
+
with gr.Column(scale=2):
|
| 482 |
+
t2_gallery = gr.Gallery(label="Results", columns=3,
|
| 483 |
+
height="auto", object_fit="contain")
|
| 484 |
+
|
| 485 |
+
def _run_and_reveal(files, prompts_text, threshold, size_label, mode):
|
| 486 |
+
gallery, stats, zip_path = run_batch(
|
| 487 |
+
files, prompts_text, threshold, size_label, mode
|
| 488 |
+
)
|
| 489 |
+
return gallery, stats, gr.update(value=zip_path, visible=zip_path is not None)
|
| 490 |
+
|
| 491 |
+
t2_btn.click(
|
| 492 |
+
_run_and_reveal,
|
| 493 |
+
inputs=[t2_files, t2_prompts, t2_threshold, t2_size, t2_mode],
|
| 494 |
+
outputs=[t2_gallery, t2_stats, t2_download],
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
demo.queue(max_size=5)
|
| 498 |
+
|
| 499 |
+
if __name__ == "__main__":
|
| 500 |
+
demo.launch(server_name="0.0.0.0", server_port=7860,
|
| 501 |
+
share=True, inbrowser=True, theme=gr.themes.Soft())
|
autolabel/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
autolabel — OWLv2-powered auto-labeling pipeline for household object detection.
|
| 3 |
+
|
| 4 |
+
Pipeline:
|
| 5 |
+
1. detect — run OWLv2 on images, produce per-image detection JSON
|
| 6 |
+
2. export — convert detections to COCO JSON for fine-tuning
|
| 7 |
+
|
| 8 |
+
Primary interface: app.py (Gradio web UI)
|
| 9 |
+
CLI interface: scripts/run_detection.py, scripts/export_coco.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
__version__ = "0.1.0"
|
autolabel/config.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config.py — Pydantic settings for the autolabel pipeline.
|
| 3 |
+
|
| 4 |
+
Handles:
|
| 5 |
+
- Auto device detection: CUDA → MPS → CPU
|
| 6 |
+
- OWLv2 model selection
|
| 7 |
+
- Detection thresholds
|
| 8 |
+
- Data paths derived from project root
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from pydantic import Field, field_validator, model_validator
|
| 20 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Project root — two levels up from this file (autolabel/config.py → project/)
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _detect_device() -> str:
|
| 31 |
+
"""Return the best available torch device string."""
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
device = "cuda"
|
| 34 |
+
name = torch.cuda.get_device_name(0)
|
| 35 |
+
logger.info("Device selected: CUDA (%s)", name)
|
| 36 |
+
elif torch.backends.mps.is_available():
|
| 37 |
+
device = "mps"
|
| 38 |
+
logger.info(
|
| 39 |
+
"Device selected: MPS (Apple Silicon). "
|
| 40 |
+
"Set PYTORCH_ENABLE_MPS_FALLBACK=1 for unsupported ops."
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
device = "cpu"
|
| 44 |
+
logger.warning("Device selected: CPU — no CUDA or MPS found. Inference will be slow.")
|
| 45 |
+
return device
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Settings(BaseSettings):
|
| 49 |
+
"""Central configuration for the autolabel pipeline.
|
| 50 |
+
|
| 51 |
+
All values can be overridden via environment variables prefixed with
|
| 52 |
+
AUTOLABEL_ (e.g., AUTOLABEL_THRESHOLD=0.2).
|
| 53 |
+
The .env file is loaded automatically from the project root.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
model_config = SettingsConfigDict(
|
| 57 |
+
env_prefix="AUTOLABEL_",
|
| 58 |
+
env_file=str(PROJECT_ROOT / ".env"),
|
| 59 |
+
env_file_encoding="utf-8",
|
| 60 |
+
case_sensitive=False,
|
| 61 |
+
extra="ignore",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# ------------------------------------------------------------------
|
| 65 |
+
# Device
|
| 66 |
+
# ------------------------------------------------------------------
|
| 67 |
+
device: str = Field(
|
| 68 |
+
default="",
|
| 69 |
+
description="Torch device override. Leave empty for auto-detection.",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
# OWLv2 model
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
model: str = Field(
|
| 76 |
+
default="google/owlv2-large-patch14-finetuned",
|
| 77 |
+
description="Hugging Face model identifier for OWLv2.",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# Detection
|
| 82 |
+
# ------------------------------------------------------------------
|
| 83 |
+
threshold: float = Field(
|
| 84 |
+
default=0.1,
|
| 85 |
+
ge=0.0,
|
| 86 |
+
le=1.0,
|
| 87 |
+
description="Minimum confidence score to keep a detection.",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
prompts: List[str] = Field(
|
| 91 |
+
default=[
|
| 92 |
+
"cup",
|
| 93 |
+
"bottle",
|
| 94 |
+
"keyboard",
|
| 95 |
+
"computer mouse",
|
| 96 |
+
"cell phone",
|
| 97 |
+
"remote control",
|
| 98 |
+
"book",
|
| 99 |
+
"plant",
|
| 100 |
+
"bowl",
|
| 101 |
+
"mug",
|
| 102 |
+
"laptop",
|
| 103 |
+
"monitor",
|
| 104 |
+
"pen",
|
| 105 |
+
"scissors",
|
| 106 |
+
"stapler",
|
| 107 |
+
"headphones",
|
| 108 |
+
"wallet",
|
| 109 |
+
"keys",
|
| 110 |
+
"glasses",
|
| 111 |
+
"candle",
|
| 112 |
+
"backpack",
|
| 113 |
+
"notebook",
|
| 114 |
+
"water bottle",
|
| 115 |
+
"coffee cup",
|
| 116 |
+
"charger",
|
| 117 |
+
],
|
| 118 |
+
description="Text prompts sent to OWLv2 for open-vocabulary detection.",
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# ------------------------------------------------------------------
|
| 122 |
+
# Paths
|
| 123 |
+
# ------------------------------------------------------------------
|
| 124 |
+
raw_dir: Path = Field(
|
| 125 |
+
default=PROJECT_ROOT / "data" / "raw",
|
| 126 |
+
description="Input images directory.",
|
| 127 |
+
)
|
| 128 |
+
detections_dir: Path = Field(
|
| 129 |
+
default=PROJECT_ROOT / "data" / "detections",
|
| 130 |
+
description="OWLv2 output JSON files.",
|
| 131 |
+
)
|
| 132 |
+
labeled_dir: Path = Field(
|
| 133 |
+
default=PROJECT_ROOT / "data" / "labeled",
|
| 134 |
+
description="Reviewed and accepted annotation JSON files.",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# ------------------------------------------------------------------
|
| 138 |
+
# Validators
|
| 139 |
+
# ------------------------------------------------------------------
|
| 140 |
+
@field_validator("threshold", mode="before")
|
| 141 |
+
@classmethod
|
| 142 |
+
def _coerce_threshold(cls, v: object) -> float:
|
| 143 |
+
return float(v) # type: ignore[arg-type]
|
| 144 |
+
|
| 145 |
+
@field_validator("prompts", mode="before")
|
| 146 |
+
@classmethod
|
| 147 |
+
def _parse_prompts(cls, v: object) -> List[str]:
|
| 148 |
+
"""Allow comma-separated string from env var."""
|
| 149 |
+
if isinstance(v, str):
|
| 150 |
+
return [p.strip() for p in v.split(",") if p.strip()]
|
| 151 |
+
return list(v) # type: ignore[arg-type]
|
| 152 |
+
|
| 153 |
+
@model_validator(mode="after")
|
| 154 |
+
def _resolve_device(self) -> "Settings":
|
| 155 |
+
if not self.device:
|
| 156 |
+
self.device = _detect_device()
|
| 157 |
+
else:
|
| 158 |
+
logger.info("Device override from env/config: %s", self.device)
|
| 159 |
+
return self
|
| 160 |
+
|
| 161 |
+
@model_validator(mode="after")
|
| 162 |
+
def _ensure_dirs(self) -> "Settings":
|
| 163 |
+
for path in (self.raw_dir, self.detections_dir, self.labeled_dir):
|
| 164 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
return self
|
| 166 |
+
|
| 167 |
+
# ------------------------------------------------------------------
|
| 168 |
+
# Convenience
|
| 169 |
+
# ------------------------------------------------------------------
|
| 170 |
+
@property
|
| 171 |
+
def torch_dtype(self) -> torch.dtype:
|
| 172 |
+
"""fp16 on CUDA, fp32 everywhere else (MPS doesn't support fp16 fully)."""
|
| 173 |
+
return torch.float16 if self.device == "cuda" else torch.float32
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Module-level singleton — import this everywhere.
|
| 177 |
+
settings = Settings()
|
autolabel/detect.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
detect.py — OWLv2 runner.
|
| 3 |
+
|
| 4 |
+
Loads google/owlv2-* from Hugging Face, runs open-vocabulary detection on
|
| 5 |
+
every image in a folder, and saves per-image JSON with boxes, scores, and
|
| 6 |
+
labels. Already-processed images are skipped unless --force is used.
|
| 7 |
+
|
| 8 |
+
Output JSON schema (per image):
|
| 9 |
+
{
|
| 10 |
+
"image_path": "/abs/path/to/image.jpg",
|
| 11 |
+
"image_width": 1920,
|
| 12 |
+
"image_height": 1080,
|
| 13 |
+
"detections": [
|
| 14 |
+
{
|
| 15 |
+
"label": "cup",
|
| 16 |
+
"score": 0.83,
|
| 17 |
+
"box_xyxy": [x1, y1, x2, y2] # absolute pixel coords
|
| 18 |
+
},
|
| 19 |
+
...
|
| 20 |
+
]
|
| 21 |
+
}
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import List, Optional
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
from transformers import Owlv2ForObjectDetection, Owlv2Processor
|
| 34 |
+
|
| 35 |
+
from autolabel.config import Settings, settings as default_settings
|
| 36 |
+
from autolabel.utils import (
|
| 37 |
+
collect_images,
|
| 38 |
+
detection_json_path,
|
| 39 |
+
save_json,
|
| 40 |
+
setup_logging,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Core runner
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def load_model(cfg: Settings) -> tuple[Owlv2Processor, Owlv2ForObjectDetection]:
|
| 51 |
+
"""Download (or load from cache) OWLv2 processor and model."""
|
| 52 |
+
logger.info("Loading OWLv2 model: %s", cfg.model)
|
| 53 |
+
processor = Owlv2Processor.from_pretrained(cfg.model)
|
| 54 |
+
model = Owlv2ForObjectDetection.from_pretrained(
|
| 55 |
+
cfg.model,
|
| 56 |
+
torch_dtype=cfg.torch_dtype,
|
| 57 |
+
)
|
| 58 |
+
model = model.to(cfg.device)
|
| 59 |
+
model.eval()
|
| 60 |
+
logger.info("Model loaded on device: %s dtype: %s", cfg.device, cfg.torch_dtype)
|
| 61 |
+
return processor, model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def infer(
|
| 65 |
+
image: Image.Image,
|
| 66 |
+
processor: Owlv2Processor,
|
| 67 |
+
model: Owlv2ForObjectDetection,
|
| 68 |
+
prompts: List[str],
|
| 69 |
+
threshold: float,
|
| 70 |
+
device: str,
|
| 71 |
+
torch_dtype: torch.dtype,
|
| 72 |
+
) -> List[dict]:
|
| 73 |
+
"""Run OWLv2 on a PIL image and return a list of detection dicts.
|
| 74 |
+
|
| 75 |
+
This is the shared inference primitive used by both the web app and the
|
| 76 |
+
CLI batch runner. Returns detections sorted by descending score.
|
| 77 |
+
"""
|
| 78 |
+
width, height = image.size
|
| 79 |
+
|
| 80 |
+
inputs = processor(text=[prompts], images=image, return_tensors="pt")
|
| 81 |
+
inputs = {
|
| 82 |
+
k: (v.to(device=device, dtype=torch_dtype) if v.is_floating_point() else v.to(device=device))
|
| 83 |
+
for k, v in inputs.items()
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
outputs = model(**inputs)
|
| 88 |
+
|
| 89 |
+
target_sizes = torch.tensor([[height, width]], device=device)
|
| 90 |
+
results = processor.post_process_grounded_object_detection(
|
| 91 |
+
outputs, target_sizes=target_sizes, threshold=threshold,
|
| 92 |
+
)[0]
|
| 93 |
+
|
| 94 |
+
detections = [
|
| 95 |
+
{
|
| 96 |
+
"label": prompts[label_idx],
|
| 97 |
+
"score": round(float(score), 4),
|
| 98 |
+
"box_xyxy": [round(c, 1) for c in box],
|
| 99 |
+
}
|
| 100 |
+
for box, score, label_idx in zip(
|
| 101 |
+
results["boxes"].cpu().tolist(),
|
| 102 |
+
results["scores"].cpu().tolist(),
|
| 103 |
+
results["labels"].cpu().tolist(),
|
| 104 |
+
)
|
| 105 |
+
]
|
| 106 |
+
detections.sort(key=lambda d: d["score"], reverse=True)
|
| 107 |
+
return detections
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def detect_image(
|
| 111 |
+
image_path: Path,
|
| 112 |
+
processor: Owlv2Processor,
|
| 113 |
+
model: Owlv2ForObjectDetection,
|
| 114 |
+
prompts: List[str],
|
| 115 |
+
threshold: float,
|
| 116 |
+
device: str,
|
| 117 |
+
torch_dtype: torch.dtype,
|
| 118 |
+
) -> dict:
|
| 119 |
+
"""Run OWLv2 on an image file and return the structured detection dict."""
|
| 120 |
+
image = Image.open(image_path).convert("RGB")
|
| 121 |
+
width, height = image.size
|
| 122 |
+
detections = infer(image, processor, model, prompts, threshold, device, torch_dtype)
|
| 123 |
+
return {
|
| 124 |
+
"image_path": str(image_path.resolve()),
|
| 125 |
+
"image_width": width,
|
| 126 |
+
"image_height": height,
|
| 127 |
+
"detections": detections,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def run_detection(
|
| 132 |
+
image_dir: Path,
|
| 133 |
+
output_dir: Path,
|
| 134 |
+
prompts: Optional[List[str]] = None,
|
| 135 |
+
cfg: Optional[Settings] = None,
|
| 136 |
+
force: bool = False,
|
| 137 |
+
) -> None:
|
| 138 |
+
"""
|
| 139 |
+
Run OWLv2 detection on all images in *image_dir*.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
image_dir: Folder containing input images.
|
| 143 |
+
output_dir: Folder where per-image JSON files are written.
|
| 144 |
+
prompts: Override text prompts (uses cfg.prompts if None).
|
| 145 |
+
cfg: Settings instance (uses module default if None).
|
| 146 |
+
force: Re-process images that already have a detection JSON.
|
| 147 |
+
"""
|
| 148 |
+
cfg = cfg or default_settings
|
| 149 |
+
active_prompts = prompts or cfg.prompts
|
| 150 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
|
| 152 |
+
images = collect_images(image_dir)
|
| 153 |
+
if not images:
|
| 154 |
+
logger.warning("No images found in %s", image_dir)
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
processor, model = load_model(cfg)
|
| 158 |
+
|
| 159 |
+
skipped = 0
|
| 160 |
+
for image_path in tqdm(images, desc="Detecting", unit="img"):
|
| 161 |
+
out_path = detection_json_path(image_path, output_dir)
|
| 162 |
+
if out_path.exists() and not force:
|
| 163 |
+
logger.debug("Skipping (already processed): %s", image_path.name)
|
| 164 |
+
skipped += 1
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
result = detect_image(
|
| 169 |
+
image_path=image_path,
|
| 170 |
+
processor=processor,
|
| 171 |
+
model=model,
|
| 172 |
+
prompts=active_prompts,
|
| 173 |
+
threshold=cfg.threshold,
|
| 174 |
+
device=cfg.device,
|
| 175 |
+
torch_dtype=cfg.torch_dtype,
|
| 176 |
+
)
|
| 177 |
+
save_json(result, out_path)
|
| 178 |
+
logger.debug(
|
| 179 |
+
"%s → %d detection(s)", image_path.name, len(result["detections"])
|
| 180 |
+
)
|
| 181 |
+
except Exception:
|
| 182 |
+
logger.exception("Failed to process %s", image_path)
|
| 183 |
+
|
| 184 |
+
logger.info(
|
| 185 |
+
"Detection complete. Processed: %d Skipped: %d",
|
| 186 |
+
len(images) - skipped,
|
| 187 |
+
skipped,
|
| 188 |
+
)
|
autolabel/export.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
export.py — COCO JSON exporter.
|
| 3 |
+
|
| 4 |
+
Reads all per-image labeled JSON files from the labeled/ directory and
|
| 5 |
+
assembles a valid COCO-format JSON file. No pycocotools dependency — the
|
| 6 |
+
format is built from scratch.
|
| 7 |
+
|
| 8 |
+
COCO format reference:
|
| 9 |
+
https://cocodataset.org/#format-data
|
| 10 |
+
|
| 11 |
+
Output structure:
|
| 12 |
+
{
|
| 13 |
+
"info": {...},
|
| 14 |
+
"licenses": [],
|
| 15 |
+
"categories": [{"id": 1, "name": "cup", "supercategory": "object"}, ...],
|
| 16 |
+
"images": [{"id": 1, "file_name": "img.jpg", "width": W, "height": H}, ...],
|
| 17 |
+
"annotations": [
|
| 18 |
+
{
|
| 19 |
+
"id": 1,
|
| 20 |
+
"image_id": 1,
|
| 21 |
+
"category_id": 2,
|
| 22 |
+
"bbox": [x, y, w, h], # COCO uses [x_min, y_min, width, height]
|
| 23 |
+
"area": w * h,
|
| 24 |
+
"iscrowd": 0
|
| 25 |
+
},
|
| 26 |
+
...
|
| 27 |
+
]
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import logging
|
| 34 |
+
from datetime import datetime, timezone
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Optional
|
| 37 |
+
|
| 38 |
+
from autolabel.config import settings as default_settings, Settings
|
| 39 |
+
from autolabel.utils import load_json, save_json
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _xyxy_to_xywh(box: list[float]) -> list[float]:
|
| 45 |
+
"""Convert [x1, y1, x2, y2] → [x, y, width, height] (COCO format)."""
|
| 46 |
+
x1, y1, x2, y2 = box
|
| 47 |
+
return [x1, y1, x2 - x1, y2 - y1]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_coco(labeled_dir: Path) -> dict:
|
| 51 |
+
"""
|
| 52 |
+
Read all JSON files in *labeled_dir* and build a COCO-format dict.
|
| 53 |
+
|
| 54 |
+
Returns the COCO dict ready for serialisation.
|
| 55 |
+
"""
|
| 56 |
+
json_files = sorted(labeled_dir.glob("*.json"))
|
| 57 |
+
# Exclude any existing coco_export.json to avoid self-inclusion
|
| 58 |
+
json_files = [f for f in json_files if f.name != "coco_export.json"]
|
| 59 |
+
|
| 60 |
+
if not json_files:
|
| 61 |
+
logger.warning("No labeled JSON files found in %s", labeled_dir)
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
logger.info("Building COCO export from %d file(s)…", len(json_files))
|
| 65 |
+
|
| 66 |
+
# Collect all category names in encounter order, deduplicating
|
| 67 |
+
category_index: dict[str, int] = {} # name → category_id
|
| 68 |
+
images_list: list[dict] = []
|
| 69 |
+
annotations_list: list[dict] = []
|
| 70 |
+
|
| 71 |
+
ann_id = 1
|
| 72 |
+
|
| 73 |
+
for img_id, json_path in enumerate(json_files, start=1):
|
| 74 |
+
data = load_json(json_path)
|
| 75 |
+
|
| 76 |
+
image_path = Path(data["image_path"])
|
| 77 |
+
images_list.append(
|
| 78 |
+
{
|
| 79 |
+
"id": img_id,
|
| 80 |
+
"file_name": image_path.name,
|
| 81 |
+
"width": data["image_width"],
|
| 82 |
+
"height": data["image_height"],
|
| 83 |
+
}
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
for det in data.get("detections", []):
|
| 87 |
+
label: str = det["label"]
|
| 88 |
+
if label not in category_index:
|
| 89 |
+
category_index[label] = len(category_index) + 1
|
| 90 |
+
|
| 91 |
+
cat_id = category_index[label]
|
| 92 |
+
xywh = _xyxy_to_xywh(det["box_xyxy"])
|
| 93 |
+
area = round(xywh[2] * xywh[3], 2)
|
| 94 |
+
|
| 95 |
+
annotations_list.append(
|
| 96 |
+
{
|
| 97 |
+
"id": ann_id,
|
| 98 |
+
"image_id": img_id,
|
| 99 |
+
"category_id": cat_id,
|
| 100 |
+
"bbox": [round(v, 1) for v in xywh],
|
| 101 |
+
"area": area,
|
| 102 |
+
"iscrowd": 0,
|
| 103 |
+
"segmentation": det.get("segmentation", []),
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
ann_id += 1
|
| 107 |
+
|
| 108 |
+
categories = [
|
| 109 |
+
{"id": cat_id, "name": name, "supercategory": "object"}
|
| 110 |
+
for name, cat_id in sorted(category_index.items(), key=lambda x: x[1])
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
coco = {
|
| 114 |
+
"info": {
|
| 115 |
+
"description": "autolabel — OWLv2 household object dataset",
|
| 116 |
+
"version": "1.0",
|
| 117 |
+
"year": datetime.now(tz=timezone.utc).year,
|
| 118 |
+
"date_created": datetime.now(tz=timezone.utc).isoformat(),
|
| 119 |
+
},
|
| 120 |
+
"licenses": [],
|
| 121 |
+
"categories": categories,
|
| 122 |
+
"images": images_list,
|
| 123 |
+
"annotations": annotations_list,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
logger.info(
|
| 127 |
+
"COCO export: %d image(s), %d annotation(s), %d categor(ies)",
|
| 128 |
+
len(images_list),
|
| 129 |
+
len(annotations_list),
|
| 130 |
+
len(categories),
|
| 131 |
+
)
|
| 132 |
+
return coco
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def run_export(
|
| 136 |
+
labeled_dir: Path,
|
| 137 |
+
output_path: Path,
|
| 138 |
+
cfg: Optional[Settings] = None,
|
| 139 |
+
) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Build COCO JSON from *labeled_dir* and write to *output_path*.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
labeled_dir: Directory containing per-image labeled JSON files.
|
| 145 |
+
output_path: Destination path for the COCO JSON file.
|
| 146 |
+
cfg: Settings instance (module default if None).
|
| 147 |
+
"""
|
| 148 |
+
_ = cfg or default_settings # reserved for future use
|
| 149 |
+
|
| 150 |
+
coco = build_coco(labeled_dir)
|
| 151 |
+
if not coco:
|
| 152 |
+
logger.error("Nothing to export.")
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
save_json(coco, output_path)
|
| 156 |
+
logger.info("COCO JSON written → %s", output_path)
|
autolabel/finetune.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune.py — OWLv2 fine-tuning core: dataset, loss, and training loop.
|
| 3 |
+
|
| 4 |
+
Architecture notes
|
| 5 |
+
------------------
|
| 6 |
+
In transformers v5.x, Owlv2ForObjectDetection.forward() does NOT compute loss
|
| 7 |
+
internally — the loss/loss_dict output fields are always None. We compute it
|
| 8 |
+
ourselves using the DETR-style utilities already shipped in transformers:
|
| 9 |
+
|
| 10 |
+
• HungarianMatcher — optimal bipartite assignment (scipy lsa under the hood)
|
| 11 |
+
• generalized_box_iou / center_to_corners_format — from the same module
|
| 12 |
+
|
| 13 |
+
Loss used (OWLv2 uses sigmoid, not softmax, so binary CE fits better):
|
| 14 |
+
total = λ_cls * L_bce + λ_bbox * L_l1 + λ_giou * L_giou
|
| 15 |
+
|
| 16 |
+
Freezing strategy (default: train detection heads only)
|
| 17 |
+
Frozen : owlv2.vision_model, owlv2.text_model, owlv2.text_projection,
|
| 18 |
+
owlv2.visual_projection
|
| 19 |
+
Trained : box_head, class_head, objectness_head, layer_norm, owlv2.logit_scale
|
| 20 |
+
Optional: --unfreeze-vision also trains the ViT image encoder
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import json
|
| 26 |
+
import logging
|
| 27 |
+
import time
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Callable, Optional
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
from transformers import Owlv2ForObjectDetection, Owlv2Processor
|
| 38 |
+
from transformers.loss.loss_for_object_detection import (
|
| 39 |
+
HungarianMatcher,
|
| 40 |
+
center_to_corners_format,
|
| 41 |
+
generalized_box_iou,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TrainingStoppedError(Exception):
|
| 48 |
+
"""Raised by a progress_callback to cancel training mid-epoch."""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Training hyperparameters (all overridable via the CLI)
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class FinetuneConfig:
|
| 57 |
+
coco_json: Path
|
| 58 |
+
image_dir: Path
|
| 59 |
+
output_dir: Path
|
| 60 |
+
|
| 61 |
+
model_name: str = "google/owlv2-large-patch14-finetuned"
|
| 62 |
+
device: str = "cpu"
|
| 63 |
+
torch_dtype: torch.dtype = torch.float32
|
| 64 |
+
|
| 65 |
+
epochs: int = 10
|
| 66 |
+
batch_size: int = 1
|
| 67 |
+
grad_accum_steps: int = 4 # effective batch = batch_size * grad_accum_steps
|
| 68 |
+
lr: float = 1e-4 # for detection heads
|
| 69 |
+
backbone_lr: float = 0.0 # for vision encoder (0 = frozen)
|
| 70 |
+
weight_decay: float = 1e-4
|
| 71 |
+
grad_clip: float = 0.1
|
| 72 |
+
warmup_steps: int = 50
|
| 73 |
+
|
| 74 |
+
val_split: float = 0.2
|
| 75 |
+
save_every: int = 1 # save checkpoint every N epochs
|
| 76 |
+
|
| 77 |
+
# Loss weights
|
| 78 |
+
lambda_cls: float = 1.0
|
| 79 |
+
lambda_bbox: float = 5.0
|
| 80 |
+
lambda_giou: float = 2.0
|
| 81 |
+
|
| 82 |
+
# Hungarian matcher costs (separate from loss weights)
|
| 83 |
+
class_cost: float = 1.0
|
| 84 |
+
bbox_cost: float = 5.0
|
| 85 |
+
giou_cost: float = 2.0
|
| 86 |
+
|
| 87 |
+
resume_from: Optional[Path] = None
|
| 88 |
+
unfreeze_vision: bool = False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
# COCO dataset
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
|
| 95 |
+
def _coco_xywh_to_cxcywh_norm(
|
| 96 |
+
bbox: list[float], img_w: int, img_h: int
|
| 97 |
+
) -> list[float]:
|
| 98 |
+
"""COCO [x, y, w, h] pixel → normalised [cx, cy, w, h] in [0, 1]."""
|
| 99 |
+
x, y, w, h = bbox
|
| 100 |
+
return [
|
| 101 |
+
(x + w / 2) / img_w,
|
| 102 |
+
(y + h / 2) / img_h,
|
| 103 |
+
w / img_w,
|
| 104 |
+
h / img_h,
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class CocoOwlv2Dataset(Dataset):
|
| 109 |
+
"""
|
| 110 |
+
Loads a COCO-format JSON and serves (image, boxes, class_labels) items.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
coco_json_path: Path to the COCO JSON file.
|
| 114 |
+
image_dir: Directory where image files live (matched by file_name).
|
| 115 |
+
categories: List of category name strings (defines the query order).
|
| 116 |
+
If None, derived from the JSON.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
coco_json_path: Path,
|
| 122 |
+
image_dir: Path,
|
| 123 |
+
categories: Optional[list[str]] = None,
|
| 124 |
+
) -> None:
|
| 125 |
+
with coco_json_path.open() as fh:
|
| 126 |
+
coco = json.load(fh)
|
| 127 |
+
|
| 128 |
+
# Build category id → 0-based index into text queries list
|
| 129 |
+
if categories is None:
|
| 130 |
+
categories = [c["name"] for c in sorted(coco["categories"], key=lambda c: c["id"])]
|
| 131 |
+
self.categories = categories
|
| 132 |
+
cat_id_to_idx = {c["id"]: i for i, c in enumerate(
|
| 133 |
+
sorted(coco["categories"], key=lambda c: c["id"])
|
| 134 |
+
)}
|
| 135 |
+
|
| 136 |
+
# Index annotations by image_id
|
| 137 |
+
ann_by_image: dict[int, list[dict]] = {}
|
| 138 |
+
for ann in coco["annotations"]:
|
| 139 |
+
ann_by_image.setdefault(ann["image_id"], []).append(ann)
|
| 140 |
+
|
| 141 |
+
# Build valid items (images that have at least one annotation)
|
| 142 |
+
self.items: list[dict] = []
|
| 143 |
+
for img_meta in coco["images"]:
|
| 144 |
+
anns = ann_by_image.get(img_meta["id"], [])
|
| 145 |
+
if not anns:
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
img_path = image_dir / img_meta["file_name"]
|
| 149 |
+
if not img_path.exists():
|
| 150 |
+
logger.warning("Image not found, skipping: %s", img_path)
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
boxes_norm = []
|
| 154 |
+
class_labels = []
|
| 155 |
+
w, h = img_meta["width"], img_meta["height"]
|
| 156 |
+
for ann in anns:
|
| 157 |
+
boxes_norm.append(_coco_xywh_to_cxcywh_norm(ann["bbox"], w, h))
|
| 158 |
+
class_labels.append(cat_id_to_idx[ann["category_id"]])
|
| 159 |
+
|
| 160 |
+
self.items.append(
|
| 161 |
+
{
|
| 162 |
+
"image_path": img_path,
|
| 163 |
+
"boxes": boxes_norm, # list of [cx, cy, w, h] normalised
|
| 164 |
+
"class_labels": class_labels, # list of int indices
|
| 165 |
+
}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
logger.info(
|
| 169 |
+
"Dataset: %d images with annotations | %d categories",
|
| 170 |
+
len(self.items),
|
| 171 |
+
len(self.categories),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def __len__(self) -> int:
|
| 175 |
+
return len(self.items)
|
| 176 |
+
|
| 177 |
+
def __getitem__(self, idx: int) -> dict:
|
| 178 |
+
item = self.items[idx]
|
| 179 |
+
image = Image.open(item["image_path"]).convert("RGB")
|
| 180 |
+
return {
|
| 181 |
+
"image": image,
|
| 182 |
+
"boxes": torch.tensor(item["boxes"], dtype=torch.float32),
|
| 183 |
+
"class_labels": torch.tensor(item["class_labels"], dtype=torch.long),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
# Collate function
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
|
| 191 |
+
def make_collate_fn(processor: Owlv2Processor, categories: list[str], device: str, dtype: torch.dtype):
|
| 192 |
+
"""Returns a collate_fn that encodes images + text queries into model inputs."""
|
| 193 |
+
|
| 194 |
+
def collate_fn(batch: list[dict]) -> dict:
|
| 195 |
+
images = [item["image"] for item in batch]
|
| 196 |
+
|
| 197 |
+
# All images in the batch share the same text queries (all categories)
|
| 198 |
+
inputs = processor(
|
| 199 |
+
text=[categories] * len(images),
|
| 200 |
+
images=images,
|
| 201 |
+
return_tensors="pt",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Move to device with correct dtype
|
| 205 |
+
inputs = {
|
| 206 |
+
k: (v.to(device=device, dtype=dtype) if v.is_floating_point() else v.to(device=device))
|
| 207 |
+
for k, v in inputs.items()
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
labels = [
|
| 211 |
+
{
|
| 212 |
+
"boxes": item["boxes"].to(device),
|
| 213 |
+
"class_labels": item["class_labels"].to(device),
|
| 214 |
+
}
|
| 215 |
+
for item in batch
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
inputs["labels"] = labels
|
| 219 |
+
return inputs
|
| 220 |
+
|
| 221 |
+
return collate_fn
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ---------------------------------------------------------------------------
|
| 225 |
+
# Loss
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
|
| 228 |
+
def compute_detection_loss(
|
| 229 |
+
logits: torch.Tensor, # [B, N_patches, N_classes]
|
| 230 |
+
pred_boxes: torch.Tensor, # [B, N_patches, 4] normalised CxCyWH
|
| 231 |
+
targets: list[dict], # [{"boxes": [M,4], "class_labels": [M]}]
|
| 232 |
+
matcher: HungarianMatcher,
|
| 233 |
+
lambda_cls: float,
|
| 234 |
+
lambda_bbox: float,
|
| 235 |
+
lambda_giou: float,
|
| 236 |
+
) -> tuple[torch.Tensor, dict[str, float]]:
|
| 237 |
+
"""
|
| 238 |
+
Compute combined detection loss using Hungarian matching.
|
| 239 |
+
|
| 240 |
+
Classification uses sigmoid BCE (OWLv2 uses sigmoid, not softmax).
|
| 241 |
+
Box regression uses L1 + GIoU on matched pairs only.
|
| 242 |
+
"""
|
| 243 |
+
B, N, C = logits.shape
|
| 244 |
+
|
| 245 |
+
# --- Hungarian matching ---
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
indices = matcher({"logits": logits, "pred_boxes": pred_boxes}, targets)
|
| 248 |
+
|
| 249 |
+
# --- Classification loss (sigmoid binary CE) ---
|
| 250 |
+
# Target tensor: shape [B, N, C], all zeros (background),
|
| 251 |
+
# set 1.0 at matched (prediction, class) positions.
|
| 252 |
+
target_cls = torch.zeros(B, N, C, device=logits.device, dtype=logits.dtype)
|
| 253 |
+
for b, (pred_idx, gt_idx) in enumerate(indices):
|
| 254 |
+
if len(pred_idx) == 0:
|
| 255 |
+
continue
|
| 256 |
+
gt_labels = targets[b]["class_labels"][gt_idx] # [M]
|
| 257 |
+
target_cls[b, pred_idx, gt_labels] = 1.0
|
| 258 |
+
|
| 259 |
+
loss_cls = F.binary_cross_entropy_with_logits(logits, target_cls, reduction="mean")
|
| 260 |
+
|
| 261 |
+
# --- Box losses (matched pairs only) ---
|
| 262 |
+
loss_bbox = torch.tensor(0.0, device=logits.device)
|
| 263 |
+
loss_giou = torch.tensor(0.0, device=logits.device)
|
| 264 |
+
num_matched = sum(len(p) for p, _ in indices)
|
| 265 |
+
|
| 266 |
+
if num_matched > 0:
|
| 267 |
+
for b, (pred_idx, gt_idx) in enumerate(indices):
|
| 268 |
+
if len(pred_idx) == 0:
|
| 269 |
+
continue
|
| 270 |
+
p_boxes = pred_boxes[b][pred_idx] # [M, 4] CxCyWH norm
|
| 271 |
+
g_boxes = targets[b]["boxes"][gt_idx] # [M, 4] CxCyWH norm
|
| 272 |
+
|
| 273 |
+
loss_bbox = loss_bbox + F.l1_loss(p_boxes, g_boxes, reduction="sum")
|
| 274 |
+
|
| 275 |
+
p_xyxy = center_to_corners_format(p_boxes)
|
| 276 |
+
g_xyxy = center_to_corners_format(g_boxes)
|
| 277 |
+
|
| 278 |
+
# Clamp to [0, 1] to avoid degenerate boxes
|
| 279 |
+
p_xyxy = p_xyxy.clamp(0, 1)
|
| 280 |
+
g_xyxy = g_xyxy.clamp(0, 1)
|
| 281 |
+
|
| 282 |
+
giou_mat = generalized_box_iou(p_xyxy, g_xyxy) # [M, M]
|
| 283 |
+
loss_giou = loss_giou + (1 - torch.diag(giou_mat)).sum()
|
| 284 |
+
|
| 285 |
+
loss_bbox = loss_bbox / num_matched
|
| 286 |
+
loss_giou = loss_giou / num_matched
|
| 287 |
+
|
| 288 |
+
total = lambda_cls * loss_cls + lambda_bbox * loss_bbox + lambda_giou * loss_giou
|
| 289 |
+
|
| 290 |
+
log = {
|
| 291 |
+
"loss": total.item(),
|
| 292 |
+
"cls": loss_cls.item(),
|
| 293 |
+
"bbox": loss_bbox.item(),
|
| 294 |
+
"giou": loss_giou.item(),
|
| 295 |
+
}
|
| 296 |
+
return total, log
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ---------------------------------------------------------------------------
|
| 300 |
+
# Freeze / unfreeze helpers
|
| 301 |
+
# ---------------------------------------------------------------------------
|
| 302 |
+
|
| 303 |
+
BACKBONE_PREFIXES = (
|
| 304 |
+
"owlv2.vision_model",
|
| 305 |
+
"owlv2.text_model",
|
| 306 |
+
"owlv2.text_projection",
|
| 307 |
+
"owlv2.visual_projection",
|
| 308 |
+
)
|
| 309 |
+
VISION_PREFIXES = ("owlv2.vision_model", "owlv2.visual_projection")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def apply_freeze_strategy(model: Owlv2ForObjectDetection, unfreeze_vision: bool) -> None:
|
| 313 |
+
"""
|
| 314 |
+
Freeze the CLIP backbone; only train the detection heads (+ layer_norm).
|
| 315 |
+
With unfreeze_vision=True, also allow gradients through the ViT encoder.
|
| 316 |
+
"""
|
| 317 |
+
for name, param in model.named_parameters():
|
| 318 |
+
if any(name.startswith(pfx) for pfx in BACKBONE_PREFIXES):
|
| 319 |
+
if unfreeze_vision and any(name.startswith(pfx) for pfx in VISION_PREFIXES):
|
| 320 |
+
param.requires_grad_(True)
|
| 321 |
+
else:
|
| 322 |
+
param.requires_grad_(False)
|
| 323 |
+
else:
|
| 324 |
+
param.requires_grad_(True)
|
| 325 |
+
|
| 326 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 327 |
+
total = sum(p.numel() for p in model.parameters())
|
| 328 |
+
logger.info(
|
| 329 |
+
"Trainable params: %s / %s (%.1f%%)",
|
| 330 |
+
f"{trainable:,}", f"{total:,}", 100 * trainable / total,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ---------------------------------------------------------------------------
|
| 335 |
+
# LR scheduler with linear warmup
|
| 336 |
+
# ---------------------------------------------------------------------------
|
| 337 |
+
|
| 338 |
+
def build_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int):
|
| 339 |
+
def lr_lambda(step: int) -> float:
|
| 340 |
+
if step < warmup_steps:
|
| 341 |
+
return step / max(warmup_steps, 1)
|
| 342 |
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
| 343 |
+
return max(0.0, 0.5 * (1.0 + torch.cos(torch.tensor(3.14159 * progress)).item()))
|
| 344 |
+
|
| 345 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
# Single-epoch helpers
|
| 350 |
+
# ---------------------------------------------------------------------------
|
| 351 |
+
|
| 352 |
+
def _run_epoch(
|
| 353 |
+
model: Owlv2ForObjectDetection,
|
| 354 |
+
loader: DataLoader,
|
| 355 |
+
matcher: HungarianMatcher,
|
| 356 |
+
cfg: FinetuneConfig,
|
| 357 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 358 |
+
scheduler=None,
|
| 359 |
+
grad_accum: int = 1,
|
| 360 |
+
desc: str = "train",
|
| 361 |
+
) -> dict[str, float]:
|
| 362 |
+
"""Run one pass over *loader*. If optimizer is None, runs in eval mode."""
|
| 363 |
+
training = optimizer is not None
|
| 364 |
+
model.train(training)
|
| 365 |
+
|
| 366 |
+
totals: dict[str, float] = {"loss": 0, "cls": 0, "bbox": 0, "giou": 0}
|
| 367 |
+
steps = 0
|
| 368 |
+
|
| 369 |
+
if training:
|
| 370 |
+
optimizer.zero_grad()
|
| 371 |
+
|
| 372 |
+
with tqdm(loader, desc=desc, leave=False) as pbar:
|
| 373 |
+
for i, batch in enumerate(pbar):
|
| 374 |
+
labels = batch.pop("labels")
|
| 375 |
+
inputs = batch
|
| 376 |
+
|
| 377 |
+
ctx = torch.no_grad() if not training else torch.enable_grad()
|
| 378 |
+
with ctx:
|
| 379 |
+
outputs = model(**inputs)
|
| 380 |
+
|
| 381 |
+
loss, log = compute_detection_loss(
|
| 382 |
+
logits=outputs.logits,
|
| 383 |
+
pred_boxes=outputs.pred_boxes,
|
| 384 |
+
targets=labels,
|
| 385 |
+
matcher=matcher,
|
| 386 |
+
lambda_cls=cfg.lambda_cls,
|
| 387 |
+
lambda_bbox=cfg.lambda_bbox,
|
| 388 |
+
lambda_giou=cfg.lambda_giou,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
if training:
|
| 392 |
+
(loss / grad_accum).backward()
|
| 393 |
+
if (i + 1) % grad_accum == 0:
|
| 394 |
+
torch.nn.utils.clip_grad_norm_(
|
| 395 |
+
model.parameters(), cfg.grad_clip
|
| 396 |
+
)
|
| 397 |
+
optimizer.step()
|
| 398 |
+
if scheduler is not None:
|
| 399 |
+
scheduler.step()
|
| 400 |
+
optimizer.zero_grad()
|
| 401 |
+
|
| 402 |
+
for k, v in log.items():
|
| 403 |
+
totals[k] += v
|
| 404 |
+
steps += 1
|
| 405 |
+
|
| 406 |
+
pbar.set_postfix({k: f"{v/steps:.4f}" for k, v in totals.items()})
|
| 407 |
+
|
| 408 |
+
return {k: v / max(steps, 1) for k, v in totals.items()}
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# ---------------------------------------------------------------------------
|
| 412 |
+
# Main training loop
|
| 413 |
+
# ---------------------------------------------------------------------------
|
| 414 |
+
|
| 415 |
+
def run_finetune(
|
| 416 |
+
cfg: FinetuneConfig,
|
| 417 |
+
progress_callback: Optional[Callable[[int, dict, dict], None]] = None,
|
| 418 |
+
) -> None:
|
| 419 |
+
"""Full fine-tuning run from a FinetuneConfig."""
|
| 420 |
+
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
| 421 |
+
t0 = time.time()
|
| 422 |
+
|
| 423 |
+
# --- Load processor and model ---
|
| 424 |
+
resume_path = cfg.resume_from or cfg.model_name
|
| 425 |
+
logger.info("Loading processor from %s", cfg.model_name)
|
| 426 |
+
processor = Owlv2Processor.from_pretrained(cfg.model_name)
|
| 427 |
+
|
| 428 |
+
logger.info("Loading model from %s", resume_path)
|
| 429 |
+
model = Owlv2ForObjectDetection.from_pretrained(
|
| 430 |
+
str(resume_path),
|
| 431 |
+
torch_dtype=cfg.torch_dtype,
|
| 432 |
+
ignore_mismatched_sizes=True,
|
| 433 |
+
).to(cfg.device)
|
| 434 |
+
|
| 435 |
+
apply_freeze_strategy(model, cfg.unfreeze_vision)
|
| 436 |
+
|
| 437 |
+
# --- Dataset ---
|
| 438 |
+
full_dataset = CocoOwlv2Dataset(cfg.coco_json, cfg.image_dir)
|
| 439 |
+
categories = full_dataset.categories
|
| 440 |
+
|
| 441 |
+
# Save category metadata alongside the model
|
| 442 |
+
meta_path = cfg.output_dir / "label_map.json"
|
| 443 |
+
with meta_path.open("w") as fh:
|
| 444 |
+
json.dump({"categories": categories}, fh, indent=2)
|
| 445 |
+
logger.info("Label map saved → %s", meta_path)
|
| 446 |
+
|
| 447 |
+
n_val = max(1, int(len(full_dataset) * cfg.val_split))
|
| 448 |
+
n_train = len(full_dataset) - n_val
|
| 449 |
+
if n_train < 1:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
f"Not enough labeled images for training ({len(full_dataset)} total). "
|
| 452 |
+
"Use the Batch tab in the web UI and run `make export` to build a larger dataset."
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
train_ds, val_ds = random_split(
|
| 456 |
+
full_dataset,
|
| 457 |
+
[n_train, n_val],
|
| 458 |
+
generator=torch.Generator().manual_seed(42),
|
| 459 |
+
)
|
| 460 |
+
logger.info("Split: %d train / %d val", n_train, n_val)
|
| 461 |
+
|
| 462 |
+
collate = make_collate_fn(processor, categories, cfg.device, cfg.torch_dtype)
|
| 463 |
+
train_loader = DataLoader(
|
| 464 |
+
train_ds, batch_size=cfg.batch_size, shuffle=True,
|
| 465 |
+
collate_fn=collate, num_workers=0,
|
| 466 |
+
)
|
| 467 |
+
val_loader = DataLoader(
|
| 468 |
+
val_ds, batch_size=cfg.batch_size, shuffle=False,
|
| 469 |
+
collate_fn=collate, num_workers=0,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# --- Optimizer (separate LR for heads vs backbone if unfrozen) ---
|
| 473 |
+
head_params = [p for n, p in model.named_parameters()
|
| 474 |
+
if p.requires_grad and not any(n.startswith(pfx) for pfx in BACKBONE_PREFIXES)]
|
| 475 |
+
vision_params = [p for n, p in model.named_parameters()
|
| 476 |
+
if p.requires_grad and any(n.startswith(pfx) for pfx in VISION_PREFIXES)]
|
| 477 |
+
|
| 478 |
+
param_groups = [{"params": head_params, "lr": cfg.lr}]
|
| 479 |
+
if vision_params and cfg.backbone_lr > 0:
|
| 480 |
+
param_groups.append({"params": vision_params, "lr": cfg.backbone_lr})
|
| 481 |
+
|
| 482 |
+
optimizer = torch.optim.AdamW(param_groups, weight_decay=cfg.weight_decay)
|
| 483 |
+
|
| 484 |
+
total_steps = (len(train_loader) // cfg.grad_accum_steps) * cfg.epochs
|
| 485 |
+
scheduler = build_scheduler(optimizer, cfg.warmup_steps, total_steps)
|
| 486 |
+
|
| 487 |
+
matcher = HungarianMatcher(
|
| 488 |
+
class_cost=cfg.class_cost,
|
| 489 |
+
bbox_cost=cfg.bbox_cost,
|
| 490 |
+
giou_cost=cfg.giou_cost,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# --- Training loop ---
|
| 494 |
+
best_val_loss = float("inf")
|
| 495 |
+
history: list[dict] = []
|
| 496 |
+
|
| 497 |
+
for epoch in range(1, cfg.epochs + 1):
|
| 498 |
+
logger.info("─── Epoch %d / %d ───", epoch, cfg.epochs)
|
| 499 |
+
|
| 500 |
+
train_log = _run_epoch(
|
| 501 |
+
model, train_loader, matcher, cfg,
|
| 502 |
+
optimizer=optimizer, scheduler=scheduler,
|
| 503 |
+
grad_accum=cfg.grad_accum_steps,
|
| 504 |
+
desc=f"train {epoch}/{cfg.epochs}",
|
| 505 |
+
)
|
| 506 |
+
val_log = _run_epoch(
|
| 507 |
+
model, val_loader, matcher, cfg,
|
| 508 |
+
desc=f"val {epoch}/{cfg.epochs}",
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
logger.info(
|
| 512 |
+
"Epoch %d train_loss=%.4f val_loss=%.4f "
|
| 513 |
+
"(cls=%.4f bbox=%.4f giou=%.4f)",
|
| 514 |
+
epoch,
|
| 515 |
+
train_log["loss"], val_log["loss"],
|
| 516 |
+
val_log["cls"], val_log["bbox"], val_log["giou"],
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
history.append({"epoch": epoch, "train": train_log, "val": val_log})
|
| 520 |
+
|
| 521 |
+
if progress_callback is not None:
|
| 522 |
+
try:
|
| 523 |
+
progress_callback(epoch, train_log, val_log)
|
| 524 |
+
except TrainingStoppedError:
|
| 525 |
+
logger.info("Training stopped by user at epoch %d.", epoch)
|
| 526 |
+
break
|
| 527 |
+
except Exception as exc:
|
| 528 |
+
logger.warning("progress_callback raised %s; continuing.", exc)
|
| 529 |
+
|
| 530 |
+
# Save best
|
| 531 |
+
if val_log["loss"] < best_val_loss:
|
| 532 |
+
best_val_loss = val_log["loss"]
|
| 533 |
+
best_path = cfg.output_dir / "best"
|
| 534 |
+
model.save_pretrained(str(best_path))
|
| 535 |
+
processor.save_pretrained(str(best_path))
|
| 536 |
+
logger.info(" ✓ New best val_loss=%.4f saved → %s", best_val_loss, best_path)
|
| 537 |
+
|
| 538 |
+
# Periodic checkpoint
|
| 539 |
+
if epoch % cfg.save_every == 0:
|
| 540 |
+
ckpt_path = cfg.output_dir / f"checkpoint-epoch-{epoch:03d}"
|
| 541 |
+
model.save_pretrained(str(ckpt_path))
|
| 542 |
+
processor.save_pretrained(str(ckpt_path))
|
| 543 |
+
logger.info(" Checkpoint saved → %s", ckpt_path)
|
| 544 |
+
|
| 545 |
+
# Save training history
|
| 546 |
+
history_path = cfg.output_dir / "training_history.json"
|
| 547 |
+
with history_path.open("w") as fh:
|
| 548 |
+
json.dump(history, fh, indent=2)
|
| 549 |
+
|
| 550 |
+
elapsed = time.time() - t0
|
| 551 |
+
logger.info(
|
| 552 |
+
"Fine-tuning complete in %.1f min. Best val_loss=%.4f. Model at %s/best",
|
| 553 |
+
elapsed / 60, best_val_loss, cfg.output_dir,
|
| 554 |
+
)
|
autolabel/segment.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
segment.py — SAM2 segmentation using bounding-box prompts.
|
| 3 |
+
|
| 4 |
+
Workflow (Grounded SAM2 pattern):
|
| 5 |
+
OWLv2 text prompts → bounding boxes
|
| 6 |
+
SAM2 box prompts → pixel masks
|
| 7 |
+
|
| 8 |
+
Model: facebook/sam2-hiera-tiny (~160 MB, fast enough for development)
|
| 9 |
+
|
| 10 |
+
Each detection returned by segment_with_boxes() gains two extra fields:
|
| 11 |
+
"mask": bool numpy array (H, W) — pixel mask in image space
|
| 12 |
+
"segmentation": COCO polygon list [[x, y, x, y, ...], ...]
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
SAM2_DEFAULT_MODEL = "facebook/sam2-hiera-tiny"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_sam2(device: str, model_id: str = SAM2_DEFAULT_MODEL):
|
| 30 |
+
"""Load SAM2 processor and model onto *device*. Returns (processor, model)."""
|
| 31 |
+
from transformers import Sam2Processor, Sam2Model
|
| 32 |
+
|
| 33 |
+
logger.info("Loading SAM2 %s on %s …", model_id, device)
|
| 34 |
+
processor = Sam2Processor.from_pretrained(model_id)
|
| 35 |
+
# SAM2 runs in float32 — bfloat16/float16 not reliably supported on all backends
|
| 36 |
+
model = Sam2Model.from_pretrained(model_id, torch_dtype=torch.float32).to(device)
|
| 37 |
+
model.eval()
|
| 38 |
+
logger.info("SAM2 ready.")
|
| 39 |
+
return processor, model
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
| 43 |
+
"""Convert a boolean 2-D mask to a COCO polygon list.
|
| 44 |
+
|
| 45 |
+
Returns a list of polygons; each polygon is a flat [x1,y1,x2,y2,…] list.
|
| 46 |
+
Returns [] if cv2 is unavailable or no contour is found.
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
import cv2
|
| 50 |
+
except ImportError:
|
| 51 |
+
logger.warning("opencv-python not installed — segmentation polygons skipped.")
|
| 52 |
+
return []
|
| 53 |
+
|
| 54 |
+
mask_u8 = mask.astype(np.uint8) * 255
|
| 55 |
+
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 56 |
+
polygons: list[list[float]] = []
|
| 57 |
+
for contour in contours:
|
| 58 |
+
if contour.size >= 6: # need at least 3 points
|
| 59 |
+
polygons.append(contour.flatten().tolist())
|
| 60 |
+
return polygons
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def segment_with_boxes(
|
| 64 |
+
pil_image: Image.Image,
|
| 65 |
+
detections: list[dict],
|
| 66 |
+
processor,
|
| 67 |
+
model,
|
| 68 |
+
device: str,
|
| 69 |
+
) -> list[dict]:
|
| 70 |
+
"""Run SAM2 on *pil_image* using the bounding box from each detection.
|
| 71 |
+
|
| 72 |
+
Each detection in the returned list gains:
|
| 73 |
+
"mask" — bool numpy array (H, W)
|
| 74 |
+
"segmentation" — COCO polygon list
|
| 75 |
+
|
| 76 |
+
Detections without a valid box are passed through unchanged (no mask field).
|
| 77 |
+
"""
|
| 78 |
+
if not detections:
|
| 79 |
+
return detections
|
| 80 |
+
|
| 81 |
+
augmented: list[dict] = []
|
| 82 |
+
h, w = pil_image.height, pil_image.width
|
| 83 |
+
|
| 84 |
+
for det in detections:
|
| 85 |
+
box = det.get("box_xyxy")
|
| 86 |
+
if box is None:
|
| 87 |
+
augmented.append(det)
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
x1, y1, x2, y2 = box
|
| 91 |
+
try:
|
| 92 |
+
# input_boxes: [batch=1, n_boxes=1, 4]
|
| 93 |
+
encoding = processor(
|
| 94 |
+
images=pil_image,
|
| 95 |
+
input_boxes=[[[x1, y1, x2, y2]]],
|
| 96 |
+
return_tensors="pt",
|
| 97 |
+
)
|
| 98 |
+
# transformers 5.x Sam2Processor returns: pixel_values, original_sizes,
|
| 99 |
+
# input_boxes — no reshaped_input_sizes. Move all tensors to device.
|
| 100 |
+
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in encoding.items()}
|
| 101 |
+
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
outputs = model(**inputs, multimask_output=False)
|
| 104 |
+
|
| 105 |
+
# pred_masks shape: [batch, n_boxes, n_masks, H_low, W_low]
|
| 106 |
+
# post_process_masks(masks, original_sizes) — transformers 5.x API:
|
| 107 |
+
# iterates over batch; each masks[i] goes through F.interpolate to
|
| 108 |
+
# original_size, then optional binarise. Expects 4-D per-image tensor.
|
| 109 |
+
# We pass pred_masks directly; masks[0] = [n_boxes, n_masks, H_low, W_low]
|
| 110 |
+
# which F.interpolate handles as [N, C, H, W].
|
| 111 |
+
original_sizes = encoding.get("original_sizes", torch.tensor([[h, w]]))
|
| 112 |
+
masks = processor.post_process_masks(
|
| 113 |
+
outputs.pred_masks,
|
| 114 |
+
original_sizes,
|
| 115 |
+
)
|
| 116 |
+
# masks[0]: [n_boxes=1, n_masks=1, H_orig, W_orig]
|
| 117 |
+
mask_np: np.ndarray = masks[0][0, 0].cpu().numpy().astype(bool)
|
| 118 |
+
except Exception:
|
| 119 |
+
logger.exception(
|
| 120 |
+
"SAM2 failed for '%s' — using empty mask", det.get("label", "?")
|
| 121 |
+
)
|
| 122 |
+
mask_np = np.zeros((h, w), dtype=bool)
|
| 123 |
+
|
| 124 |
+
polygons = _mask_to_polygon(mask_np)
|
| 125 |
+
augmented.append({**det, "mask": mask_np, "segmentation": polygons})
|
| 126 |
+
|
| 127 |
+
return augmented
|
autolabel/utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
utils.py — shared helpers for the autolabel pipeline.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def collect_images(directory: Path) -> list[Path]:
|
| 18 |
+
"""Return sorted list of image files under *directory*."""
|
| 19 |
+
images = sorted(
|
| 20 |
+
p for p in directory.rglob("*") if p.suffix.lower() in IMAGE_EXTENSIONS
|
| 21 |
+
)
|
| 22 |
+
logger.info("Found %d image(s) in %s", len(images), directory)
|
| 23 |
+
return images
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_json(path: Path) -> Any:
|
| 27 |
+
"""Load and return JSON from *path*."""
|
| 28 |
+
with path.open("r", encoding="utf-8") as fh:
|
| 29 |
+
return json.load(fh)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def save_json(data: Any, path: Path, indent: int = 2) -> None:
|
| 33 |
+
"""Serialise *data* to JSON at *path*, creating parent dirs as needed."""
|
| 34 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
with path.open("w", encoding="utf-8") as fh:
|
| 36 |
+
json.dump(data, fh, indent=indent, ensure_ascii=False)
|
| 37 |
+
logger.debug("Saved JSON → %s", path)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def detection_json_path(image_path: Path, detections_dir: Path) -> Path:
|
| 41 |
+
"""Return the expected detection JSON path for a given image."""
|
| 42 |
+
return detections_dir / (image_path.stem + ".json")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def setup_logging(level: int = logging.INFO) -> None:
|
| 46 |
+
"""Configure root logger with a sensible format."""
|
| 47 |
+
logging.basicConfig(
|
| 48 |
+
level=level,
|
| 49 |
+
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
|
| 50 |
+
datefmt="%H:%M:%S",
|
| 51 |
+
)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "labelplayground"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Auto-labeling pipeline using OWLv2 + SAM2 for household object detection"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = "==3.11.*"
|
| 7 |
+
license = { text = "MIT" }
|
| 8 |
+
authors = [{ name = "Erick Rosas" }]
|
| 9 |
+
|
| 10 |
+
dependencies = [
|
| 11 |
+
# Deep learning — device-agnostic (CUDA / MPS / CPU)
|
| 12 |
+
"torch>=2.2.0",
|
| 13 |
+
"torchvision>=0.17.0",
|
| 14 |
+
|
| 15 |
+
# Hugging Face — OWLv2 + SAM2 models & processors (Apache 2.0)
|
| 16 |
+
"transformers>=4.45.0", # SAM2 support added in 4.45
|
| 17 |
+
|
| 18 |
+
# Computer vision
|
| 19 |
+
"pillow>=10.3.0",
|
| 20 |
+
"opencv-python>=4.9.0", # mask → COCO polygon via cv2.findContours (SAM2)
|
| 21 |
+
|
| 22 |
+
# Data & utilities
|
| 23 |
+
"numpy>=1.26.0",
|
| 24 |
+
"pydantic>=2.7.0",
|
| 25 |
+
"pydantic-settings>=2.3.0",
|
| 26 |
+
|
| 27 |
+
# CLI
|
| 28 |
+
"click>=8.1.7",
|
| 29 |
+
"tqdm>=4.66.0",
|
| 30 |
+
|
| 31 |
+
# Environment
|
| 32 |
+
"python-dotenv>=1.0.1",
|
| 33 |
+
|
| 34 |
+
# Web UI
|
| 35 |
+
"gradio>=6.0.0",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[project.scripts]
|
| 39 |
+
autolabel-detect = "scripts.run_detection:main"
|
| 40 |
+
autolabel-export = "scripts.export_coco:main"
|
| 41 |
+
|
| 42 |
+
[build-system]
|
| 43 |
+
requires = ["hatchling"]
|
| 44 |
+
build-backend = "hatchling.build"
|
| 45 |
+
|
| 46 |
+
[tool.hatch.build.targets.wheel]
|
| 47 |
+
packages = ["autolabel", "scripts"]
|
samples/CREDITS.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Sample images used in the autolabel demo
|
| 2 |
+
=========================================
|
| 3 |
+
|
| 4 |
+
kitchen.jpg
|
| 5 |
+
Title: "Good Food Display - NCI Visuals Online"
|
| 6 |
+
Source: https://commons.wikimedia.org/wiki/File:Good_Food_Display_-_NCI_Visuals_Online.jpg
|
| 7 |
+
License: Public domain (National Cancer Institute / US Government)
|
| 8 |
+
|
| 9 |
+
dog.jpg
|
| 10 |
+
Title: "Yellow Labrador Looking"
|
| 11 |
+
Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg
|
| 12 |
+
License: CC BY-SA 3.0 — Jenn Durfey
|
| 13 |
+
|
| 14 |
+
cat.jpg
|
| 15 |
+
Title: "Cat November 2010-1a"
|
| 16 |
+
Source: https://commons.wikimedia.org/wiki/File:Cat_November_2010-1a.jpg
|
| 17 |
+
License: CC BY-SA 3.0 — Alvesgaspar
|
samples/animals.jpg
ADDED
|
Git LFS Details
|
samples/cat.jpg
ADDED
|
Git LFS Details
|
samples/dog.jpg
ADDED
|
samples/kitchen.jpg
ADDED
|
Git LFS Details
|
scripts/export_coco.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
export_coco.py — CLI entrypoint for the COCO JSON export stage.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
uv run python scripts/export_coco.py
|
| 6 |
+
uv run python scripts/export_coco.py --labeled-dir data/labeled --output data/labeled/coco_export.json
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import click
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
from autolabel.export import run_export
|
| 20 |
+
from autolabel.config import settings
|
| 21 |
+
from autolabel.utils import setup_logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@click.command()
|
| 25 |
+
@click.option(
|
| 26 |
+
"--labeled-dir",
|
| 27 |
+
default=str(settings.labeled_dir),
|
| 28 |
+
show_default=True,
|
| 29 |
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
| 30 |
+
help="Directory containing accepted-annotation JSON files.",
|
| 31 |
+
)
|
| 32 |
+
@click.option(
|
| 33 |
+
"--output",
|
| 34 |
+
default=str(settings.labeled_dir / "coco_export.json"),
|
| 35 |
+
show_default=True,
|
| 36 |
+
type=click.Path(path_type=Path),
|
| 37 |
+
help="Output path for the COCO JSON file.",
|
| 38 |
+
)
|
| 39 |
+
@click.option("--verbose", "-v", is_flag=True, default=False, help="Debug logging.")
|
| 40 |
+
def main(labeled_dir: Path, output: Path, verbose: bool) -> None:
|
| 41 |
+
"""Export accepted annotations from LABELED_DIR to COCO JSON format."""
|
| 42 |
+
setup_logging(logging.DEBUG if verbose else logging.INFO)
|
| 43 |
+
run_export(labeled_dir=labeled_dir, output_path=output, cfg=settings)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
main()
|
scripts/finetune_owlv2.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune_owlv2.py — CLI for fine-tuning OWLv2 on a COCO-format dataset.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
uv run python scripts/finetune_owlv2.py
|
| 6 |
+
uv run python scripts/finetune_owlv2.py --epochs 20 --lr 5e-5
|
| 7 |
+
uv run python scripts/finetune_owlv2.py --unfreeze-vision --backbone-lr 1e-5
|
| 8 |
+
uv run python scripts/finetune_owlv2.py --resume models/owlv2-finetuned/checkpoint-epoch-005
|
| 9 |
+
|
| 10 |
+
Recommended hardware:
|
| 11 |
+
CUDA (Windows/Linux) — use fp16 for speed, set --device cuda
|
| 12 |
+
MPS (Apple Silicon) — fp32 only, slower but functional for small datasets
|
| 13 |
+
CPU — very slow, only for tiny sanity-check runs
|
| 14 |
+
|
| 15 |
+
Typical first run:
|
| 16 |
+
1. make export # build data/labeled/coco_export.json
|
| 17 |
+
2. make finetune # train with defaults
|
| 18 |
+
3. Update app.py to load from models/owlv2-finetuned/best
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import click
|
| 27 |
+
import torch
|
| 28 |
+
from dotenv import load_dotenv
|
| 29 |
+
|
| 30 |
+
load_dotenv()
|
| 31 |
+
|
| 32 |
+
from autolabel.config import settings
|
| 33 |
+
from autolabel.finetune import FinetuneConfig, run_finetune
|
| 34 |
+
from autolabel.utils import setup_logging
|
| 35 |
+
|
| 36 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 37 |
+
DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "owlv2-finetuned"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@click.command()
|
| 41 |
+
@click.option(
|
| 42 |
+
"--coco-json",
|
| 43 |
+
default=str(settings.labeled_dir / "coco_export.json"),
|
| 44 |
+
show_default=True,
|
| 45 |
+
type=click.Path(exists=True, path_type=Path),
|
| 46 |
+
help="COCO JSON file produced by `make export`.",
|
| 47 |
+
)
|
| 48 |
+
@click.option(
|
| 49 |
+
"--image-dir",
|
| 50 |
+
default=str(settings.raw_dir),
|
| 51 |
+
show_default=True,
|
| 52 |
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
| 53 |
+
help="Directory containing the source images (matched by file_name in COCO JSON).",
|
| 54 |
+
)
|
| 55 |
+
@click.option(
|
| 56 |
+
"--output-dir",
|
| 57 |
+
default=str(DEFAULT_OUTPUT),
|
| 58 |
+
show_default=True,
|
| 59 |
+
type=click.Path(file_okay=False, path_type=Path),
|
| 60 |
+
help="Directory to save checkpoints and the best model.",
|
| 61 |
+
)
|
| 62 |
+
@click.option("--model", default=settings.model, show_default=True,
|
| 63 |
+
help="Base model to fine-tune.")
|
| 64 |
+
@click.option("--epochs", default=10, show_default=True, type=int)
|
| 65 |
+
@click.option("--batch-size", default=1, show_default=True, type=int,
|
| 66 |
+
help="Images per forward pass. Keep at 1 for OWLv2-large on ≤8 GB VRAM.")
|
| 67 |
+
@click.option("--grad-accum", default=4, show_default=True, type=int,
|
| 68 |
+
help="Gradient accumulation steps. Effective batch = batch_size * grad_accum.")
|
| 69 |
+
@click.option("--lr", default=1e-4, show_default=True, type=float,
|
| 70 |
+
help="Learning rate for detection heads.")
|
| 71 |
+
@click.option("--val-split", default=0.2, show_default=True, type=float,
|
| 72 |
+
help="Fraction of data to use for validation.")
|
| 73 |
+
@click.option("--warmup-steps", default=50, show_default=True, type=int)
|
| 74 |
+
@click.option("--save-every", default=1, show_default=True, type=int,
|
| 75 |
+
help="Save a checkpoint every N epochs.")
|
| 76 |
+
@click.option(
|
| 77 |
+
"--unfreeze-vision", is_flag=True, default=False,
|
| 78 |
+
help="Also fine-tune the ViT image encoder (needs more VRAM, slower).",
|
| 79 |
+
)
|
| 80 |
+
@click.option(
|
| 81 |
+
"--backbone-lr", default=1e-5, show_default=True, type=float,
|
| 82 |
+
help="LR for the vision encoder when --unfreeze-vision is set.",
|
| 83 |
+
)
|
| 84 |
+
@click.option(
|
| 85 |
+
"--resume",
|
| 86 |
+
default=None,
|
| 87 |
+
type=click.Path(path_type=Path),
|
| 88 |
+
help="Path to a saved checkpoint to resume from.",
|
| 89 |
+
)
|
| 90 |
+
@click.option(
|
| 91 |
+
"--device",
|
| 92 |
+
default=settings.device,
|
| 93 |
+
show_default=True,
|
| 94 |
+
help="Torch device: cuda | mps | cpu.",
|
| 95 |
+
)
|
| 96 |
+
@click.option("--verbose", "-v", is_flag=True, default=False)
|
| 97 |
+
def main(
|
| 98 |
+
coco_json: Path,
|
| 99 |
+
image_dir: Path,
|
| 100 |
+
output_dir: Path,
|
| 101 |
+
model: str,
|
| 102 |
+
epochs: int,
|
| 103 |
+
batch_size: int,
|
| 104 |
+
grad_accum: int,
|
| 105 |
+
lr: float,
|
| 106 |
+
val_split: float,
|
| 107 |
+
warmup_steps: int,
|
| 108 |
+
save_every: int,
|
| 109 |
+
unfreeze_vision: bool,
|
| 110 |
+
backbone_lr: float,
|
| 111 |
+
resume: Path | None,
|
| 112 |
+
device: str,
|
| 113 |
+
verbose: bool,
|
| 114 |
+
) -> None:
|
| 115 |
+
"""Fine-tune OWLv2 on your labeled COCO dataset."""
|
| 116 |
+
setup_logging(logging.DEBUG if verbose else logging.INFO)
|
| 117 |
+
|
| 118 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 119 |
+
|
| 120 |
+
cfg = FinetuneConfig(
|
| 121 |
+
coco_json=coco_json,
|
| 122 |
+
image_dir=image_dir,
|
| 123 |
+
output_dir=output_dir,
|
| 124 |
+
model_name=model,
|
| 125 |
+
device=device,
|
| 126 |
+
torch_dtype=dtype,
|
| 127 |
+
epochs=epochs,
|
| 128 |
+
batch_size=batch_size,
|
| 129 |
+
grad_accum_steps=grad_accum,
|
| 130 |
+
lr=lr,
|
| 131 |
+
backbone_lr=backbone_lr if unfreeze_vision else 0.0,
|
| 132 |
+
val_split=val_split,
|
| 133 |
+
warmup_steps=warmup_steps,
|
| 134 |
+
save_every=save_every,
|
| 135 |
+
unfreeze_vision=unfreeze_vision,
|
| 136 |
+
resume_from=resume,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
click.echo(f"Fine-tuning OWLv2 on {coco_json}")
|
| 140 |
+
click.echo(f" device : {device} ({dtype})")
|
| 141 |
+
click.echo(f" epochs : {epochs}")
|
| 142 |
+
click.echo(f" effective bs : {batch_size * grad_accum}")
|
| 143 |
+
click.echo(f" heads lr : {lr}")
|
| 144 |
+
click.echo(f" unfreeze ViT : {unfreeze_vision}")
|
| 145 |
+
click.echo(f" output : {output_dir}")
|
| 146 |
+
click.echo()
|
| 147 |
+
|
| 148 |
+
run_finetune(cfg)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
main()
|
scripts/run_detection.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
run_detection.py — CLI entrypoint for the OWLv2 detection stage.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
uv run python scripts/run_detection.py --image-dir data/raw --output-dir data/detections
|
| 6 |
+
uv run python scripts/run_detection.py --help
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import click
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
load_dotenv() # picks up PYTORCH_ENABLE_MPS_FALLBACK and other vars
|
| 19 |
+
|
| 20 |
+
from autolabel.detect import run_detection
|
| 21 |
+
from autolabel.config import settings
|
| 22 |
+
from autolabel.utils import setup_logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@click.command()
|
| 26 |
+
@click.option(
|
| 27 |
+
"--image-dir",
|
| 28 |
+
default=str(settings.raw_dir),
|
| 29 |
+
show_default=True,
|
| 30 |
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
| 31 |
+
help="Directory containing input images.",
|
| 32 |
+
)
|
| 33 |
+
@click.option(
|
| 34 |
+
"--output-dir",
|
| 35 |
+
default=str(settings.detections_dir),
|
| 36 |
+
show_default=True,
|
| 37 |
+
type=click.Path(file_okay=False, path_type=Path),
|
| 38 |
+
help="Directory to write per-image detection JSON files.",
|
| 39 |
+
)
|
| 40 |
+
@click.option(
|
| 41 |
+
"--prompts",
|
| 42 |
+
default=None,
|
| 43 |
+
help="Comma-separated list of text prompts (overrides config defaults).",
|
| 44 |
+
)
|
| 45 |
+
@click.option(
|
| 46 |
+
"--threshold",
|
| 47 |
+
default=None,
|
| 48 |
+
type=float,
|
| 49 |
+
help="Score threshold override (0.0–1.0).",
|
| 50 |
+
)
|
| 51 |
+
@click.option(
|
| 52 |
+
"--force",
|
| 53 |
+
is_flag=True,
|
| 54 |
+
default=False,
|
| 55 |
+
help="Re-process images even if a detection JSON already exists.",
|
| 56 |
+
)
|
| 57 |
+
@click.option("--verbose", "-v", is_flag=True, default=False, help="Debug logging.")
|
| 58 |
+
def main(
|
| 59 |
+
image_dir: Path,
|
| 60 |
+
output_dir: Path,
|
| 61 |
+
prompts: Optional[str],
|
| 62 |
+
threshold: Optional[float],
|
| 63 |
+
force: bool,
|
| 64 |
+
verbose: bool,
|
| 65 |
+
) -> None:
|
| 66 |
+
"""Run OWLv2 open-vocabulary detection on IMAGE_DIR images."""
|
| 67 |
+
setup_logging(logging.DEBUG if verbose else logging.INFO)
|
| 68 |
+
|
| 69 |
+
prompt_list = None
|
| 70 |
+
if prompts:
|
| 71 |
+
prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
|
| 72 |
+
|
| 73 |
+
if threshold is not None:
|
| 74 |
+
settings.threshold = threshold
|
| 75 |
+
|
| 76 |
+
run_detection(
|
| 77 |
+
image_dir=image_dir,
|
| 78 |
+
output_dir=output_dir,
|
| 79 |
+
prompts=prompt_list,
|
| 80 |
+
cfg=settings,
|
| 81 |
+
force=force,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
main()
|