Napron commited on
Commit
d6dee9e
·
verified ·
1 Parent(s): cd64594

Update dfine_jina_pipeline.py

Browse files
Files changed (1) hide show
  1. dfine_jina_pipeline.py +327 -379
dfine_jina_pipeline.py CHANGED
@@ -1,8 +1,6 @@
1
- """
2
- Pipeline: D-FINE (person/car only)group detections crop group regions
3
- classify all inner object detections with Jina-CLIP-v2 or Nomic.
4
  Outputs separate crop folders per model (jina_crops, nomic_crops) for visual comparison.
5
- Each saved image is the D-FINE group crop, with bboxes drawn only for known classes.
6
  """
7
 
8
  import argparse
@@ -12,7 +10,8 @@ from pathlib import Path
12
 
13
  import numpy as np
14
  import torch
15
- from PIL import Image, ImageDraw, ImageFont
 
16
  from transformers import AutoImageProcessor, DFineForObjectDetection
17
 
18
  # Jina-CLIP-v2 few-shot (same refs + classify as jina_fewshot.py)
@@ -22,16 +21,16 @@ from jina_fewshot import (
22
  JinaCLIPv2Encoder,
23
  build_refs,
24
  classify as jina_classify,
 
25
  )
26
 
27
  from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic
28
 
29
 
30
  # -----------------------------------------------------------------------------
31
- # Detection + grouping
32
  # -----------------------------------------------------------------------------
33
 
34
-
35
  def get_box_dist(box1, box2):
36
  """Euclidean distance between box centers. box = [x1, y1, x2, y2]."""
37
  c1 = np.array([(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2])
@@ -42,6 +41,7 @@ def get_box_dist(box1, box2):
42
  def group_detections(detections, threshold):
43
  """
44
  Group detections by proximity (center distance < threshold).
 
45
  detections: list of {"box": [x1,y1,x2,y2], "conf", "cls", ...}
46
  Returns list of {"box": merged [x1,y1,x2,y2], "conf": best in group, "cls": best in group}.
47
  """
@@ -51,6 +51,7 @@ def group_detections(detections, threshold):
51
  boxes = [d["box"] for d in detections]
52
  n = len(boxes)
53
  adj = {i: [] for i in range(n)}
 
54
  for i in range(n):
55
  for j in range(i + 1, n):
56
  if get_box_dist(boxes[i], boxes[j]) < threshold:
@@ -59,14 +60,17 @@ def group_detections(detections, threshold):
59
 
60
  groups = []
61
  visited = [False] * n
 
62
  for i in range(n):
63
  if not visited[i]:
64
  group_indices = []
65
  stack = [i]
66
  visited[i] = True
 
67
  while stack:
68
  curr = stack.pop()
69
  group_indices.append(curr)
 
70
  for neighbor in adj[curr]:
71
  if not visited[neighbor]:
72
  visited[neighbor] = True
@@ -77,14 +81,15 @@ def group_detections(detections, threshold):
77
  y1 = min(d["box"][1] for d in group_dets)
78
  x2 = max(d["box"][2] for d in group_dets)
79
  y2 = max(d["box"][3] for d in group_dets)
80
- best_det = max(group_dets, key=lambda x: x["conf"])
81
 
 
82
  groups.append({
83
  "box": [x1, y1, x2, y2],
84
  "conf": best_det["conf"],
85
  "cls": best_det["cls"],
86
  "label": best_det.get("label", str(best_det["cls"])),
87
  })
 
88
  return groups
89
 
90
 
@@ -92,7 +97,10 @@ def box_center_inside(box, crop_box):
92
  """True if center of box is inside crop_box. All [x1,y1,x2,y2]."""
93
  cx = (box[0] + box[2]) / 2
94
  cy = (box[1] + box[3]) / 2
95
- return crop_box[0] <= cx <= crop_box[2] and crop_box[1] <= cy <= crop_box[3]
 
 
 
96
 
97
 
98
  def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
@@ -104,8 +112,10 @@ def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
104
  orig = (int(bx1), int(by1), int(bx2), int(by2))
105
  w = bx2 - bx1
106
  h = by2 - by1
 
107
  if w <= 0 or h <= 0:
108
  return orig
 
109
  if h > w:
110
  add = (h - w) / 2.0
111
  bx1 = max(0, bx1 - add)
@@ -114,9 +124,12 @@ def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
114
  add = (w - h) / 2.0
115
  by1 = max(0, by1 - add)
116
  by2 = min(img_h, by2 + add)
 
117
  bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)
 
118
  if bx2 <= bx1 or by2 <= by1:
119
  return orig
 
120
  return bx1, by1, bx2, by2
121
 
122
 
@@ -126,12 +139,15 @@ def box_iou(box1, box2):
126
  iy1 = max(box1[1], box2[1])
127
  ix2 = min(box1[2], box2[2])
128
  iy2 = min(box1[3], box2[3])
 
129
  inter_w = max(0, ix2 - ix1)
130
  inter_h = max(0, iy2 - iy1)
131
  inter = inter_w * inter_h
 
132
  a1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
133
  a2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
134
  union = a1 + a2 - inter
 
135
  return inter / union if union > 0 else 0.0
136
 
137
 
@@ -139,190 +155,33 @@ def deduplicate_by_iou(detections, iou_threshold=0.9):
139
  """Keep one detection per overlapping group (IoU >= iou_threshold). Prefer higher confidence."""
140
  if not detections:
141
  return []
 
 
142
  sorted_d = sorted(detections, key=lambda x: -x["conf"])
143
  kept = []
 
144
  for d in sorted_d:
145
  if not any(box_iou(d["box"], k["box"]) >= iou_threshold for k in kept):
146
  kept.append(d)
147
- return kept
148
-
149
-
150
- # -----------------------------------------------------------------------------
151
- # Drawing / layout helpers
152
- # -----------------------------------------------------------------------------
153
-
154
-
155
- def _load_font_for_box(img_w, img_h):
156
- font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
157
- size = max(10, min(img_w, img_h) // 14)
158
- try:
159
- return ImageFont.truetype(font_path, size=size)
160
- except OSError:
161
- return ImageFont.load_default()
162
-
163
-
164
- def draw_bbox_with_optional_label(
165
- img,
166
- box,
167
- label=None,
168
- confidence=None,
169
- color=(255, 0, 0),
170
- width=3,
171
- ):
172
- """
173
- Draw bbox on the given image. If label is None, draw only the box.
174
- box is [x1, y1, x2, y2] in this image's coordinate space.
175
- """
176
- out = img.copy().convert("RGB")
177
- draw = ImageDraw.Draw(out)
178
- x1, y1, x2, y2 = [int(v) for v in box]
179
-
180
- x1 = max(0, min(x1, out.width - 1))
181
- y1 = max(0, min(y1, out.height - 1))
182
- x2 = max(0, min(x2, out.width - 1))
183
- y2 = max(0, min(y2, out.height - 1))
184
-
185
- for i in range(width):
186
- draw.rectangle([x1 - i, y1 - i, x2 + i, y2 + i], outline=color)
187
-
188
- if label:
189
- font = _load_font_for_box(out.width, out.height)
190
- text = f"{label} ({confidence:.2f})" if confidence is not None else label
191
- bbox = draw.textbbox((0, 0), text, font=font)
192
- tw = bbox[2] - bbox[0]
193
- th = bbox[3] - bbox[1]
194
-
195
- tx = max(0, min(x1, out.width - tw - 8))
196
- ty = y1 - th - 8
197
- if ty < 0:
198
- ty = min(out.height - th - 4, y1 + 4)
199
-
200
- draw.rectangle([tx, ty, tx + tw + 8, ty + th + 6], fill=color)
201
- draw.text((tx + 4, ty + 3), text, fill=(255, 255, 255), font=font)
202
-
203
- return out
204
-
205
-
206
- def stack_images_vertical(images, bg=(255, 255, 255), pad=10):
207
- """Stack PIL images vertically into one output image."""
208
- if not images:
209
- return None
210
- widths = [img.width for img in images]
211
- heights = [img.height for img in images]
212
- out_w = max(widths)
213
- out_h = sum(heights) + pad * (len(images) - 1)
214
- canvas = Image.new("RGB", (out_w, out_h), color=bg)
215
-
216
- y = 0
217
- for img in images:
218
- x = (out_w - img.width) // 2
219
- canvas.paste(img, (x, y))
220
- y += img.height + pad
221
- return canvas
222
-
223
-
224
- # -----------------------------------------------------------------------------
225
- # Shared crop/object preparation
226
- # -----------------------------------------------------------------------------
227
-
228
-
229
- def expand_box_with_padding(box, img_w, img_h, padding):
230
- x1, y1, x2, y2 = [float(v) for v in box]
231
- w = x2 - x1
232
- h = y2 - y1
233
- if w <= 0 or h <= 0:
234
- return None
235
- pad_x = w * padding
236
- pad_y = h * padding
237
- ex1 = max(0, int(x1 - pad_x))
238
- ey1 = max(0, int(y1 - pad_y))
239
- ex2 = min(img_w, int(x2 + pad_x))
240
- ey2 = min(img_h, int(y2 + pad_y))
241
- if ex2 <= ex1 or ey2 <= ey1:
242
- return None
243
- return [ex1, ey1, ex2, ey2]
244
-
245
-
246
- def build_group_crop_box(group_box, img_w, img_h, padding=0.2, squarify=True):
247
- expanded = expand_box_with_padding(group_box, img_w, img_h, padding)
248
- if expanded is None:
249
- return None
250
- if squarify:
251
- return list(squarify_crop_box(expanded[0], expanded[1], expanded[2], expanded[3], img_w, img_h))
252
- return expanded
253
-
254
-
255
- def collect_group_object_candidates(
256
- detections,
257
- person_car_ids,
258
- group_box,
259
- img_w,
260
- img_h,
261
- min_side,
262
- crop_dedup_iou,
263
- object_padding=0.3,
264
- ):
265
- """
266
- For one group crop, collect and deduplicate all non-person/car detections inside it.
267
- Returns list of dicts with:
268
- {
269
- "det": original detection,
270
- "expanded_box": expanded object crop box in full-image coords
271
- }
272
- """
273
- inside = [
274
- d for d in detections
275
- if box_center_inside(d["box"], group_box) and d["cls"] not in person_car_ids
276
- ]
277
- inside = deduplicate_by_iou(inside, iou_threshold=0.9)
278
-
279
- candidates = []
280
- for d in inside:
281
- expanded_box = expand_box_with_padding(d["box"], img_w, img_h, object_padding)
282
- if expanded_box is None:
283
- continue
284
- if min(expanded_box[2] - expanded_box[0], expanded_box[3] - expanded_box[1]) < min_side:
285
- continue
286
- candidates.append({
287
- "det": d,
288
- "expanded_box": expanded_box,
289
- })
290
-
291
- def crop_area(box):
292
- return (box[2] - box[0]) * (box[3] - box[1])
293
-
294
- candidates.sort(key=lambda c: -crop_area(c["expanded_box"]))
295
- kept = []
296
-
297
- def is_same_object(box_a, box_b):
298
- if box_iou(box_a, box_b) >= crop_dedup_iou:
299
- return True
300
- if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
301
- return True
302
- return False
303
-
304
- for c in candidates:
305
- if not any(is_same_object(c["expanded_box"], k["expanded_box"]) for k in kept):
306
- kept.append(c)
307
 
308
  return kept
309
 
310
 
311
  def parse_args():
312
  p = argparse.ArgumentParser(
313
- description="D-FINE (person/car) → group → classify objects inside each group crop"
314
  )
315
  p.add_argument("--refs", required=True, help="Reference images folder for Jina and Nomic (e.g. refs/)")
316
  p.add_argument("--input", required=True, help="Full-frame images folder")
317
  p.add_argument("--output", default="pipeline_results", help="Output folder (CSV, etc.)")
318
  p.add_argument("--det-threshold", type=float, default=0.13, help="D-FINE score threshold")
319
  p.add_argument("--group-dist", type=float, default=None, help="Group distance (default: 0.1 * max(H,W))")
320
- p.add_argument("--min-side", type=int, default=40, help="Min side of expanded object bbox in px (skip smaller)")
321
- p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two object crops as same object")
322
- p.add_argument("--no-squarify", action="store_true", help="Skip squarify on group crop")
323
- p.add_argument("--padding", type=float, default=0.2, help="Padding around D-FINE group crop box")
324
- p.add_argument("--conf-threshold", type=float, default=0.75, help="Accept confidence")
325
- p.add_argument("--gap-threshold", type=float, default=0.05, help="Accept gap")
326
  p.add_argument("--text-weight", type=float, default=0.3)
327
  p.add_argument("--max-images", type=int, default=None)
328
  p.add_argument("--device", default=None)
@@ -333,25 +192,32 @@ def get_person_car_label_ids(model):
333
  """Return set of label IDs for person and car (Objects365: Person, Car, SUV, etc.)."""
334
  id2label = getattr(model.config, "id2label", None) or {}
335
  ids = set()
 
336
  for idx, name in id2label.items():
337
  try:
338
  i = int(idx)
339
  except (ValueError, TypeError):
340
  continue
 
341
  n = (name or "").lower()
342
  if "person" in n or n in ("car", "suv"):
343
  ids.add(i)
 
344
  return ids
345
 
346
 
347
  def run_dfine(image, processor, model, device, score_threshold):
348
  """Run D-FINE, return all detections as list of {box, score, label_id, label}."""
 
 
349
  if isinstance(image, Image.Image):
350
  pil = image.convert("RGB")
351
  else:
352
  pil = Image.fromarray(image).convert("RGB")
 
353
  w, h = pil.size
354
  target_size = torch.tensor([[h, w]], device=device)
 
355
  inputs = processor(images=pil, return_tensors="pt")
356
  inputs = {k: v.to(device) for k, v in inputs.items()}
357
 
@@ -360,13 +226,20 @@ def run_dfine(image, processor, model, device, score_threshold):
360
 
361
  target_sizes = target_size.to(outputs["logits"].device)
362
  results = processor.post_process_object_detection(
363
- outputs, target_sizes=target_sizes, threshold=score_threshold
 
 
364
  )
365
- id2label = getattr(model.config, "id2label", {}) or {}
366
 
 
367
  detections = []
 
368
  for result in results:
369
- for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
 
 
 
 
370
  sid = int(label_id.item())
371
  detections.append({
372
  "box": [float(x) for x in box.cpu().tolist()],
@@ -374,54 +247,14 @@ def run_dfine(image, processor, model, device, score_threshold):
374
  "cls": sid,
375
  "label": id2label.get(sid, str(sid)),
376
  })
377
- return detections
378
-
379
-
380
- def annotate_group_crop(
381
- pil,
382
- detections,
383
- person_car_ids,
384
- group_box,
385
- crop_box,
386
- encoder_choice,
387
- encoder,
388
- ref_labels,
389
- ref_embs,
390
- conf_threshold,
391
- gap_threshold,
392
- ):
393
- """
394
- Build one D-FINE group crop, classify all object candidates inside it,
395
- and draw bbox+label only for known classes.
396
-
397
- Returns:
398
- annotated_crop_pil,
399
- rows_for_csv,
400
- known_lines
401
- """
402
- crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
403
- group_crop = pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)).convert("RGB")
404
-
405
- obj_candidates = collect_group_object_candidates(
406
- detections=detections,
407
- person_car_ids=person_car_ids,
408
- group_box=group_box,
409
- img_w=pil.width,
410
- img_h=pil.height,
411
- min_side=1, # not used here; caller filters before this if needed
412
- crop_dedup_iou=1.0, # not used here; caller passes already-filtered set if needed
413
- object_padding=0.0,
414
- )
415
- # This helper is not used directly because caller already builds filtered candidates.
416
- # Kept here only for API symmetry.
417
- _ = obj_candidates
418
 
419
- return group_crop, [], []
420
 
421
 
422
  def main():
423
  args = parse_args()
424
  device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
 
425
  input_dir = Path(args.input)
426
  output_dir = Path(args.output)
427
  refs_dir = Path(args.refs)
@@ -432,9 +265,13 @@ def main():
432
  if not input_dir.is_dir():
433
  raise SystemExit(f"Input folder not found: {input_dir}")
434
 
435
- paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS)
 
 
 
436
  if args.max_images is not None:
437
  paths = paths[: args.max_images]
 
438
  if not paths:
439
  raise SystemExit(f"No images in {input_dir}")
440
 
@@ -445,16 +282,22 @@ def main():
445
  dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365")
446
  dfine_model = dfine_model.to(device).eval()
447
  person_car_ids = get_person_car_label_ids(dfine_model)
448
- print(f" Person/car label IDs: {person_car_ids} ({time.perf_counter() - t0:.1f}s)")
449
 
450
- # Load Jina refs
451
  print("[*] Loading Jina-CLIP-v2 and building refs...")
452
  t0 = time.perf_counter()
453
  jina_encoder = JinaCLIPv2Encoder(device)
454
- ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, args.text_weight, batch_size=16)
455
- print(f" Jina refs: {ref_labels} ({time.perf_counter() - t0:.1f}s)\n")
 
 
 
 
 
 
456
 
457
- # Load Nomic refs
458
  print("[*] Loading Nomic embed-vision + embed-text and building refs...")
459
  t0 = time.perf_counter()
460
  nomic_encoder = NomicVisionEncoder(device)
@@ -466,21 +309,38 @@ def main():
466
  text_encoder=nomic_text_encoder,
467
  text_weight=args.text_weight,
468
  )
469
- print(f" Nomic refs: {ref_labels_nomic} ({time.perf_counter() - t0:.1f}s)\n")
470
 
 
471
  jina_crops_dir = output_dir / "jina_crops"
472
  nomic_crops_dir = output_dir / "nomic_crops"
473
  jina_crops_dir.mkdir(parents=True, exist_ok=True)
474
  nomic_crops_dir.mkdir(parents=True, exist_ok=True)
475
 
 
476
  csv_path = output_dir / "results.csv"
477
  f = open(csv_path, "w", newline="")
478
  w = csv.writer(f)
479
  w.writerow([
480
- "image", "crop_filename", "group_idx", "crop_x1", "crop_y1", "crop_x2", "crop_y2",
481
- "bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2", "dfine_label", "dfine_conf",
482
- "jina_prediction", "jina_confidence", "jina_status",
483
- "nomic_prediction", "nomic_confidence", "nomic_status",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  ])
485
 
486
  for img_path in paths:
@@ -488,101 +348,167 @@ def main():
488
  img_w, img_h = pil.size
489
  group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w)
490
 
491
- detections = run_dfine(pil, image_processor, dfine_model, device, args.det_threshold)
 
 
 
 
 
 
 
 
492
  person_car = [d for d in detections if d["cls"] in person_car_ids]
493
  if not person_car:
494
  continue
495
 
 
496
  grouped = group_detections(person_car, group_dist)
497
  grouped.sort(key=lambda x: x["conf"], reverse=True)
498
- top_groups = grouped[:10]
 
 
 
 
499
 
500
  for gidx, grp in enumerate(top_groups):
501
- group_box = grp["box"]
502
- crop_box = build_group_crop_box(
503
- group_box,
504
- img_w,
505
- img_h,
506
- padding=args.padding,
507
- squarify=not args.no_squarify,
508
- )
509
- if crop_box is None:
510
- continue
511
 
512
- crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
513
- crop_name = f"{img_path.stem}_g{gidx}_{crop_x1}_{crop_y1}_{crop_x2}_{crop_y2}{img_path.suffix}"
514
- base_group_crop = pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)).convert("RGB")
515
-
516
- obj_candidates = collect_group_object_candidates(
517
- detections=detections,
518
- person_car_ids=person_car_ids,
519
- group_box=group_box,
520
- img_w=img_w,
521
- img_h=img_h,
522
- min_side=args.min_side,
523
- crop_dedup_iou=args.crop_dedup_iou,
524
- object_padding=0.3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  )
526
 
527
- ann_jina = base_group_crop.copy()
528
- ann_nomic = base_group_crop.copy()
529
-
530
- for item in obj_candidates:
531
- d = item["det"]
532
- ex1, ey1, ex2, ey2 = item["expanded_box"]
533
-
534
- obj_crop = pil.crop((ex1, ey1, ex2, ey2)).convert("RGB")
535
-
536
- q_jina = jina_encoder.encode_images([obj_crop], TRUNCATE_DIM)
537
- result_jina = jina_classify(q_jina, ref_labels, ref_embs, args.conf_threshold, args.gap_threshold)
538
-
539
- q_nomic = nomic_encoder.encode_images([obj_crop])
540
- result_nomic = jina_classify(q_nomic, ref_labels_nomic, ref_embs_nomic, args.conf_threshold, args.gap_threshold)
541
-
542
- rel_box = [
543
- max(0, int(round(d["box"][0] - crop_x1))),
544
- max(0, int(round(d["box"][1] - crop_y1))),
545
- min(base_group_crop.width, int(round(d["box"][2] - crop_x1))),
546
- min(base_group_crop.height, int(round(d["box"][3] - crop_y1))),
547
- ]
548
-
549
- if result_jina["prediction"] in ref_labels:
550
- ann_jina = draw_bbox_with_optional_label(
551
- ann_jina,
552
- rel_box,
553
- label=result_jina["prediction"],
554
- confidence=result_jina["confidence"],
555
- )
556
-
557
- if result_nomic["prediction"] in ref_labels_nomic:
558
- ann_nomic = draw_bbox_with_optional_label(
559
- ann_nomic,
560
- rel_box,
561
- label=result_nomic["prediction"],
562
- confidence=result_nomic["confidence"],
563
- )
564
-
565
- w.writerow([
566
- img_path.name, crop_name, gidx,
567
- crop_x1, crop_y1, crop_x2, crop_y2,
568
- d["box"][0], d["box"][1], d["box"][2], d["box"][3],
569
- d["label"], f"{d['conf']:.4f}",
570
- result_jina["prediction"], f"{result_jina['confidence']:.4f}", result_jina["status"],
571
- result_nomic["prediction"], f"{result_nomic['confidence']:.4f}", result_nomic["status"],
572
- ])
573
 
 
574
  ann_jina.save(jina_crops_dir / crop_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  ann_nomic.save(nomic_crops_dir / crop_name)
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  f.close()
578
  print(f"[*] Wrote {csv_path}")
579
- print(f"[*] Jina crops: {jina_crops_dir}")
580
  print(f"[*] Nomic crops: {nomic_crops_dir}")
581
 
582
 
583
  # -----------------------------------------------------------------------------
584
- # Single-image runner for Gradio app: D-FINE first, then Jina or Nomic
585
  # -----------------------------------------------------------------------------
 
586
  _APP_DFINE = None
587
  _APP_JINA = None
588
  _APP_NOMIC = None
@@ -603,11 +529,16 @@ def run_single_image(
603
  squarify=True,
604
  ):
605
  """
606
- Run D-FINE on one image, then classify small-object detections inside each group crop.
607
- Returns:
608
- - one vertically stacked image of all D-FINE group crops
609
- - text containing only known-class predictions
 
 
610
  """
 
 
 
611
  global _APP_DFINE, _APP_JINA, _APP_NOMIC, _APP_REFS_JINA, _APP_REFS_NOMIC
612
 
613
  refs_dir = Path(refs_dir)
@@ -640,6 +571,61 @@ def run_single_image(
640
  grouped.sort(key=lambda x: x["conf"], reverse=True)
641
  top_groups = grouped[:10]
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  # Load encoder + refs for chosen model
644
  if encoder_choice == "jina":
645
  if _APP_JINA is None or _APP_REFS_JINA != str(refs_dir):
@@ -647,7 +633,8 @@ def run_single_image(
647
  ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
648
  _APP_JINA = (jina_encoder, ref_labels, ref_embs)
649
  _APP_REFS_JINA = str(refs_dir)
650
- encoder, ref_labels, ref_embs = _APP_JINA
 
651
  else:
652
  if _APP_NOMIC is None or _APP_REFS_NOMIC != str(refs_dir):
653
  nomic_encoder = NomicVisionEncoder(device)
@@ -661,83 +648,44 @@ def run_single_image(
661
  )
662
  _APP_NOMIC = (nomic_encoder, ref_labels, ref_embs)
663
  _APP_REFS_NOMIC = str(refs_dir)
664
- encoder, ref_labels, ref_embs = _APP_NOMIC
665
 
666
- output_crops = []
667
- lines = []
668
-
669
- for gidx, grp in enumerate(top_groups):
670
- group_box = grp["box"]
671
- crop_box = build_group_crop_box(
672
- group_box,
673
- img_w,
674
- img_h,
675
- padding=0.2,
676
- squarify=squarify,
677
- )
678
- if crop_box is None:
679
- continue
680
-
681
- crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
682
- group_crop = pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)).convert("RGB")
683
-
684
- obj_candidates = collect_group_object_candidates(
685
- detections=detections,
686
- person_car_ids=person_car_ids,
687
- group_box=group_box,
688
- img_w=img_w,
689
- img_h=img_h,
690
- min_side=min_side,
691
- crop_dedup_iou=crop_dedup_iou,
692
- object_padding=0.3,
693
- )
694
-
695
- ann_crop = group_crop.copy()
696
- group_known_lines = []
697
-
698
- for item in obj_candidates:
699
- d = item["det"]
700
- ex1, ey1, ex2, ey2 = item["expanded_box"]
701
- obj_crop = pil.crop((ex1, ey1, ex2, ey2)).convert("RGB")
702
 
703
- if encoder_choice == "jina":
704
- q = encoder.encode_images([obj_crop], TRUNCATE_DIM)
705
- result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
706
- else:
707
- q = encoder.encode_images([obj_crop])
708
- result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
709
-
710
- known = result["prediction"] in ref_labels
711
- if not known:
712
- continue
 
 
 
 
 
713
 
714
- rel_box = [
715
- max(0, int(round(d["box"][0] - crop_x1))),
716
- max(0, int(round(d["box"][1] - crop_y1))),
717
- min(group_crop.width, int(round(d["box"][2] - crop_x1))),
718
- min(group_crop.height, int(round(d["box"][3] - crop_y1))),
719
- ]
720
 
721
- ann_crop = draw_bbox_with_optional_label(
722
- ann_crop,
723
- rel_box,
724
- label=result["prediction"],
725
- confidence=result["confidence"],
726
- )
727
- group_known_lines.append(f"{result['prediction']} ({result['confidence']:.2f})")
728
 
729
- output_crops.append(ann_crop)
 
730
 
731
- if group_known_lines:
732
- lines.append(f"Crop {gidx + 1}:")
733
- lines.extend(group_known_lines)
734
 
735
- if not output_crops:
736
- return np.array(pil), "No D-FINE group crops."
737
 
738
- stacked = stack_images_vertical(output_crops, pad=10)
739
- result_text = "\n".join(lines) if lines else ""
740
- return np.array(stacked), result_text
741
 
742
 
743
  if __name__ == "__main__":
 
1
+ """ Pipeline: D-FINE (person/car only) → group detections → crop regions →
2
+ find all bboxes inside each crop Jina-CLIP-v2 and Nomic embeddings on those crops.
 
3
  Outputs separate crop folders per model (jina_crops, nomic_crops) for visual comparison.
 
4
  """
5
 
6
  import argparse
 
10
 
11
  import numpy as np
12
  import torch
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
  from transformers import AutoImageProcessor, DFineForObjectDetection
16
 
17
  # Jina-CLIP-v2 few-shot (same refs + classify as jina_fewshot.py)
 
21
  JinaCLIPv2Encoder,
22
  build_refs,
23
  classify as jina_classify,
24
+ draw_label_on_image,
25
  )
26
 
27
  from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic
28
 
29
 
30
  # -----------------------------------------------------------------------------
31
+ # Detection + grouping (from reference_detection.py)
32
  # -----------------------------------------------------------------------------
33
 
 
34
  def get_box_dist(box1, box2):
35
  """Euclidean distance between box centers. box = [x1, y1, x2, y2]."""
36
  c1 = np.array([(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2])
 
41
  def group_detections(detections, threshold):
42
  """
43
  Group detections by proximity (center distance < threshold).
44
+
45
  detections: list of {"box": [x1,y1,x2,y2], "conf", "cls", ...}
46
  Returns list of {"box": merged [x1,y1,x2,y2], "conf": best in group, "cls": best in group}.
47
  """
 
51
  boxes = [d["box"] for d in detections]
52
  n = len(boxes)
53
  adj = {i: [] for i in range(n)}
54
+
55
  for i in range(n):
56
  for j in range(i + 1, n):
57
  if get_box_dist(boxes[i], boxes[j]) < threshold:
 
60
 
61
  groups = []
62
  visited = [False] * n
63
+
64
  for i in range(n):
65
  if not visited[i]:
66
  group_indices = []
67
  stack = [i]
68
  visited[i] = True
69
+
70
  while stack:
71
  curr = stack.pop()
72
  group_indices.append(curr)
73
+
74
  for neighbor in adj[curr]:
75
  if not visited[neighbor]:
76
  visited[neighbor] = True
 
81
  y1 = min(d["box"][1] for d in group_dets)
82
  x2 = max(d["box"][2] for d in group_dets)
83
  y2 = max(d["box"][3] for d in group_dets)
 
84
 
85
+ best_det = max(group_dets, key=lambda x: x["conf"])
86
  groups.append({
87
  "box": [x1, y1, x2, y2],
88
  "conf": best_det["conf"],
89
  "cls": best_det["cls"],
90
  "label": best_det.get("label", str(best_det["cls"])),
91
  })
92
+
93
  return groups
94
 
95
 
 
97
  """True if center of box is inside crop_box. All [x1,y1,x2,y2]."""
98
  cx = (box[0] + box[2]) / 2
99
  cy = (box[1] + box[3]) / 2
100
+ return (
101
+ crop_box[0] <= cx <= crop_box[2]
102
+ and crop_box[1] <= cy <= crop_box[3]
103
+ )
104
 
105
 
106
  def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
 
112
  orig = (int(bx1), int(by1), int(bx2), int(by2))
113
  w = bx2 - bx1
114
  h = by2 - by1
115
+
116
  if w <= 0 or h <= 0:
117
  return orig
118
+
119
  if h > w:
120
  add = (h - w) / 2.0
121
  bx1 = max(0, bx1 - add)
 
124
  add = (w - h) / 2.0
125
  by1 = max(0, by1 - add)
126
  by2 = min(img_h, by2 + add)
127
+
128
  bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)
129
+
130
  if bx2 <= bx1 or by2 <= by1:
131
  return orig
132
+
133
  return bx1, by1, bx2, by2
134
 
135
 
 
139
  iy1 = max(box1[1], box2[1])
140
  ix2 = min(box1[2], box2[2])
141
  iy2 = min(box1[3], box2[3])
142
+
143
  inter_w = max(0, ix2 - ix1)
144
  inter_h = max(0, iy2 - iy1)
145
  inter = inter_w * inter_h
146
+
147
  a1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
148
  a2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
149
  union = a1 + a2 - inter
150
+
151
  return inter / union if union > 0 else 0.0
152
 
153
 
 
155
  """Keep one detection per overlapping group (IoU >= iou_threshold). Prefer higher confidence."""
156
  if not detections:
157
  return []
158
+
159
+ # Sort by confidence descending; keep first, then add only if no kept box overlaps >= threshold
160
  sorted_d = sorted(detections, key=lambda x: -x["conf"])
161
  kept = []
162
+
163
  for d in sorted_d:
164
  if not any(box_iou(d["box"], k["box"]) >= iou_threshold for k in kept):
165
  kept.append(d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  return kept
168
 
169
 
170
  def parse_args():
171
  p = argparse.ArgumentParser(
172
+ description="D-FINE (person/car) → group → Jina-CLIP-v2 on crops inside groups"
173
  )
174
  p.add_argument("--refs", required=True, help="Reference images folder for Jina and Nomic (e.g. refs/)")
175
  p.add_argument("--input", required=True, help="Full-frame images folder")
176
  p.add_argument("--output", default="pipeline_results", help="Output folder (CSV, etc.)")
177
  p.add_argument("--det-threshold", type=float, default=0.13, help="D-FINE score threshold")
178
  p.add_argument("--group-dist", type=float, default=None, help="Group distance (default: 0.1 * max(H,W))")
179
+ p.add_argument("--min-side", type=int, default=40, help="Min side of expanded bbox in px (skip smaller)")
180
+ p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)")
181
+ p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)")
182
+ p.add_argument("--padding", type=float, default=0.2, help="Crop padding around group box (0.2 = 20%%)")
183
+ p.add_argument("--conf-threshold", type=float, default=0.75, help="Jina accept confidence")
184
+ p.add_argument("--gap-threshold", type=float, default=0.05, help="Jina accept gap")
185
  p.add_argument("--text-weight", type=float, default=0.3)
186
  p.add_argument("--max-images", type=int, default=None)
187
  p.add_argument("--device", default=None)
 
192
  """Return set of label IDs for person and car (Objects365: Person, Car, SUV, etc.)."""
193
  id2label = getattr(model.config, "id2label", None) or {}
194
  ids = set()
195
+
196
  for idx, name in id2label.items():
197
  try:
198
  i = int(idx)
199
  except (ValueError, TypeError):
200
  continue
201
+
202
  n = (name or "").lower()
203
  if "person" in n or n in ("car", "suv"):
204
  ids.add(i)
205
+
206
  return ids
207
 
208
 
209
  def run_dfine(image, processor, model, device, score_threshold):
210
  """Run D-FINE, return all detections as list of {box, score, label_id, label}."""
211
+ from PIL import Image
212
+
213
  if isinstance(image, Image.Image):
214
  pil = image.convert("RGB")
215
  else:
216
  pil = Image.fromarray(image).convert("RGB")
217
+
218
  w, h = pil.size
219
  target_size = torch.tensor([[h, w]], device=device)
220
+
221
  inputs = processor(images=pil, return_tensors="pt")
222
  inputs = {k: v.to(device) for k, v in inputs.items()}
223
 
 
226
 
227
  target_sizes = target_size.to(outputs["logits"].device)
228
  results = processor.post_process_object_detection(
229
+ outputs,
230
+ target_sizes=target_sizes,
231
+ threshold=score_threshold,
232
  )
 
233
 
234
+ id2label = getattr(model.config, "id2label", {}) or {}
235
  detections = []
236
+
237
  for result in results:
238
+ for score, label_id, box in zip(
239
+ result["scores"],
240
+ result["labels"],
241
+ result["boxes"]
242
+ ):
243
  sid = int(label_id.item())
244
  detections.append({
245
  "box": [float(x) for x in box.cpu().tolist()],
 
247
  "cls": sid,
248
  "label": id2label.get(sid, str(sid)),
249
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ return detections
252
 
253
 
254
  def main():
255
  args = parse_args()
256
  device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
257
+
258
  input_dir = Path(args.input)
259
  output_dir = Path(args.output)
260
  refs_dir = Path(args.refs)
 
265
  if not input_dir.is_dir():
266
  raise SystemExit(f"Input folder not found: {input_dir}")
267
 
268
+ paths = sorted(
269
+ p for p in input_dir.iterdir()
270
+ if p.suffix.lower() in IMAGE_EXTS
271
+ )
272
  if args.max_images is not None:
273
  paths = paths[: args.max_images]
274
+
275
  if not paths:
276
  raise SystemExit(f"No images in {input_dir}")
277
 
 
282
  dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365")
283
  dfine_model = dfine_model.to(device).eval()
284
  person_car_ids = get_person_car_label_ids(dfine_model)
285
+ print(f" Person/car label IDs: {person_car_ids} ({time.perf_counter()-t0:.1f}s)")
286
 
287
+ # Load Jina-CLIP-v2 + build refs
288
  print("[*] Loading Jina-CLIP-v2 and building refs...")
289
  t0 = time.perf_counter()
290
  jina_encoder = JinaCLIPv2Encoder(device)
291
+ ref_labels, ref_embs = build_refs(
292
+ jina_encoder,
293
+ refs_dir,
294
+ TRUNCATE_DIM,
295
+ args.text_weight,
296
+ batch_size=16
297
+ )
298
+ print(f" Jina refs: {ref_labels} ({time.perf_counter()-t0:.1f}s)\n")
299
 
300
+ # Load Nomic vision + text, build refs (same as Jina: image + text prompts, text_weight 0.3)
301
  print("[*] Loading Nomic embed-vision + embed-text and building refs...")
302
  t0 = time.perf_counter()
303
  nomic_encoder = NomicVisionEncoder(device)
 
309
  text_encoder=nomic_text_encoder,
310
  text_weight=args.text_weight,
311
  )
312
+ print(f" Nomic refs: {ref_labels_nomic} ({time.perf_counter()-t0:.1f}s)\n")
313
 
314
+ # Separate output folders per model for visual comparison
315
  jina_crops_dir = output_dir / "jina_crops"
316
  nomic_crops_dir = output_dir / "nomic_crops"
317
  jina_crops_dir.mkdir(parents=True, exist_ok=True)
318
  nomic_crops_dir.mkdir(parents=True, exist_ok=True)
319
 
320
+ # CSV
321
  csv_path = output_dir / "results.csv"
322
  f = open(csv_path, "w", newline="")
323
  w = csv.writer(f)
324
  w.writerow([
325
+ "image",
326
+ "crop_filename",
327
+ "group_idx",
328
+ "crop_x1",
329
+ "crop_y1",
330
+ "crop_x2",
331
+ "crop_y2",
332
+ "bbox_x1",
333
+ "bbox_y1",
334
+ "bbox_x2",
335
+ "bbox_y2",
336
+ "dfine_label",
337
+ "dfine_conf",
338
+ "jina_prediction",
339
+ "jina_confidence",
340
+ "jina_status",
341
+ "nomic_prediction",
342
+ "nomic_confidence",
343
+ "nomic_status",
344
  ])
345
 
346
  for img_path in paths:
 
348
  img_w, img_h = pil.size
349
  group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w)
350
 
351
+ # 1) D-FINE: detect everything, keep all bboxes for the image
352
+ detections = run_dfine(
353
+ pil,
354
+ image_processor,
355
+ dfine_model,
356
+ device,
357
+ args.det_threshold
358
+ )
359
+
360
  person_car = [d for d in detections if d["cls"] in person_car_ids]
361
  if not person_car:
362
  continue
363
 
364
+ # 2) Group person/car detections (same as reference)
365
  grouped = group_detections(person_car, group_dist)
366
  grouped.sort(key=lambda x: x["conf"], reverse=True)
367
+ top_groups = grouped[:10] # limit groups per image
368
+
369
+ # 3) Collect all candidate crops (bboxes inside person/car groups)
370
+ # Each: (crop_box, crop_pil, d, gidx, crop_idx, x1, y1, x2, y2)
371
+ candidates = []
372
 
373
  for gidx, grp in enumerate(top_groups):
374
+ x1, y1, x2, y2 = grp["box"]
375
+ group_box = [x1, y1, x2, y2]
 
 
 
 
 
 
 
 
376
 
377
+ inside = [
378
+ d for d in detections
379
+ if box_center_inside(d["box"], group_box) and d["cls"] not in person_car_ids
380
+ ]
381
+ inside = deduplicate_by_iou(inside, iou_threshold=0.9)
382
+
383
+ for crop_idx, d in enumerate(inside):
384
+ bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
385
+ obj_w, obj_h = bx2 - bx1, by2 - by1
386
+ if obj_w <= 0 or obj_h <= 0:
387
+ continue
388
+
389
+ pad_x = obj_w * 0.3
390
+ pad_y = obj_h * 0.3
391
+ bx1 = max(0, int(bx1 - pad_x))
392
+ by1 = max(0, int(by1 - pad_y))
393
+ bx2 = min(img_w, int(bx2 + pad_x))
394
+ by2 = min(img_h, int(by2 + pad_y))
395
+
396
+ if bx2 <= bx1 or by2 <= by1:
397
+ continue
398
+
399
+ if min(bx2 - bx1, by2 - by1) < args.min_side:
400
+ continue
401
+
402
+ expanded_box = [bx1, by1, bx2, by2]
403
+ candidates.append((expanded_box, d, gidx, crop_idx, x1, y1, x2, y2))
404
+
405
+ # 4) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept
406
+ def crop_area(box):
407
+ return (box[2] - box[0]) * (box[3] - box[1])
408
+
409
+ candidates.sort(key=lambda c: -crop_area(c[0]))
410
+ kept = []
411
+
412
+ for c in candidates:
413
+ expanded_box = c[0]
414
+
415
+ def is_same_object(box_a, box_b):
416
+ if box_iou(box_a, box_b) >= args.crop_dedup_iou:
417
+ return True
418
+ if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
419
+ return True
420
+ return False
421
+
422
+ if not any(is_same_object(expanded_box, k[0]) for k in kept):
423
+ kept.append(c)
424
+
425
+ # 5) Optionally squarify, then run Jina and Nomic only on kept crops
426
+ for i, (expanded_box, d, gidx, crop_idx, x1, y1, x2, y2) in enumerate(kept):
427
+ if not args.no_squarify:
428
+ bx1, by1, bx2, by2 = squarify_crop_box(
429
+ expanded_box[0],
430
+ expanded_box[1],
431
+ expanded_box[2],
432
+ expanded_box[3],
433
+ img_w,
434
+ img_h
435
+ )
436
+ else:
437
+ bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]
438
+
439
+ crop_pil = pil.crop((bx1, by1, bx2, by2))
440
+ crop_name = f"{img_path.stem}_g{gidx}_{i}_{bx1}_{by1}_{bx2}_{by2}{img_path.suffix}"
441
+
442
+ q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
443
+ result_jina = jina_classify(
444
+ q_jina,
445
+ ref_labels,
446
+ ref_embs,
447
+ args.conf_threshold,
448
+ args.gap_threshold
449
  )
450
 
451
+ if result_jina["prediction"] in ref_labels:
452
+ label_jina = result_jina["prediction"]
453
+ conf_jina = result_jina["confidence"]
454
+ else:
455
+ label_jina = f"unnamed (dfine: {d['label']})"
456
+ conf_jina = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
+ ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina)
459
  ann_jina.save(jina_crops_dir / crop_name)
460
+
461
+ q_nomic = nomic_encoder.encode_images([crop_pil])
462
+ result_nomic = jina_classify(
463
+ q_nomic,
464
+ ref_labels_nomic,
465
+ ref_embs_nomic,
466
+ args.conf_threshold,
467
+ args.gap_threshold
468
+ )
469
+
470
+ if result_nomic["prediction"] in ref_labels_nomic:
471
+ label_nomic = result_nomic["prediction"]
472
+ conf_nomic = result_nomic["confidence"]
473
+ else:
474
+ label_nomic = f"unnamed (dfine: {d['label']})"
475
+ conf_nomic = 0.0
476
+
477
+ ann_nomic = draw_label_on_image(crop_pil, label_nomic, conf_nomic)
478
  ann_nomic.save(nomic_crops_dir / crop_name)
479
 
480
+ w.writerow([
481
+ img_path.name,
482
+ crop_name,
483
+ gidx,
484
+ x1,
485
+ y1,
486
+ x2,
487
+ y2,
488
+ bx1,
489
+ by1,
490
+ bx2,
491
+ by2,
492
+ d["label"],
493
+ f"{d['conf']:.4f}",
494
+ result_jina["prediction"],
495
+ f"{result_jina['confidence']:.4f}",
496
+ result_jina["status"],
497
+ result_nomic["prediction"],
498
+ f"{result_nomic['confidence']:.4f}",
499
+ result_nomic["status"],
500
+ ])
501
+
502
  f.close()
503
  print(f"[*] Wrote {csv_path}")
504
+ print(f"[*] Jina crops: {jina_crops_dir}")
505
  print(f"[*] Nomic crops: {nomic_crops_dir}")
506
 
507
 
508
  # -----------------------------------------------------------------------------
509
+ # Single-image runner for Gradio app: D-FINE first, then Jina or Nomic (user choice)
510
  # -----------------------------------------------------------------------------
511
+
512
  _APP_DFINE = None
513
  _APP_JINA = None
514
  _APP_NOMIC = None
 
529
  squarify=True,
530
  ):
531
  """
532
+ Run D-FINE on one image, then classify small-object crops with Jina or Nomic.
533
+
534
+ refs_dir: path to refs folder (str or Path).
535
+ encoder_choice: "jina" or "nomic".
536
+
537
+ Returns (annotated_pil, result_text) for display in app.
538
  """
539
+ import numpy as np
540
+ from PIL import Image
541
+
542
  global _APP_DFINE, _APP_JINA, _APP_NOMIC, _APP_REFS_JINA, _APP_REFS_NOMIC
543
 
544
  refs_dir = Path(refs_dir)
 
571
  grouped.sort(key=lambda x: x["conf"], reverse=True)
572
  top_groups = grouped[:10]
573
 
574
+ candidates = []
575
+
576
+ for gidx, grp in enumerate(top_groups):
577
+ x1, y1, x2, y2 = grp["box"]
578
+ group_box = [x1, y1, x2, y2]
579
+
580
+ inside = [
581
+ d for d in detections
582
+ if box_center_inside(d["box"], group_box) and d["cls"] not in person_car_ids
583
+ ]
584
+ inside = deduplicate_by_iou(inside, iou_threshold=0.9)
585
+
586
+ for crop_idx, d in enumerate(inside):
587
+ bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
588
+ obj_w, obj_h = bx2 - bx1, by2 - by1
589
+ if obj_w <= 0 or obj_h <= 0:
590
+ continue
591
+
592
+ pad_x, pad_y = obj_w * 0.3, obj_h * 0.3
593
+ bx1 = max(0, int(bx1 - pad_x))
594
+ by1 = max(0, int(by1 - pad_y))
595
+ bx2 = min(img_w, int(bx2 + pad_x))
596
+ by2 = min(img_h, int(by2 + pad_y))
597
+
598
+ if bx2 <= bx1 or by2 <= by1:
599
+ continue
600
+
601
+ if min(bx2 - bx1, by2 - by1) < min_side:
602
+ continue
603
+
604
+ expanded_box = [bx1, by1, bx2, by2]
605
+ candidates.append((expanded_box, d, gidx, crop_idx))
606
+
607
+ def crop_area(box):
608
+ return (box[2] - box[0]) * (box[3] - box[1])
609
+
610
+ candidates.sort(key=lambda c: -crop_area(c[0]))
611
+ kept = []
612
+
613
+ for c in candidates:
614
+ def is_same_object(box_a, box_b):
615
+ if box_iou(box_a, box_b) >= crop_dedup_iou:
616
+ return True
617
+ if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
618
+ return True
619
+ return False
620
+
621
+ if not any(is_same_object(c[0], k[0]) for k in kept):
622
+ kept.append(c)
623
+
624
+ if not kept:
625
+ if not candidates:
626
+ return np.array(pil), "No small-object crops: D-FINE did not detect any object (gun/phone/etc.) inside person/car areas, or all were below min size. Try a higher-resolution image."
627
+ return np.array(pil), "No small-object crops (after dedup)."
628
+
629
  # Load encoder + refs for chosen model
630
  if encoder_choice == "jina":
631
  if _APP_JINA is None or _APP_REFS_JINA != str(refs_dir):
 
633
  ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
634
  _APP_JINA = (jina_encoder, ref_labels, ref_embs)
635
  _APP_REFS_JINA = str(refs_dir)
636
+
637
+ jina_encoder, ref_labels, ref_embs = _APP_JINA
638
  else:
639
  if _APP_NOMIC is None or _APP_REFS_NOMIC != str(refs_dir):
640
  nomic_encoder = NomicVisionEncoder(device)
 
648
  )
649
  _APP_NOMIC = (nomic_encoder, ref_labels, ref_embs)
650
  _APP_REFS_NOMIC = str(refs_dir)
 
651
 
652
+ nomic_encoder, ref_labels, ref_embs = _APP_NOMIC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
+ lines = []
655
+ out_img = pil.copy()
656
+
657
+ for i, (expanded_box, d, gidx, crop_idx) in enumerate(kept):
658
+ if squarify:
659
+ bx1, by1, bx2, by2 = squarify_crop_box(
660
+ expanded_box[0],
661
+ expanded_box[1],
662
+ expanded_box[2],
663
+ expanded_box[3],
664
+ img_w,
665
+ img_h
666
+ )
667
+ else:
668
+ bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]
669
 
670
+ crop_pil = pil.crop((bx1, by1, bx2, by2))
 
 
 
 
 
671
 
672
+ if encoder_choice == "jina":
673
+ q = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
674
+ result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
675
+ else:
676
+ q = nomic_encoder.encode_images([crop_pil])
677
+ result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
 
678
 
679
+ pred = result["prediction"] if result["prediction"] in ref_labels else f"unknown ({d['label']})"
680
+ conf = result["confidence"]
681
 
682
+ lines.append(f"Crop {i+1}: {pred} ({conf:.2f})")
 
 
683
 
684
+ labeled = draw_label_on_image(crop_pil, pred, conf)
685
+ out_img.paste(labeled, (bx1, by1))
686
 
687
+ result_text = "\n".join(lines) if lines else "No crops"
688
+ return np.array(out_img), result_text
 
689
 
690
 
691
  if __name__ == "__main__":