""" Find best conf_threshold and gap_threshold for Jina and Nomic using COCO ground truth. Expects full_frames/ with images and annotations.coco.json (COCO format). Runs the same detection + crop pipeline, matches each crop to a GT annotation (IoU), then grid-searches (conf_threshold, gap_threshold) to maximize accuracy. """ import argparse import json from pathlib import Path import numpy as np import torch from PIL import Image from transformers import AutoImageProcessor, DFineForObjectDetection from dfine_jina_pipeline import ( box_center_inside, box_iou, deduplicate_by_iou, get_person_car_label_ids, group_detections, run_dfine, squarify_crop_box, ) from jina_fewshot import IMAGE_EXTS, TRUNCATE_DIM, JinaCLIPv2Encoder, build_refs, draw_label_on_image from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic # Our 4 classes (same order as refs) CLASS_NAMES = ["cigarette", "gun", "knife", "phone"] def coco_bbox_to_xyxy(bbox): """COCO bbox [x, y, w, h] -> [x1, y1, x2, y2]. Tolerate string numbers from JSON.""" x, y, w, h = (float(v) for v in bbox) return [x, y, x + w, y + h] def map_category_to_class(name: str) -> str | None: """Map COCO category name to one of our 4 classes, or None if other.""" n = (name or "").strip().lower() if "cigarette" in n: return "cigarette" if any(x in n for x in ("gun", "pistol", "handgun", "firearm")): return "gun" if "knife" in n or "blade" in n: return "knife" if any(x in n for x in ("phone", "cell", "mobile", "smartphone", "telephone")): return "phone" return None def load_coco_gt(annotations_path: Path): """ Load COCO JSON. Returns: - file_to_gts: dict[file_name] = list of (bbox_xyxy, category_name) - categories: list of category dicts from COCO """ with open(annotations_path) as f: data = json.load(f) images = {im["id"]: im for im in data.get("images", [])} categories = {c["id"]: c["name"] for c in data.get("categories", [])} file_to_gts = {} for im in images.values(): file_to_gts[im["file_name"]] = [] for ann in data.get("annotations", []): image_id = ann["image_id"] cat_name = categories.get(ann["category_id"], "") bbox_xyxy = coco_bbox_to_xyxy(ann["bbox"]) file_name = images[image_id]["file_name"] file_to_gts[file_name].append((bbox_xyxy, cat_name)) # Also index by basename for lookup by Path.name by_basename = {} for fn, gts in file_to_gts.items(): by_basename[Path(fn).name] = gts return by_basename, data.get("categories", []) def assign_gt_to_crop(crop_box_xyxy, gt_list, iou_min=0.3): """ Find best overlapping GT for this crop. Returns (gt_class or None, iou). gt_class is one of CLASS_NAMES (mapped from category). """ best_iou = 0.0 best_class = None for bbox_xyxy, cat_name in gt_list: iou = box_iou(crop_box_xyxy, bbox_xyxy) if iou >= iou_min and iou > best_iou: cls = map_category_to_class(cat_name) if cls is not None: best_iou = iou best_class = cls return best_class, best_iou def parse_args(): p = argparse.ArgumentParser(description="Tune Jina/Nomic thresholds using COCO GT") p.add_argument("--input", default="full_frames", help="Folder with images and annotations.coco.json") p.add_argument("--annotations", default=None, help="Path to annotations.coco.json (default: input/_annotations.coco.json)") p.add_argument("--refs", required=True, help="Reference images folder (for Jina + Nomic refs)") p.add_argument("--output", default="threshold_tuning", help="Output folder for results CSV") p.add_argument("--det-threshold", type=float, default=0.3) p.add_argument("--group-dist", type=float, default=None) p.add_argument("--expand", type=float, default=0.3) p.add_argument("--min-side", type=int, default=40) p.add_argument("--text-weight", type=float, default=0.3) p.add_argument("--iou-min", type=float, default=0.3, help="Min IoU to match crop to GT") p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)") p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)") p.add_argument("--max-images", type=int, default=None) p.add_argument("--device", default=None) p.add_argument("--no-save-crops", action="store_true", help="Do not save annotated crop images") p.add_argument("--save-conf", type=float, default=0.5, help="Conf threshold for saved crop labels") p.add_argument("--save-gap", type=float, default=0.02, help="Gap threshold for saved crop labels") return p.parse_args() def main(): args = parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") input_dir = Path(args.input) refs_dir = Path(args.refs) output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) annotations_path = Path(args.annotations) if args.annotations else input_dir / "_annotations.coco.json" if not annotations_path.is_file(): raise SystemExit(f"Annotations not found: {annotations_path}") file_to_gts, _ = load_coco_gt(annotations_path) print(f"[*] Loaded GT for {len(file_to_gts)} images from {annotations_path}") paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS) if args.max_images is not None: paths = paths[: args.max_images] # Only images that appear in COCO paths = [p for p in paths if p.name in file_to_gts] if not paths: raise SystemExit("No images in input that have COCO annotations.") print(f"[*] Loading D-FINE...") image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj365") dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365") dfine_model = dfine_model.to(device).eval() person_car_ids = get_person_car_label_ids(dfine_model) print("[*] Loading Jina-CLIP-v2 and building refs...") jina_encoder = JinaCLIPv2Encoder(device) ref_labels, ref_embs = build_refs( jina_encoder, refs_dir, TRUNCATE_DIM, args.text_weight, batch_size=16 ) assert ref_labels == CLASS_NAMES, f"Ref order {ref_labels}" print("[*] Loading Nomic (vision + text) and building refs (same as Jina: text_weight 0.3)...") nomic_encoder = NomicVisionEncoder(device) nomic_text_encoder = NomicTextEncoder(device) ref_labels_nomic, ref_embs_nomic = build_refs_nomic( nomic_encoder, refs_dir, batch_size=16, text_encoder=nomic_text_encoder, text_weight=args.text_weight, ) # Optional: save annotated crops for Jina and Nomic, raw crops (no label), and person/car grouping crops save_crops = not args.no_save_crops if save_crops: jina_crops_dir = output_dir / "jina_crops" nomic_crops_dir = output_dir / "nomic_crops" crops_no_label_dir = output_dir / "crops" detection_crops_dir = output_dir / "detection_crops" jina_crops_dir.mkdir(parents=True, exist_ok=True) nomic_crops_dir.mkdir(parents=True, exist_ok=True) crops_no_label_dir.mkdir(parents=True, exist_ok=True) detection_crops_dir.mkdir(parents=True, exist_ok=True) # Collect per-crop: gt_class, jina sims/conf/gap, nomic sims/conf/gap (only crops with gt in our 4 classes) rows = [] for img_path in paths: pil = Image.open(img_path).convert("RGB") img_w, img_h = pil.size group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w) detections = run_dfine(pil, image_processor, dfine_model, device, args.det_threshold) person_car = [d for d in detections if d["cls"] in person_car_ids] if not person_car: continue grouped = group_detections(person_car, group_dist) grouped.sort(key=lambda x: x["conf"], reverse=True) gt_list = file_to_gts.get(img_path.name, []) if not gt_list: continue # 1) Collect all candidate crops (bboxes inside person/car groups, with GT match) # Each candidate: (crop_box, crop_pil, gt_class, gidx, crop_idx) candidates = [] for gidx, grp in enumerate(grouped[:10]): x1, y1, x2, y2 = grp["box"] group_box = [x1, y1, x2, y2] # Save person/car grouping output crop (detection crop only) if save_crops: gx1 = max(0, int(x1)) gy1 = max(0, int(y1)) gx2 = min(img_w, int(x2)) gy2 = min(img_h, int(y2)) if gx2 > gx1 and gy2 > gy1: group_crop = pil.crop((gx1, gy1, gx2, gy2)) group_crop.save(detection_crops_dir / f"{img_path.stem}_group{gidx}.jpg") inside = [ d for d in detections if box_center_inside(d["box"], group_box) and d["cls"] not in person_car_ids ] inside = deduplicate_by_iou(inside, iou_threshold=0.9) for crop_idx, d in enumerate(inside): bx1, by1, bx2, by2 = [float(x) for x in d["box"]] obj_w, obj_h = bx2 - bx1, by2 - by1 if obj_w <= 0 or obj_h <= 0: continue pad_x, pad_y = obj_w * args.expand, obj_h * args.expand bx1 = max(0, int(bx1 - pad_x)) by1 = max(0, int(by1 - pad_y)) bx2 = min(img_w, int(bx2 + pad_x)) by2 = min(img_h, int(by2 + pad_y)) if bx2 <= bx1 or by2 <= by1: continue if min(bx2 - bx1, by2 - by1) < args.min_side: continue expanded_box = [bx1, by1, bx2, by2] gt_class, _ = assign_gt_to_crop(expanded_box, gt_list, args.iou_min) if gt_class is None: continue candidates.append((expanded_box, gt_class, gidx, crop_idx)) # 2) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept def crop_area(box): return (box[2] - box[0]) * (box[3] - box[1]) candidates.sort(key=lambda c: -crop_area(c[0])) # largest first kept = [] for c in candidates: expanded_box = c[0] # Skip if same object: IoU above threshold, or one box's center is inside the other def is_same_object(box_a, box_b): if box_iou(box_a, box_b) >= args.crop_dedup_iou: return True if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a): return True return False if not any(is_same_object(expanded_box, k[0]) for k in kept): kept.append(c) # 3) Optionally squarify, then run Jina/Nomic only on kept crops for i, (expanded_box, gt_class, gidx, crop_idx) in enumerate(kept): if not args.no_squarify: bx1, by1, bx2, by2 = squarify_crop_box( expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3], img_w, img_h ) else: bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3] crop_box = [bx1, by1, bx2, by2] crop_pil = pil.crop((bx1, by1, bx2, by2)) bbox_suffix = f"_{bx1}_{by1}_{bx2}_{by2}" crop_name = f"{img_path.stem}_g{gidx}_{i}{bbox_suffix}.jpg" q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM) sims_jina = (q_jina @ ref_embs.T).squeeze(0) best_jina = int(np.argmax(sims_jina)) conf_jina = float(sims_jina[best_jina]) gap_jina = float(sims_jina[best_jina] - np.partition(sims_jina, -2)[-2]) q_nomic = nomic_encoder.encode_images([crop_pil]) sims_nomic = (q_nomic @ ref_embs_nomic.T).squeeze(0) best_nomic = int(np.argmax(sims_nomic)) conf_nomic = float(sims_nomic[best_nomic]) gap_nomic = float(sims_nomic[best_nomic] - np.partition(sims_nomic, -2)[-2]) rows.append({ "gt": gt_class, "jina_best_idx": best_jina, "jina_conf": conf_jina, "jina_gap": gap_jina, "nomic_best_idx": best_nomic, "nomic_conf": conf_nomic, "nomic_gap": gap_nomic, }) if save_crops: crop_pil.save(crops_no_label_dir / crop_name) sc, sg = args.save_conf, args.save_gap label_jina = ref_labels[best_jina] if (conf_jina >= sc and gap_jina >= sg) else f"unknown (gt:{gt_class})" label_nomic = ref_labels_nomic[best_nomic] if (conf_nomic >= sc and gap_nomic >= sg) else f"unknown (gt:{gt_class})" ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina) ann_nomic = draw_label_on_image(crop_pil, label_nomic, conf_nomic) ann_jina.save(jina_crops_dir / crop_name) ann_nomic.save(nomic_crops_dir / crop_name) if not rows: raise SystemExit("No crops matched to GT (with our 4 classes). Check annotations and iou_min.") print(f"[*] {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}") if save_crops: print(f"[*] Annotated crops saved to {jina_crops_dir} and {nomic_crops_dir}") print(f"[*] Raw crops (no label) saved to {crops_no_label_dir}") print(f"[*] Person/car grouping crops saved to {detection_crops_dir}") # Grid search conf_candidates = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80] gap_candidates = [0.02, 0.05, 0.08, 0.10] def accuracy_jina(conf_t, gap_t): correct = 0 for r in rows: pred = ref_labels[r["jina_best_idx"]] if (r["jina_conf"] >= conf_t and r["jina_gap"] >= gap_t) else "unknown" if pred == r["gt"]: correct += 1 return correct / len(rows) def accuracy_nomic(conf_t, gap_t): correct = 0 for r in rows: pred = ref_labels_nomic[r["nomic_best_idx"]] if (r["nomic_conf"] >= conf_t and r["nomic_gap"] >= gap_t) else "unknown" if pred == r["gt"]: correct += 1 return correct / len(rows) best_jina_acc = -1 best_jina_conf = best_jina_gap = None for c in conf_candidates: for g in gap_candidates: acc = accuracy_jina(c, g) if acc > best_jina_acc: best_jina_acc = acc best_jina_conf, best_jina_gap = c, g best_nomic_acc = -1 best_nomic_conf = best_nomic_gap = None for c in conf_candidates: for g in gap_candidates: acc = accuracy_nomic(c, g) if acc > best_nomic_acc: best_nomic_acc = acc best_nomic_conf, best_nomic_gap = c, g # Report report_path = output_dir / "best_thresholds.txt" with open(report_path, "w") as f: f.write(f"Based on {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}\n") if save_crops: f.write(f"Annotated crops: jina_crops/ and nomic_crops/ (conf>={args.save_conf}, gap>={args.save_gap})\n") f.write("Raw crops (no label): crops/\n") f.write("Person/car grouping only: detection_crops/\n") f.write("\n") f.write("Jina (best accuracy):\n") f.write(f" conf_threshold = {best_jina_conf}\n") f.write(f" gap_threshold = {best_jina_gap}\n") f.write(f" accuracy = {best_jina_acc:.4f}\n\n") f.write("Nomic (best accuracy):\n") f.write(f" conf_threshold = {best_nomic_conf}\n") f.write(f" gap_threshold = {best_nomic_gap}\n") f.write(f" accuracy = {best_nomic_acc:.4f}\n") print(f"\n[*] Best thresholds written to {report_path}") print("\nJina best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format( best_jina_conf, best_jina_gap, best_jina_acc)) print("Nomic best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format( best_nomic_conf, best_nomic_gap, best_nomic_acc)) # Full grid CSV import csv csv_path = output_dir / "grid_search.csv" with open(csv_path, "w", newline="") as f: w = csv.writer(f) w.writerow(["model", "conf_threshold", "gap_threshold", "accuracy"]) for c in conf_candidates: for g in gap_candidates: w.writerow(["jina", c, g, f"{accuracy_jina(c, g):.4f}"]) for c in conf_candidates: for g in gap_candidates: w.writerow(["nomic", c, g, f"{accuracy_nomic(c, g):.4f}"]) print(f"[*] Full grid written to {csv_path}") if __name__ == "__main__": main()