Spaces:
Sleeping
Sleeping
| """ | |
| Find best conf_threshold and gap_threshold for Jina and Nomic using COCO ground truth. | |
| Expects full_frames/ with images and annotations.coco.json (COCO format). | |
| Runs the same detection + crop pipeline, matches each crop to a GT annotation (IoU), | |
| then grid-searches (conf_threshold, gap_threshold) to maximize accuracy. | |
| """ | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, DFineForObjectDetection | |
| from dfine_jina_pipeline import ( | |
| box_center_inside, | |
| box_iou, | |
| deduplicate_by_iou, | |
| get_person_car_label_ids, | |
| group_detections, | |
| run_dfine, | |
| squarify_crop_box, | |
| ) | |
| from jina_fewshot import IMAGE_EXTS, TRUNCATE_DIM, JinaCLIPv2Encoder, build_refs, draw_label_on_image | |
| from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic | |
| # Our 4 classes (same order as refs) | |
| CLASS_NAMES = ["cigarette", "gun", "knife", "phone"] | |
| def coco_bbox_to_xyxy(bbox): | |
| """COCO bbox [x, y, w, h] -> [x1, y1, x2, y2]. Tolerate string numbers from JSON.""" | |
| x, y, w, h = (float(v) for v in bbox) | |
| return [x, y, x + w, y + h] | |
| def map_category_to_class(name: str) -> str | None: | |
| """Map COCO category name to one of our 4 classes, or None if other.""" | |
| n = (name or "").strip().lower() | |
| if "cigarette" in n: | |
| return "cigarette" | |
| if any(x in n for x in ("gun", "pistol", "handgun", "firearm")): | |
| return "gun" | |
| if "knife" in n or "blade" in n: | |
| return "knife" | |
| if any(x in n for x in ("phone", "cell", "mobile", "smartphone", "telephone")): | |
| return "phone" | |
| return None | |
| def load_coco_gt(annotations_path: Path): | |
| """ | |
| Load COCO JSON. Returns: | |
| - file_to_gts: dict[file_name] = list of (bbox_xyxy, category_name) | |
| - categories: list of category dicts from COCO | |
| """ | |
| with open(annotations_path) as f: | |
| data = json.load(f) | |
| images = {im["id"]: im for im in data.get("images", [])} | |
| categories = {c["id"]: c["name"] for c in data.get("categories", [])} | |
| file_to_gts = {} | |
| for im in images.values(): | |
| file_to_gts[im["file_name"]] = [] | |
| for ann in data.get("annotations", []): | |
| image_id = ann["image_id"] | |
| cat_name = categories.get(ann["category_id"], "") | |
| bbox_xyxy = coco_bbox_to_xyxy(ann["bbox"]) | |
| file_name = images[image_id]["file_name"] | |
| file_to_gts[file_name].append((bbox_xyxy, cat_name)) | |
| # Also index by basename for lookup by Path.name | |
| by_basename = {} | |
| for fn, gts in file_to_gts.items(): | |
| by_basename[Path(fn).name] = gts | |
| return by_basename, data.get("categories", []) | |
| def assign_gt_to_crop(crop_box_xyxy, gt_list, iou_min=0.3): | |
| """ | |
| Find best overlapping GT for this crop. Returns (gt_class or None, iou). | |
| gt_class is one of CLASS_NAMES (mapped from category). | |
| """ | |
| best_iou = 0.0 | |
| best_class = None | |
| for bbox_xyxy, cat_name in gt_list: | |
| iou = box_iou(crop_box_xyxy, bbox_xyxy) | |
| if iou >= iou_min and iou > best_iou: | |
| cls = map_category_to_class(cat_name) | |
| if cls is not None: | |
| best_iou = iou | |
| best_class = cls | |
| return best_class, best_iou | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Tune Jina/Nomic thresholds using COCO GT") | |
| p.add_argument("--input", default="full_frames", help="Folder with images and annotations.coco.json") | |
| p.add_argument("--annotations", default=None, help="Path to annotations.coco.json (default: input/_annotations.coco.json)") | |
| p.add_argument("--refs", required=True, help="Reference images folder (for Jina + Nomic refs)") | |
| p.add_argument("--output", default="threshold_tuning", help="Output folder for results CSV") | |
| p.add_argument("--det-threshold", type=float, default=0.3) | |
| p.add_argument("--group-dist", type=float, default=None) | |
| p.add_argument("--expand", type=float, default=0.3) | |
| p.add_argument("--min-side", type=int, default=40) | |
| p.add_argument("--text-weight", type=float, default=0.3) | |
| p.add_argument("--iou-min", type=float, default=0.3, help="Min IoU to match crop to GT") | |
| p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)") | |
| p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)") | |
| p.add_argument("--max-images", type=int, default=None) | |
| p.add_argument("--device", default=None) | |
| p.add_argument("--no-save-crops", action="store_true", help="Do not save annotated crop images") | |
| p.add_argument("--save-conf", type=float, default=0.5, help="Conf threshold for saved crop labels") | |
| p.add_argument("--save-gap", type=float, default=0.02, help="Gap threshold for saved crop labels") | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| input_dir = Path(args.input) | |
| refs_dir = Path(args.refs) | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| annotations_path = Path(args.annotations) if args.annotations else input_dir / "_annotations.coco.json" | |
| if not annotations_path.is_file(): | |
| raise SystemExit(f"Annotations not found: {annotations_path}") | |
| file_to_gts, _ = load_coco_gt(annotations_path) | |
| print(f"[*] Loaded GT for {len(file_to_gts)} images from {annotations_path}") | |
| paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS) | |
| if args.max_images is not None: | |
| paths = paths[: args.max_images] | |
| # Only images that appear in COCO | |
| paths = [p for p in paths if p.name in file_to_gts] | |
| if not paths: | |
| raise SystemExit("No images in input that have COCO annotations.") | |
| print(f"[*] Loading D-FINE...") | |
| image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj365") | |
| dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365") | |
| dfine_model = dfine_model.to(device).eval() | |
| person_car_ids = get_person_car_label_ids(dfine_model) | |
| print("[*] Loading Jina-CLIP-v2 and building refs...") | |
| jina_encoder = JinaCLIPv2Encoder(device) | |
| ref_labels, ref_embs = build_refs( | |
| jina_encoder, refs_dir, TRUNCATE_DIM, args.text_weight, batch_size=16 | |
| ) | |
| assert ref_labels == CLASS_NAMES, f"Ref order {ref_labels}" | |
| print("[*] Loading Nomic (vision + text) and building refs (same as Jina: text_weight 0.3)...") | |
| nomic_encoder = NomicVisionEncoder(device) | |
| nomic_text_encoder = NomicTextEncoder(device) | |
| ref_labels_nomic, ref_embs_nomic = build_refs_nomic( | |
| nomic_encoder, refs_dir, batch_size=16, | |
| text_encoder=nomic_text_encoder, text_weight=args.text_weight, | |
| ) | |
| # Optional: save annotated crops for Jina and Nomic, raw crops (no label), and person/car grouping crops | |
| save_crops = not args.no_save_crops | |
| if save_crops: | |
| jina_crops_dir = output_dir / "jina_crops" | |
| nomic_crops_dir = output_dir / "nomic_crops" | |
| crops_no_label_dir = output_dir / "crops" | |
| detection_crops_dir = output_dir / "detection_crops" | |
| jina_crops_dir.mkdir(parents=True, exist_ok=True) | |
| nomic_crops_dir.mkdir(parents=True, exist_ok=True) | |
| crops_no_label_dir.mkdir(parents=True, exist_ok=True) | |
| detection_crops_dir.mkdir(parents=True, exist_ok=True) | |
| # Collect per-crop: gt_class, jina sims/conf/gap, nomic sims/conf/gap (only crops with gt in our 4 classes) | |
| rows = [] | |
| for img_path in paths: | |
| pil = Image.open(img_path).convert("RGB") | |
| img_w, img_h = pil.size | |
| group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w) | |
| detections = run_dfine(pil, image_processor, dfine_model, device, args.det_threshold) | |
| person_car = [d for d in detections if d["cls"] in person_car_ids] | |
| if not person_car: | |
| continue | |
| grouped = group_detections(person_car, group_dist) | |
| grouped.sort(key=lambda x: x["conf"], reverse=True) | |
| gt_list = file_to_gts.get(img_path.name, []) | |
| if not gt_list: | |
| continue | |
| # 1) Collect all candidate crops (bboxes inside person/car groups, with GT match) | |
| # Each candidate: (crop_box, crop_pil, gt_class, gidx, crop_idx) | |
| candidates = [] | |
| for gidx, grp in enumerate(grouped[:10]): | |
| x1, y1, x2, y2 = grp["box"] | |
| group_box = [x1, y1, x2, y2] | |
| # Save person/car grouping output crop (detection crop only) | |
| if save_crops: | |
| gx1 = max(0, int(x1)) | |
| gy1 = max(0, int(y1)) | |
| gx2 = min(img_w, int(x2)) | |
| gy2 = min(img_h, int(y2)) | |
| if gx2 > gx1 and gy2 > gy1: | |
| group_crop = pil.crop((gx1, gy1, gx2, gy2)) | |
| group_crop.save(detection_crops_dir / f"{img_path.stem}_group{gidx}.jpg") | |
| inside = [ | |
| d for d in detections | |
| if box_center_inside(d["box"], group_box) | |
| and d["cls"] not in person_car_ids | |
| ] | |
| inside = deduplicate_by_iou(inside, iou_threshold=0.9) | |
| for crop_idx, d in enumerate(inside): | |
| bx1, by1, bx2, by2 = [float(x) for x in d["box"]] | |
| obj_w, obj_h = bx2 - bx1, by2 - by1 | |
| if obj_w <= 0 or obj_h <= 0: | |
| continue | |
| pad_x, pad_y = obj_w * args.expand, obj_h * args.expand | |
| bx1 = max(0, int(bx1 - pad_x)) | |
| by1 = max(0, int(by1 - pad_y)) | |
| bx2 = min(img_w, int(bx2 + pad_x)) | |
| by2 = min(img_h, int(by2 + pad_y)) | |
| if bx2 <= bx1 or by2 <= by1: | |
| continue | |
| if min(bx2 - bx1, by2 - by1) < args.min_side: | |
| continue | |
| expanded_box = [bx1, by1, bx2, by2] | |
| gt_class, _ = assign_gt_to_crop(expanded_box, gt_list, args.iou_min) | |
| if gt_class is None: | |
| continue | |
| candidates.append((expanded_box, gt_class, gidx, crop_idx)) | |
| # 2) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept | |
| def crop_area(box): | |
| return (box[2] - box[0]) * (box[3] - box[1]) | |
| candidates.sort(key=lambda c: -crop_area(c[0])) # largest first | |
| kept = [] | |
| for c in candidates: | |
| expanded_box = c[0] | |
| # Skip if same object: IoU above threshold, or one box's center is inside the other | |
| def is_same_object(box_a, box_b): | |
| if box_iou(box_a, box_b) >= args.crop_dedup_iou: | |
| return True | |
| if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a): | |
| return True | |
| return False | |
| if not any(is_same_object(expanded_box, k[0]) for k in kept): | |
| kept.append(c) | |
| # 3) Optionally squarify, then run Jina/Nomic only on kept crops | |
| for i, (expanded_box, gt_class, gidx, crop_idx) in enumerate(kept): | |
| if not args.no_squarify: | |
| bx1, by1, bx2, by2 = squarify_crop_box( | |
| expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3], img_w, img_h | |
| ) | |
| else: | |
| bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3] | |
| crop_box = [bx1, by1, bx2, by2] | |
| crop_pil = pil.crop((bx1, by1, bx2, by2)) | |
| bbox_suffix = f"_{bx1}_{by1}_{bx2}_{by2}" | |
| crop_name = f"{img_path.stem}_g{gidx}_{i}{bbox_suffix}.jpg" | |
| q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM) | |
| sims_jina = (q_jina @ ref_embs.T).squeeze(0) | |
| best_jina = int(np.argmax(sims_jina)) | |
| conf_jina = float(sims_jina[best_jina]) | |
| gap_jina = float(sims_jina[best_jina] - np.partition(sims_jina, -2)[-2]) | |
| q_nomic = nomic_encoder.encode_images([crop_pil]) | |
| sims_nomic = (q_nomic @ ref_embs_nomic.T).squeeze(0) | |
| best_nomic = int(np.argmax(sims_nomic)) | |
| conf_nomic = float(sims_nomic[best_nomic]) | |
| gap_nomic = float(sims_nomic[best_nomic] - np.partition(sims_nomic, -2)[-2]) | |
| rows.append({ | |
| "gt": gt_class, | |
| "jina_best_idx": best_jina, | |
| "jina_conf": conf_jina, | |
| "jina_gap": gap_jina, | |
| "nomic_best_idx": best_nomic, | |
| "nomic_conf": conf_nomic, | |
| "nomic_gap": gap_nomic, | |
| }) | |
| if save_crops: | |
| crop_pil.save(crops_no_label_dir / crop_name) | |
| sc, sg = args.save_conf, args.save_gap | |
| label_jina = ref_labels[best_jina] if (conf_jina >= sc and gap_jina >= sg) else f"unknown (gt:{gt_class})" | |
| label_nomic = ref_labels_nomic[best_nomic] if (conf_nomic >= sc and gap_nomic >= sg) else f"unknown (gt:{gt_class})" | |
| ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina) | |
| ann_nomic = draw_label_on_image(crop_pil, label_nomic, conf_nomic) | |
| ann_jina.save(jina_crops_dir / crop_name) | |
| ann_nomic.save(nomic_crops_dir / crop_name) | |
| if not rows: | |
| raise SystemExit("No crops matched to GT (with our 4 classes). Check annotations and iou_min.") | |
| print(f"[*] {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}") | |
| if save_crops: | |
| print(f"[*] Annotated crops saved to {jina_crops_dir} and {nomic_crops_dir}") | |
| print(f"[*] Raw crops (no label) saved to {crops_no_label_dir}") | |
| print(f"[*] Person/car grouping crops saved to {detection_crops_dir}") | |
| # Grid search | |
| conf_candidates = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80] | |
| gap_candidates = [0.02, 0.05, 0.08, 0.10] | |
| def accuracy_jina(conf_t, gap_t): | |
| correct = 0 | |
| for r in rows: | |
| pred = ref_labels[r["jina_best_idx"]] if (r["jina_conf"] >= conf_t and r["jina_gap"] >= gap_t) else "unknown" | |
| if pred == r["gt"]: | |
| correct += 1 | |
| return correct / len(rows) | |
| def accuracy_nomic(conf_t, gap_t): | |
| correct = 0 | |
| for r in rows: | |
| pred = ref_labels_nomic[r["nomic_best_idx"]] if (r["nomic_conf"] >= conf_t and r["nomic_gap"] >= gap_t) else "unknown" | |
| if pred == r["gt"]: | |
| correct += 1 | |
| return correct / len(rows) | |
| best_jina_acc = -1 | |
| best_jina_conf = best_jina_gap = None | |
| for c in conf_candidates: | |
| for g in gap_candidates: | |
| acc = accuracy_jina(c, g) | |
| if acc > best_jina_acc: | |
| best_jina_acc = acc | |
| best_jina_conf, best_jina_gap = c, g | |
| best_nomic_acc = -1 | |
| best_nomic_conf = best_nomic_gap = None | |
| for c in conf_candidates: | |
| for g in gap_candidates: | |
| acc = accuracy_nomic(c, g) | |
| if acc > best_nomic_acc: | |
| best_nomic_acc = acc | |
| best_nomic_conf, best_nomic_gap = c, g | |
| # Report | |
| report_path = output_dir / "best_thresholds.txt" | |
| with open(report_path, "w") as f: | |
| f.write(f"Based on {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}\n") | |
| if save_crops: | |
| f.write(f"Annotated crops: jina_crops/ and nomic_crops/ (conf>={args.save_conf}, gap>={args.save_gap})\n") | |
| f.write("Raw crops (no label): crops/\n") | |
| f.write("Person/car grouping only: detection_crops/\n") | |
| f.write("\n") | |
| f.write("Jina (best accuracy):\n") | |
| f.write(f" conf_threshold = {best_jina_conf}\n") | |
| f.write(f" gap_threshold = {best_jina_gap}\n") | |
| f.write(f" accuracy = {best_jina_acc:.4f}\n\n") | |
| f.write("Nomic (best accuracy):\n") | |
| f.write(f" conf_threshold = {best_nomic_conf}\n") | |
| f.write(f" gap_threshold = {best_nomic_gap}\n") | |
| f.write(f" accuracy = {best_nomic_acc:.4f}\n") | |
| print(f"\n[*] Best thresholds written to {report_path}") | |
| print("\nJina best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format( | |
| best_jina_conf, best_jina_gap, best_jina_acc)) | |
| print("Nomic best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format( | |
| best_nomic_conf, best_nomic_gap, best_nomic_acc)) | |
| # Full grid CSV | |
| import csv | |
| csv_path = output_dir / "grid_search.csv" | |
| with open(csv_path, "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["model", "conf_threshold", "gap_threshold", "accuracy"]) | |
| for c in conf_candidates: | |
| for g in gap_candidates: | |
| w.writerow(["jina", c, g, f"{accuracy_jina(c, g):.4f}"]) | |
| for c in conf_candidates: | |
| for g in gap_candidates: | |
| w.writerow(["nomic", c, g, f"{accuracy_nomic(c, g):.4f}"]) | |
| print(f"[*] Full grid written to {csv_path}") | |
| if __name__ == "__main__": | |
| main() | |