small_object_detection / tune_thresholds.py
orik-ss's picture
Log device, Jina CPU warning, pin revision
82551bb
"""
Find best conf_threshold and gap_threshold for Jina and Nomic using COCO ground truth.
Expects full_frames/ with images and annotations.coco.json (COCO format).
Runs the same detection + crop pipeline, matches each crop to a GT annotation (IoU),
then grid-searches (conf_threshold, gap_threshold) to maximize accuracy.
"""
import argparse
import json
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, DFineForObjectDetection
from dfine_jina_pipeline import (
box_center_inside,
box_iou,
deduplicate_by_iou,
get_person_car_label_ids,
group_detections,
run_dfine,
squarify_crop_box,
)
from jina_fewshot import IMAGE_EXTS, TRUNCATE_DIM, JinaCLIPv2Encoder, build_refs, draw_label_on_image
from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic
# Our 4 classes (same order as refs)
CLASS_NAMES = ["cigarette", "gun", "knife", "phone"]
def coco_bbox_to_xyxy(bbox):
"""COCO bbox [x, y, w, h] -> [x1, y1, x2, y2]. Tolerate string numbers from JSON."""
x, y, w, h = (float(v) for v in bbox)
return [x, y, x + w, y + h]
def map_category_to_class(name: str) -> str | None:
"""Map COCO category name to one of our 4 classes, or None if other."""
n = (name or "").strip().lower()
if "cigarette" in n:
return "cigarette"
if any(x in n for x in ("gun", "pistol", "handgun", "firearm")):
return "gun"
if "knife" in n or "blade" in n:
return "knife"
if any(x in n for x in ("phone", "cell", "mobile", "smartphone", "telephone")):
return "phone"
return None
def load_coco_gt(annotations_path: Path):
"""
Load COCO JSON. Returns:
- file_to_gts: dict[file_name] = list of (bbox_xyxy, category_name)
- categories: list of category dicts from COCO
"""
with open(annotations_path) as f:
data = json.load(f)
images = {im["id"]: im for im in data.get("images", [])}
categories = {c["id"]: c["name"] for c in data.get("categories", [])}
file_to_gts = {}
for im in images.values():
file_to_gts[im["file_name"]] = []
for ann in data.get("annotations", []):
image_id = ann["image_id"]
cat_name = categories.get(ann["category_id"], "")
bbox_xyxy = coco_bbox_to_xyxy(ann["bbox"])
file_name = images[image_id]["file_name"]
file_to_gts[file_name].append((bbox_xyxy, cat_name))
# Also index by basename for lookup by Path.name
by_basename = {}
for fn, gts in file_to_gts.items():
by_basename[Path(fn).name] = gts
return by_basename, data.get("categories", [])
def assign_gt_to_crop(crop_box_xyxy, gt_list, iou_min=0.3):
"""
Find best overlapping GT for this crop. Returns (gt_class or None, iou).
gt_class is one of CLASS_NAMES (mapped from category).
"""
best_iou = 0.0
best_class = None
for bbox_xyxy, cat_name in gt_list:
iou = box_iou(crop_box_xyxy, bbox_xyxy)
if iou >= iou_min and iou > best_iou:
cls = map_category_to_class(cat_name)
if cls is not None:
best_iou = iou
best_class = cls
return best_class, best_iou
def parse_args():
p = argparse.ArgumentParser(description="Tune Jina/Nomic thresholds using COCO GT")
p.add_argument("--input", default="full_frames", help="Folder with images and annotations.coco.json")
p.add_argument("--annotations", default=None, help="Path to annotations.coco.json (default: input/_annotations.coco.json)")
p.add_argument("--refs", required=True, help="Reference images folder (for Jina + Nomic refs)")
p.add_argument("--output", default="threshold_tuning", help="Output folder for results CSV")
p.add_argument("--det-threshold", type=float, default=0.3)
p.add_argument("--group-dist", type=float, default=None)
p.add_argument("--expand", type=float, default=0.3)
p.add_argument("--min-side", type=int, default=40)
p.add_argument("--text-weight", type=float, default=0.3)
p.add_argument("--iou-min", type=float, default=0.3, help="Min IoU to match crop to GT")
p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)")
p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)")
p.add_argument("--max-images", type=int, default=None)
p.add_argument("--device", default=None)
p.add_argument("--no-save-crops", action="store_true", help="Do not save annotated crop images")
p.add_argument("--save-conf", type=float, default=0.5, help="Conf threshold for saved crop labels")
p.add_argument("--save-gap", type=float, default=0.02, help="Gap threshold for saved crop labels")
return p.parse_args()
def main():
args = parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
input_dir = Path(args.input)
refs_dir = Path(args.refs)
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
annotations_path = Path(args.annotations) if args.annotations else input_dir / "_annotations.coco.json"
if not annotations_path.is_file():
raise SystemExit(f"Annotations not found: {annotations_path}")
file_to_gts, _ = load_coco_gt(annotations_path)
print(f"[*] Loaded GT for {len(file_to_gts)} images from {annotations_path}")
paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS)
if args.max_images is not None:
paths = paths[: args.max_images]
# Only images that appear in COCO
paths = [p for p in paths if p.name in file_to_gts]
if not paths:
raise SystemExit("No images in input that have COCO annotations.")
print(f"[*] Loading D-FINE...")
image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj365")
dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365")
dfine_model = dfine_model.to(device).eval()
person_car_ids = get_person_car_label_ids(dfine_model)
print("[*] Loading Jina-CLIP-v2 and building refs...")
jina_encoder = JinaCLIPv2Encoder(device)
ref_labels, ref_embs = build_refs(
jina_encoder, refs_dir, TRUNCATE_DIM, args.text_weight, batch_size=16
)
assert ref_labels == CLASS_NAMES, f"Ref order {ref_labels}"
print("[*] Loading Nomic (vision + text) and building refs (same as Jina: text_weight 0.3)...")
nomic_encoder = NomicVisionEncoder(device)
nomic_text_encoder = NomicTextEncoder(device)
ref_labels_nomic, ref_embs_nomic = build_refs_nomic(
nomic_encoder, refs_dir, batch_size=16,
text_encoder=nomic_text_encoder, text_weight=args.text_weight,
)
# Optional: save annotated crops for Jina and Nomic, raw crops (no label), and person/car grouping crops
save_crops = not args.no_save_crops
if save_crops:
jina_crops_dir = output_dir / "jina_crops"
nomic_crops_dir = output_dir / "nomic_crops"
crops_no_label_dir = output_dir / "crops"
detection_crops_dir = output_dir / "detection_crops"
jina_crops_dir.mkdir(parents=True, exist_ok=True)
nomic_crops_dir.mkdir(parents=True, exist_ok=True)
crops_no_label_dir.mkdir(parents=True, exist_ok=True)
detection_crops_dir.mkdir(parents=True, exist_ok=True)
# Collect per-crop: gt_class, jina sims/conf/gap, nomic sims/conf/gap (only crops with gt in our 4 classes)
rows = []
for img_path in paths:
pil = Image.open(img_path).convert("RGB")
img_w, img_h = pil.size
group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w)
detections = run_dfine(pil, image_processor, dfine_model, device, args.det_threshold)
person_car = [d for d in detections if d["cls"] in person_car_ids]
if not person_car:
continue
grouped = group_detections(person_car, group_dist)
grouped.sort(key=lambda x: x["conf"], reverse=True)
gt_list = file_to_gts.get(img_path.name, [])
if not gt_list:
continue
# 1) Collect all candidate crops (bboxes inside person/car groups, with GT match)
# Each candidate: (crop_box, crop_pil, gt_class, gidx, crop_idx)
candidates = []
for gidx, grp in enumerate(grouped[:10]):
x1, y1, x2, y2 = grp["box"]
group_box = [x1, y1, x2, y2]
# Save person/car grouping output crop (detection crop only)
if save_crops:
gx1 = max(0, int(x1))
gy1 = max(0, int(y1))
gx2 = min(img_w, int(x2))
gy2 = min(img_h, int(y2))
if gx2 > gx1 and gy2 > gy1:
group_crop = pil.crop((gx1, gy1, gx2, gy2))
group_crop.save(detection_crops_dir / f"{img_path.stem}_group{gidx}.jpg")
inside = [
d for d in detections
if box_center_inside(d["box"], group_box)
and d["cls"] not in person_car_ids
]
inside = deduplicate_by_iou(inside, iou_threshold=0.9)
for crop_idx, d in enumerate(inside):
bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
obj_w, obj_h = bx2 - bx1, by2 - by1
if obj_w <= 0 or obj_h <= 0:
continue
pad_x, pad_y = obj_w * args.expand, obj_h * args.expand
bx1 = max(0, int(bx1 - pad_x))
by1 = max(0, int(by1 - pad_y))
bx2 = min(img_w, int(bx2 + pad_x))
by2 = min(img_h, int(by2 + pad_y))
if bx2 <= bx1 or by2 <= by1:
continue
if min(bx2 - bx1, by2 - by1) < args.min_side:
continue
expanded_box = [bx1, by1, bx2, by2]
gt_class, _ = assign_gt_to_crop(expanded_box, gt_list, args.iou_min)
if gt_class is None:
continue
candidates.append((expanded_box, gt_class, gidx, crop_idx))
# 2) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept
def crop_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
candidates.sort(key=lambda c: -crop_area(c[0])) # largest first
kept = []
for c in candidates:
expanded_box = c[0]
# Skip if same object: IoU above threshold, or one box's center is inside the other
def is_same_object(box_a, box_b):
if box_iou(box_a, box_b) >= args.crop_dedup_iou:
return True
if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
return True
return False
if not any(is_same_object(expanded_box, k[0]) for k in kept):
kept.append(c)
# 3) Optionally squarify, then run Jina/Nomic only on kept crops
for i, (expanded_box, gt_class, gidx, crop_idx) in enumerate(kept):
if not args.no_squarify:
bx1, by1, bx2, by2 = squarify_crop_box(
expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3], img_w, img_h
)
else:
bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]
crop_box = [bx1, by1, bx2, by2]
crop_pil = pil.crop((bx1, by1, bx2, by2))
bbox_suffix = f"_{bx1}_{by1}_{bx2}_{by2}"
crop_name = f"{img_path.stem}_g{gidx}_{i}{bbox_suffix}.jpg"
q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
sims_jina = (q_jina @ ref_embs.T).squeeze(0)
best_jina = int(np.argmax(sims_jina))
conf_jina = float(sims_jina[best_jina])
gap_jina = float(sims_jina[best_jina] - np.partition(sims_jina, -2)[-2])
q_nomic = nomic_encoder.encode_images([crop_pil])
sims_nomic = (q_nomic @ ref_embs_nomic.T).squeeze(0)
best_nomic = int(np.argmax(sims_nomic))
conf_nomic = float(sims_nomic[best_nomic])
gap_nomic = float(sims_nomic[best_nomic] - np.partition(sims_nomic, -2)[-2])
rows.append({
"gt": gt_class,
"jina_best_idx": best_jina,
"jina_conf": conf_jina,
"jina_gap": gap_jina,
"nomic_best_idx": best_nomic,
"nomic_conf": conf_nomic,
"nomic_gap": gap_nomic,
})
if save_crops:
crop_pil.save(crops_no_label_dir / crop_name)
sc, sg = args.save_conf, args.save_gap
label_jina = ref_labels[best_jina] if (conf_jina >= sc and gap_jina >= sg) else f"unknown (gt:{gt_class})"
label_nomic = ref_labels_nomic[best_nomic] if (conf_nomic >= sc and gap_nomic >= sg) else f"unknown (gt:{gt_class})"
ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina)
ann_nomic = draw_label_on_image(crop_pil, label_nomic, conf_nomic)
ann_jina.save(jina_crops_dir / crop_name)
ann_nomic.save(nomic_crops_dir / crop_name)
if not rows:
raise SystemExit("No crops matched to GT (with our 4 classes). Check annotations and iou_min.")
print(f"[*] {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}")
if save_crops:
print(f"[*] Annotated crops saved to {jina_crops_dir} and {nomic_crops_dir}")
print(f"[*] Raw crops (no label) saved to {crops_no_label_dir}")
print(f"[*] Person/car grouping crops saved to {detection_crops_dir}")
# Grid search
conf_candidates = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80]
gap_candidates = [0.02, 0.05, 0.08, 0.10]
def accuracy_jina(conf_t, gap_t):
correct = 0
for r in rows:
pred = ref_labels[r["jina_best_idx"]] if (r["jina_conf"] >= conf_t and r["jina_gap"] >= gap_t) else "unknown"
if pred == r["gt"]:
correct += 1
return correct / len(rows)
def accuracy_nomic(conf_t, gap_t):
correct = 0
for r in rows:
pred = ref_labels_nomic[r["nomic_best_idx"]] if (r["nomic_conf"] >= conf_t and r["nomic_gap"] >= gap_t) else "unknown"
if pred == r["gt"]:
correct += 1
return correct / len(rows)
best_jina_acc = -1
best_jina_conf = best_jina_gap = None
for c in conf_candidates:
for g in gap_candidates:
acc = accuracy_jina(c, g)
if acc > best_jina_acc:
best_jina_acc = acc
best_jina_conf, best_jina_gap = c, g
best_nomic_acc = -1
best_nomic_conf = best_nomic_gap = None
for c in conf_candidates:
for g in gap_candidates:
acc = accuracy_nomic(c, g)
if acc > best_nomic_acc:
best_nomic_acc = acc
best_nomic_conf, best_nomic_gap = c, g
# Report
report_path = output_dir / "best_thresholds.txt"
with open(report_path, "w") as f:
f.write(f"Based on {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}\n")
if save_crops:
f.write(f"Annotated crops: jina_crops/ and nomic_crops/ (conf>={args.save_conf}, gap>={args.save_gap})\n")
f.write("Raw crops (no label): crops/\n")
f.write("Person/car grouping only: detection_crops/\n")
f.write("\n")
f.write("Jina (best accuracy):\n")
f.write(f" conf_threshold = {best_jina_conf}\n")
f.write(f" gap_threshold = {best_jina_gap}\n")
f.write(f" accuracy = {best_jina_acc:.4f}\n\n")
f.write("Nomic (best accuracy):\n")
f.write(f" conf_threshold = {best_nomic_conf}\n")
f.write(f" gap_threshold = {best_nomic_gap}\n")
f.write(f" accuracy = {best_nomic_acc:.4f}\n")
print(f"\n[*] Best thresholds written to {report_path}")
print("\nJina best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format(
best_jina_conf, best_jina_gap, best_jina_acc))
print("Nomic best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format(
best_nomic_conf, best_nomic_gap, best_nomic_acc))
# Full grid CSV
import csv
csv_path = output_dir / "grid_search.csv"
with open(csv_path, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["model", "conf_threshold", "gap_threshold", "accuracy"])
for c in conf_candidates:
for g in gap_candidates:
w.writerow(["jina", c, g, f"{accuracy_jina(c, g):.4f}"])
for c in conf_candidates:
for g in gap_candidates:
w.writerow(["nomic", c, g, f"{accuracy_nomic(c, g):.4f}"])
print(f"[*] Full grid written to {csv_path}")
if __name__ == "__main__":
main()