small_object_detection / jina_fewshot.py
Orkhan Hasanli
D-FINE: person/car crop galleries, known-object bboxes only
be400de
"""
Few-shot object classification using jina-clip-v2 via ONNX Runtime.
Bypasses all PyTorch custom code / dtype issues on HF Spaces (T4).
Combines IMAGE embeddings from reference photos + TEXT embeddings
from class names. Dual threshold: confidence + gap between top-1 and top-2.
Usage:
python jina_fewshot.py \
--refs refs/ \
--input crops/ \
--output results/ \
--text-weight 0.3 \
--conf-threshold 0.75 \
--gap-threshold 0.05
refs/ folder structure (3-10 images per class recommended):
refs/
β”œβ”€β”€ cigarette/
β”œβ”€β”€ gun/
β”œβ”€β”€ knife/
β”œβ”€β”€ phone/
└── nothing/ (empty hands, random objects)
"""
import argparse
import csv
import json
import time
from pathlib import Path
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download
from transformers import AutoImageProcessor, AutoTokenizer
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"}
TRUNCATE_DIM = 1024
# ONNX model outputs: [text_unnorm, image_unnorm, text_norm, image_norm]
_TEXT_NORM_IDX = 2
_IMAGE_NORM_IDX = 3
def draw_label_on_image(img: Image.Image, label: str, confidence: float) -> Image.Image:
"""Draw the label in a bar outside and on top of the image (full width). Returns new image."""
img = img.convert("RGB")
w, h = img.width, img.height
text = f"{label} ({confidence:.2f})"
margin = 8
max_text_w = max(1, w - 2 * margin)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
try:
font_size = max(10, min(h, w) // 12)
font = ImageFont.truetype(font_path, size=font_size)
except OSError:
font = ImageFont.load_default()
font_size = None
dummy = Image.new("RGB", (1, 1))
ddraw = ImageDraw.Draw(dummy)
bbox = ddraw.textbbox((0, 0), text, font=font)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
if font_size is not None:
while tw > max_text_w and font_size > 8:
font_size = max(8, font_size - 2)
font = ImageFont.truetype(font_path, size=font_size)
bbox = ddraw.textbbox((0, 0), text, font=font)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
bar_height = th + 2 * margin
out = Image.new("RGB", (w, bar_height + h), color=(255, 255, 255))
draw = ImageDraw.Draw(out)
draw.rectangle([0, 0, w, bar_height], fill=(0, 0, 0))
x = (w - tw) // 2
y = margin
draw.text((x, y), text, fill=(255, 255, 255), font=font)
out.paste(img, (0, bar_height))
return out
def draw_bboxes_on_image(
img: Image.Image,
boxes: list[tuple[float, float, float, float, str, float]],
) -> Image.Image:
"""Draw bboxes and labels (label conf) on image. boxes: list of (x1, y1, x2, y2, label, conf)."""
img = img.convert("RGB")
draw = ImageDraw.Draw(img)
w, h = img.width, img.height
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
try:
font = ImageFont.truetype(font_path, size=max(10, min(h, w) // 20))
except OSError:
font = ImageFont.load_default()
for (x1, y1, x2, y2, label, conf) in boxes:
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 0), width=2)
text = f"{label} {conf:.2f}"
draw.text((x1, max(0, y1 - 16)), text, fill=(0, 255, 0), font=font)
return img
CLASS_PROMPTS = {
"knife": [
"a knife",
"a person holding a knife",
"a sharp blade knife",
],
"gun": [
"a gun",
"a pistol",
"a handgun",
"a person holding a gun",
"a person holding a pistol",
"a firearm weapon",
],
"cigarette": [
"a cigarette",
"a person smoking a cigarette",
"a lit cigarette in hand",
],
"phone": [
"a phone",
"a person holding a smartphone",
"a mobile phone cell phone",
],
"nothing": [
"a person with empty hands",
"a person standing with no objects",
"empty hands no weapon",
],
}
def parse_args():
p = argparse.ArgumentParser(description="Jina-CLIP-v2 few-shot classifier (ONNX)")
p.add_argument("--refs", required=True, help="Reference images folder")
p.add_argument("--input", required=True, help="Query crop images folder")
p.add_argument("--output", default="jinaclip_results", help="Output folder")
p.add_argument("--dim", type=int, default=TRUNCATE_DIM, help="Embedding dim (64-1024)")
p.add_argument("--text-weight", type=float, default=0.3,
help="Text embedding weight (0.0=image only, default 0.3)")
p.add_argument("--conf-threshold", type=float, default=0.75,
help="Min confidence to accept prediction (default 0.75)")
p.add_argument("--gap-threshold", type=float, default=0.05,
help="Min gap between top-1 and top-2 (default 0.05)")
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--save-refs", action="store_true",
help="Save reference embeddings to .npy for fast reload")
return p.parse_args()
def _download_onnx_model():
"""
Download the ONNX model from HF Hub.
Try fp32 (model.onnx + model.onnx_data) first.
Both files must be in the same directory for ONNX Runtime to find the
external data file.
"""
print(" Downloading ONNX model files from jinaai/jina-clip-v2...")
# Download both files β€” hf_hub_download puts them in the same snapshot dir
onnx_path = hf_hub_download(
repo_id="jinaai/jina-clip-v2",
filename="onnx/model.onnx",
)
# External weights file β€” MUST be downloaded to same directory
hf_hub_download(
repo_id="jinaai/jina-clip-v2",
filename="onnx/model.onnx_data",
)
print(f" Downloaded: {onnx_path}")
print(f" External data: model.onnx_data (same directory)")
return onnx_path
class JinaCLIPv2Encoder:
"""
ONNX Runtime based encoder for jina-clip-v2.
Completely bypasses PyTorch β€” no dtype/NaN issues.
"""
def __init__(self, device="cuda"):
self.device = device
print("[*] Loading jina-clip-v2 (ONNX)...")
t0 = time.perf_counter()
# Download ONNX model (fp32 with external data)
onnx_path = _download_onnx_model()
# Pick providers: prefer CUDA if available
available = ort.get_available_providers()
if "CUDAExecutionProvider" in available and device == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
print(f" ONNX providers: {providers}")
self.session = ort.InferenceSession(onnx_path, providers=providers)
# Load tokenizer and image processor
self.tokenizer = AutoTokenizer.from_pretrained(
"jinaai/jina-clip-v2", trust_remote_code=True
)
self.image_processor = AutoImageProcessor.from_pretrained(
"jinaai/jina-clip-v2", trust_remote_code=True
)
# Inspect ONNX model I/O
self.input_names = [inp.name for inp in self.session.get_inputs()]
self.output_names = [out.name for out in self.session.get_outputs()]
print(f" ONNX inputs: {self.input_names}")
print(f" ONNX outputs: {self.output_names}")
# Build input name mapping
self._pixel_name = None
self._ids_name = None
self._mask_name = None
for name in self.input_names:
nl = name.lower()
if "pixel" in nl:
self._pixel_name = name
elif "input_id" in nl:
self._ids_name = name
elif "attention" in nl or "mask" in nl:
self._mask_name = name
print(f" Mapped: pixel={self._pixel_name}, ids={self._ids_name}, mask={self._mask_name}")
# Sanity checks
_dummy = Image.new("RGB", (512, 512), color=(255, 0, 0))
_test = self.encode_images([_dummy], dim=64)
_norm = float(np.linalg.norm(_test))
_is_nan = bool(np.isnan(_norm))
print(f" [SANITY] dummy image embed norm={_norm:.4f}, nan={_is_nan}")
if _is_nan or _norm < 0.01:
print(" [ERROR] ONNX vision encoder broken!")
else:
print(" [OK] ONNX vision encoder producing valid embeddings")
_test_t = self.encode_texts(["a red square"], dim=64)
_tn = float(np.linalg.norm(_test_t))
print(f" [SANITY] dummy text embed norm={_tn:.4f}")
elapsed = time.perf_counter() - t0
print(f"[*] Loaded in {elapsed:.1f}s (ONNX, providers={providers})\n")
def _run_image(self, pixel_values: np.ndarray) -> np.ndarray:
"""Run ONNX for images only. Returns normalized image embeddings."""
bs = pixel_values.shape[0]
# Dummy text input (minimal tokens)
dummy_ids = np.zeros((bs, 1), dtype=np.int64)
dummy_mask = np.ones((bs, 1), dtype=np.int64)
feeds = {}
if self._pixel_name:
feeds[self._pixel_name] = pixel_values.astype(np.float32)
if self._ids_name:
feeds[self._ids_name] = dummy_ids
if self._mask_name:
feeds[self._mask_name] = dummy_mask
outputs = self.session.run(self.output_names, feeds)
return outputs[_IMAGE_NORM_IDX]
def _run_text(self, input_ids: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
"""Run ONNX for text only. Returns normalized text embeddings."""
bs = input_ids.shape[0]
# Dummy pixel values (1 pixel image β€” minimal memory)
dummy_pv = np.zeros((bs, 3, 512, 512), dtype=np.float32)
feeds = {}
if self._pixel_name:
feeds[self._pixel_name] = dummy_pv
if self._ids_name:
feeds[self._ids_name] = input_ids.astype(np.int64)
if self._mask_name:
feeds[self._mask_name] = attention_mask.astype(np.int64)
outputs = self.session.run(self.output_names, feeds)
return outputs[_TEXT_NORM_IDX]
def encode_images(self, images: list[Image.Image], dim: int = TRUNCATE_DIM) -> np.ndarray:
rgb = [img.convert("RGB") for img in images]
processed = self.image_processor(rgb, return_tensors="np")
pv = processed["pixel_values"]
pixel_values = pv.numpy().astype(np.float32) if hasattr(pv, "numpy") else np.asarray(pv, dtype=np.float32)
embs = self._run_image(pixel_values)
if dim and dim < embs.shape[-1]:
embs = embs[:, :dim]
embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0)
norms = np.linalg.norm(embs, axis=-1, keepdims=True)
norms = np.maximum(norms, 1e-12)
return (embs / norms).astype(np.float32)
def encode_texts(self, texts: list[str], dim: int = TRUNCATE_DIM) -> np.ndarray:
tokens = self.tokenizer(
texts, return_tensors="np", padding=True, truncation=True, max_length=512
)
input_ids = tokens["input_ids"].astype(np.int64)
attention_mask = tokens["attention_mask"].astype(np.int64)
embs = self._run_text(input_ids, attention_mask)
if dim and dim < embs.shape[-1]:
embs = embs[:, :dim]
embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0)
norms = np.linalg.norm(embs, axis=-1, keepdims=True)
norms = np.maximum(norms, 1e-12)
return (embs / norms).astype(np.float32)
def encode_image_paths(self, paths: list[str], dim: int = TRUNCATE_DIM,
batch_size: int = 16) -> np.ndarray:
all_embs = []
for i in range(0, len(paths), batch_size):
batch = [Image.open(p) for p in paths[i:i + batch_size]]
all_embs.append(self.encode_images(batch, dim))
return np.concatenate(all_embs, axis=0)
def build_refs(encoder: JinaCLIPv2Encoder, refs_dir: Path,
dim: int, text_weight: float, batch_size: int):
class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
if not class_dirs:
raise ValueError(f"No subfolders in {refs_dir}")
labels, embeddings = [], []
print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
for d in class_dirs:
name = d.name
paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
if not paths:
continue
img_embs = encoder.encode_image_paths(paths, dim, batch_size)
img_avg = np.nan_to_num(img_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
text_embs = encoder.encode_texts(prompts, dim)
text_avg = np.nan_to_num(text_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
combined = combined / (np.linalg.norm(combined) + 1e-12)
labels.append(name)
embeddings.append(combined)
img_norm = img_avg / (np.linalg.norm(img_avg) + 1e-12)
text_norm = text_avg / (np.linalg.norm(text_avg) + 1e-12)
sim = float(np.nan_to_num(np.dot(img_norm, text_norm), nan=0.0))
print(f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts | "
f"img-text sim: {sim:.4f}")
return labels, np.stack(embeddings)
def classify(query_emb: np.ndarray, ref_labels: list[str], ref_embs: np.ndarray,
conf_threshold: float, gap_threshold: float) -> dict:
sims = (query_emb @ ref_embs.T).squeeze(0)
sims = np.nan_to_num(sims.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0)
sorted_idx = np.argsort(sims)[::-1]
best_idx = sorted_idx[0]
second_idx = sorted_idx[1]
conf = float(sims[best_idx])
gap = float(sims[best_idx] - sims[second_idx])
conf_ok = conf >= conf_threshold
gap_ok = gap >= gap_threshold
if conf_ok and gap_ok:
prediction = ref_labels[best_idx]
status = "accepted"
else:
prediction = "unknown"
reasons = []
if not conf_ok:
reasons.append(f"conf {conf:.4f} < {conf_threshold}")
if not gap_ok:
reasons.append(f"gap {gap:.4f} < {gap_threshold}")
status = "rejected: " + ", ".join(reasons)
return {
"prediction": prediction,
"raw_prediction": ref_labels[best_idx],
"confidence": conf,
"gap": gap,
"second_best": ref_labels[second_idx],
"second_conf": float(sims[second_idx]),
"status": status,
"all_sims": {ref_labels[j]: float(sims[j]) for j in range(len(ref_labels))},
}
def main():
args = parse_args()
input_dir, output_dir = Path(args.input), Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS)
if not paths:
return print(f"[!] No images in {input_dir}")
print(f"[*] {len(paths)} query images")
print(f"[*] Conf threshold: {args.conf_threshold} | Gap threshold: {args.gap_threshold}\n")
encoder = JinaCLIPv2Encoder("cuda")
print("[*] Building references...")
ref_labels, ref_embs = build_refs(
encoder, Path(args.refs), args.dim, args.text_weight, args.batch_size
)
print(f"\n[*] {len(ref_labels)} classes: {ref_labels}\n")
if args.save_refs:
np.save(output_dir / "ref_embeddings.npy", ref_embs)
with open(output_dir / "ref_labels.json", "w") as jf:
json.dump(ref_labels, jf)
print(f"[*] Saved refs to {output_dir}\n")
csv_path = output_dir / "classifications.csv"
f = open(csv_path, "w", newline="")
w = csv.writer(f)
w.writerow(["image", "prediction", "raw_prediction", "confidence", "gap",
"second_best", "second_conf", "status"] +
[f"sim_{l}" for l in ref_labels] + ["time_ms"])
times = []
counts = {"unknown": 0}
for l in ref_labels:
counts[l] = 0
accepted, rejected = 0, 0
hdr = " ".join(f"{l:>10}" for l in ref_labels)
print(f"{'Image':<30} {'Result':<10} {'Conf':>6} {'Gap':>6} {hdr} {'Status'}")
print("=" * (30 + 10 + 14 + len(hdr) + 40))
for p in paths:
t0 = time.perf_counter()
img = Image.open(p)
q = encoder.encode_images([img], args.dim)
ms = (time.perf_counter() - t0) * 1000
times.append(ms)
result = classify(q, ref_labels, ref_embs, args.conf_threshold, args.gap_threshold)
counts[result["prediction"]] += 1
if result["prediction"] != "unknown":
accepted += 1
else:
rejected += 1
annotated = draw_label_on_image(img, result["prediction"], result["confidence"])
out_path = output_dir / p.name
annotated.save(out_path)
sim_str = " ".join(f"{result['all_sims'][l]:>10.4f}" for l in ref_labels)
print(f"{p.name:<30} {result['prediction']:<10} "
f"{result['confidence']:>6.4f} {result['gap']:>6.4f} "
f"{sim_str} {result['status']}")
w.writerow([
p.name, result["prediction"], result["raw_prediction"],
f"{result['confidence']:.4f}", f"{result['gap']:.4f}",
result["second_best"], f"{result['second_conf']:.4f}",
result["status"],
] + [f"{result['all_sims'][l]:.4f}" for l in ref_labels] +
[f"{ms:.1f}"])
f.close()
n = len(times)
total = sum(times)
print(f"\n{'='*70}")
print("SUMMARY")
print(f"{'='*70}")
print(f" Model : jina-clip-v2 (ONNX Runtime, fp32)")
print(f" Embed dim : {args.dim}")
print(f" Text weight : {args.text_weight}")
print(f" Conf threshold : {args.conf_threshold}")
print(f" Gap threshold : {args.gap_threshold}")
print(f" Images : {n}")
if n:
print(f" Accepted : {accepted} ({accepted/n*100:.1f}%)")
print(f" Rejected : {rejected} ({rejected/n*100:.1f}%)")
print(f" ──────────────────────────────────────────")
for l in ref_labels + ["unknown"]:
c = counts.get(l, 0)
pct = (c / n * 100) if n else 0
print(f" {l:<14}: {c:>4} ({pct:.1f}%)")
print(f" ──────────────────────────────────────────")
if n:
print(f" Total : {total:.0f}ms ({total/1000:.2f}s)")
print(f" Avg/image : {total/n:.1f}ms")
print(f" Throughput : {n/(total/1000):.1f} img/s")
print(f" CSV : {csv_path}")
print(f" Annotated imgs : {output_dir}")
print(f"{'='*70}")
if __name__ == "__main__":
main()