Spaces:
Sleeping
Sleeping
| """ | |
| Few-shot object classification using jina-clip-v2 via ONNX Runtime. | |
| Bypasses all PyTorch custom code / dtype issues on HF Spaces (T4). | |
| Combines IMAGE embeddings from reference photos + TEXT embeddings | |
| from class names. Dual threshold: confidence + gap between top-1 and top-2. | |
| Usage: | |
| python jina_fewshot.py \ | |
| --refs refs/ \ | |
| --input crops/ \ | |
| --output results/ \ | |
| --text-weight 0.3 \ | |
| --conf-threshold 0.75 \ | |
| --gap-threshold 0.05 | |
| refs/ folder structure (3-10 images per class recommended): | |
| refs/ | |
| βββ cigarette/ | |
| βββ gun/ | |
| βββ knife/ | |
| βββ phone/ | |
| βββ nothing/ (empty hands, random objects) | |
| """ | |
| import argparse | |
| import csv | |
| import json | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image, ImageDraw, ImageFont | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoImageProcessor, AutoTokenizer | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"} | |
| TRUNCATE_DIM = 1024 | |
| # ONNX model outputs: [text_unnorm, image_unnorm, text_norm, image_norm] | |
| _TEXT_NORM_IDX = 2 | |
| _IMAGE_NORM_IDX = 3 | |
| def draw_label_on_image(img: Image.Image, label: str, confidence: float) -> Image.Image: | |
| """Draw the label in a bar outside and on top of the image (full width). Returns new image.""" | |
| img = img.convert("RGB") | |
| w, h = img.width, img.height | |
| text = f"{label} ({confidence:.2f})" | |
| margin = 8 | |
| max_text_w = max(1, w - 2 * margin) | |
| font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" | |
| try: | |
| font_size = max(10, min(h, w) // 12) | |
| font = ImageFont.truetype(font_path, size=font_size) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| font_size = None | |
| dummy = Image.new("RGB", (1, 1)) | |
| ddraw = ImageDraw.Draw(dummy) | |
| bbox = ddraw.textbbox((0, 0), text, font=font) | |
| tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| if font_size is not None: | |
| while tw > max_text_w and font_size > 8: | |
| font_size = max(8, font_size - 2) | |
| font = ImageFont.truetype(font_path, size=font_size) | |
| bbox = ddraw.textbbox((0, 0), text, font=font) | |
| tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| bar_height = th + 2 * margin | |
| out = Image.new("RGB", (w, bar_height + h), color=(255, 255, 255)) | |
| draw = ImageDraw.Draw(out) | |
| draw.rectangle([0, 0, w, bar_height], fill=(0, 0, 0)) | |
| x = (w - tw) // 2 | |
| y = margin | |
| draw.text((x, y), text, fill=(255, 255, 255), font=font) | |
| out.paste(img, (0, bar_height)) | |
| return out | |
| def draw_bboxes_on_image( | |
| img: Image.Image, | |
| boxes: list[tuple[float, float, float, float, str, float]], | |
| ) -> Image.Image: | |
| """Draw bboxes and labels (label conf) on image. boxes: list of (x1, y1, x2, y2, label, conf).""" | |
| img = img.convert("RGB") | |
| draw = ImageDraw.Draw(img) | |
| w, h = img.width, img.height | |
| font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" | |
| try: | |
| font = ImageFont.truetype(font_path, size=max(10, min(h, w) // 20)) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| for (x1, y1, x2, y2, label, conf) in boxes: | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 0), width=2) | |
| text = f"{label} {conf:.2f}" | |
| draw.text((x1, max(0, y1 - 16)), text, fill=(0, 255, 0), font=font) | |
| return img | |
| CLASS_PROMPTS = { | |
| "knife": [ | |
| "a knife", | |
| "a person holding a knife", | |
| "a sharp blade knife", | |
| ], | |
| "gun": [ | |
| "a gun", | |
| "a pistol", | |
| "a handgun", | |
| "a person holding a gun", | |
| "a person holding a pistol", | |
| "a firearm weapon", | |
| ], | |
| "cigarette": [ | |
| "a cigarette", | |
| "a person smoking a cigarette", | |
| "a lit cigarette in hand", | |
| ], | |
| "phone": [ | |
| "a phone", | |
| "a person holding a smartphone", | |
| "a mobile phone cell phone", | |
| ], | |
| "nothing": [ | |
| "a person with empty hands", | |
| "a person standing with no objects", | |
| "empty hands no weapon", | |
| ], | |
| } | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Jina-CLIP-v2 few-shot classifier (ONNX)") | |
| p.add_argument("--refs", required=True, help="Reference images folder") | |
| p.add_argument("--input", required=True, help="Query crop images folder") | |
| p.add_argument("--output", default="jinaclip_results", help="Output folder") | |
| p.add_argument("--dim", type=int, default=TRUNCATE_DIM, help="Embedding dim (64-1024)") | |
| p.add_argument("--text-weight", type=float, default=0.3, | |
| help="Text embedding weight (0.0=image only, default 0.3)") | |
| p.add_argument("--conf-threshold", type=float, default=0.75, | |
| help="Min confidence to accept prediction (default 0.75)") | |
| p.add_argument("--gap-threshold", type=float, default=0.05, | |
| help="Min gap between top-1 and top-2 (default 0.05)") | |
| p.add_argument("--batch-size", type=int, default=16) | |
| p.add_argument("--save-refs", action="store_true", | |
| help="Save reference embeddings to .npy for fast reload") | |
| return p.parse_args() | |
| def _download_onnx_model(): | |
| """ | |
| Download the ONNX model from HF Hub. | |
| Try fp32 (model.onnx + model.onnx_data) first. | |
| Both files must be in the same directory for ONNX Runtime to find the | |
| external data file. | |
| """ | |
| print(" Downloading ONNX model files from jinaai/jina-clip-v2...") | |
| # Download both files β hf_hub_download puts them in the same snapshot dir | |
| onnx_path = hf_hub_download( | |
| repo_id="jinaai/jina-clip-v2", | |
| filename="onnx/model.onnx", | |
| ) | |
| # External weights file β MUST be downloaded to same directory | |
| hf_hub_download( | |
| repo_id="jinaai/jina-clip-v2", | |
| filename="onnx/model.onnx_data", | |
| ) | |
| print(f" Downloaded: {onnx_path}") | |
| print(f" External data: model.onnx_data (same directory)") | |
| return onnx_path | |
| class JinaCLIPv2Encoder: | |
| """ | |
| ONNX Runtime based encoder for jina-clip-v2. | |
| Completely bypasses PyTorch β no dtype/NaN issues. | |
| """ | |
| def __init__(self, device="cuda"): | |
| self.device = device | |
| print("[*] Loading jina-clip-v2 (ONNX)...") | |
| t0 = time.perf_counter() | |
| # Download ONNX model (fp32 with external data) | |
| onnx_path = _download_onnx_model() | |
| # Pick providers: prefer CUDA if available | |
| available = ort.get_available_providers() | |
| if "CUDAExecutionProvider" in available and device == "cuda": | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| providers = ["CPUExecutionProvider"] | |
| print(f" ONNX providers: {providers}") | |
| self.session = ort.InferenceSession(onnx_path, providers=providers) | |
| # Load tokenizer and image processor | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "jinaai/jina-clip-v2", trust_remote_code=True | |
| ) | |
| self.image_processor = AutoImageProcessor.from_pretrained( | |
| "jinaai/jina-clip-v2", trust_remote_code=True | |
| ) | |
| # Inspect ONNX model I/O | |
| self.input_names = [inp.name for inp in self.session.get_inputs()] | |
| self.output_names = [out.name for out in self.session.get_outputs()] | |
| print(f" ONNX inputs: {self.input_names}") | |
| print(f" ONNX outputs: {self.output_names}") | |
| # Build input name mapping | |
| self._pixel_name = None | |
| self._ids_name = None | |
| self._mask_name = None | |
| for name in self.input_names: | |
| nl = name.lower() | |
| if "pixel" in nl: | |
| self._pixel_name = name | |
| elif "input_id" in nl: | |
| self._ids_name = name | |
| elif "attention" in nl or "mask" in nl: | |
| self._mask_name = name | |
| print(f" Mapped: pixel={self._pixel_name}, ids={self._ids_name}, mask={self._mask_name}") | |
| # Sanity checks | |
| _dummy = Image.new("RGB", (512, 512), color=(255, 0, 0)) | |
| _test = self.encode_images([_dummy], dim=64) | |
| _norm = float(np.linalg.norm(_test)) | |
| _is_nan = bool(np.isnan(_norm)) | |
| print(f" [SANITY] dummy image embed norm={_norm:.4f}, nan={_is_nan}") | |
| if _is_nan or _norm < 0.01: | |
| print(" [ERROR] ONNX vision encoder broken!") | |
| else: | |
| print(" [OK] ONNX vision encoder producing valid embeddings") | |
| _test_t = self.encode_texts(["a red square"], dim=64) | |
| _tn = float(np.linalg.norm(_test_t)) | |
| print(f" [SANITY] dummy text embed norm={_tn:.4f}") | |
| elapsed = time.perf_counter() - t0 | |
| print(f"[*] Loaded in {elapsed:.1f}s (ONNX, providers={providers})\n") | |
| def _run_image(self, pixel_values: np.ndarray) -> np.ndarray: | |
| """Run ONNX for images only. Returns normalized image embeddings.""" | |
| bs = pixel_values.shape[0] | |
| # Dummy text input (minimal tokens) | |
| dummy_ids = np.zeros((bs, 1), dtype=np.int64) | |
| dummy_mask = np.ones((bs, 1), dtype=np.int64) | |
| feeds = {} | |
| if self._pixel_name: | |
| feeds[self._pixel_name] = pixel_values.astype(np.float32) | |
| if self._ids_name: | |
| feeds[self._ids_name] = dummy_ids | |
| if self._mask_name: | |
| feeds[self._mask_name] = dummy_mask | |
| outputs = self.session.run(self.output_names, feeds) | |
| return outputs[_IMAGE_NORM_IDX] | |
| def _run_text(self, input_ids: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: | |
| """Run ONNX for text only. Returns normalized text embeddings.""" | |
| bs = input_ids.shape[0] | |
| # Dummy pixel values (1 pixel image β minimal memory) | |
| dummy_pv = np.zeros((bs, 3, 512, 512), dtype=np.float32) | |
| feeds = {} | |
| if self._pixel_name: | |
| feeds[self._pixel_name] = dummy_pv | |
| if self._ids_name: | |
| feeds[self._ids_name] = input_ids.astype(np.int64) | |
| if self._mask_name: | |
| feeds[self._mask_name] = attention_mask.astype(np.int64) | |
| outputs = self.session.run(self.output_names, feeds) | |
| return outputs[_TEXT_NORM_IDX] | |
| def encode_images(self, images: list[Image.Image], dim: int = TRUNCATE_DIM) -> np.ndarray: | |
| rgb = [img.convert("RGB") for img in images] | |
| processed = self.image_processor(rgb, return_tensors="np") | |
| pv = processed["pixel_values"] | |
| pixel_values = pv.numpy().astype(np.float32) if hasattr(pv, "numpy") else np.asarray(pv, dtype=np.float32) | |
| embs = self._run_image(pixel_values) | |
| if dim and dim < embs.shape[-1]: | |
| embs = embs[:, :dim] | |
| embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0) | |
| norms = np.linalg.norm(embs, axis=-1, keepdims=True) | |
| norms = np.maximum(norms, 1e-12) | |
| return (embs / norms).astype(np.float32) | |
| def encode_texts(self, texts: list[str], dim: int = TRUNCATE_DIM) -> np.ndarray: | |
| tokens = self.tokenizer( | |
| texts, return_tensors="np", padding=True, truncation=True, max_length=512 | |
| ) | |
| input_ids = tokens["input_ids"].astype(np.int64) | |
| attention_mask = tokens["attention_mask"].astype(np.int64) | |
| embs = self._run_text(input_ids, attention_mask) | |
| if dim and dim < embs.shape[-1]: | |
| embs = embs[:, :dim] | |
| embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0) | |
| norms = np.linalg.norm(embs, axis=-1, keepdims=True) | |
| norms = np.maximum(norms, 1e-12) | |
| return (embs / norms).astype(np.float32) | |
| def encode_image_paths(self, paths: list[str], dim: int = TRUNCATE_DIM, | |
| batch_size: int = 16) -> np.ndarray: | |
| all_embs = [] | |
| for i in range(0, len(paths), batch_size): | |
| batch = [Image.open(p) for p in paths[i:i + batch_size]] | |
| all_embs.append(self.encode_images(batch, dim)) | |
| return np.concatenate(all_embs, axis=0) | |
| def build_refs(encoder: JinaCLIPv2Encoder, refs_dir: Path, | |
| dim: int, text_weight: float, batch_size: int): | |
| class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir()) | |
| if not class_dirs: | |
| raise ValueError(f"No subfolders in {refs_dir}") | |
| labels, embeddings = [], [] | |
| print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n") | |
| for d in class_dirs: | |
| name = d.name | |
| paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS) | |
| if not paths: | |
| continue | |
| img_embs = encoder.encode_image_paths(paths, dim, batch_size) | |
| img_avg = np.nan_to_num(img_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0) | |
| prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"]) | |
| text_embs = encoder.encode_texts(prompts, dim) | |
| text_avg = np.nan_to_num(text_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0) | |
| combined = (1.0 - text_weight) * img_avg + text_weight * text_avg | |
| combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0) | |
| combined = combined / (np.linalg.norm(combined) + 1e-12) | |
| labels.append(name) | |
| embeddings.append(combined) | |
| img_norm = img_avg / (np.linalg.norm(img_avg) + 1e-12) | |
| text_norm = text_avg / (np.linalg.norm(text_avg) + 1e-12) | |
| sim = float(np.nan_to_num(np.dot(img_norm, text_norm), nan=0.0)) | |
| print(f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts | " | |
| f"img-text sim: {sim:.4f}") | |
| return labels, np.stack(embeddings) | |
| def classify(query_emb: np.ndarray, ref_labels: list[str], ref_embs: np.ndarray, | |
| conf_threshold: float, gap_threshold: float) -> dict: | |
| sims = (query_emb @ ref_embs.T).squeeze(0) | |
| sims = np.nan_to_num(sims.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0) | |
| sorted_idx = np.argsort(sims)[::-1] | |
| best_idx = sorted_idx[0] | |
| second_idx = sorted_idx[1] | |
| conf = float(sims[best_idx]) | |
| gap = float(sims[best_idx] - sims[second_idx]) | |
| conf_ok = conf >= conf_threshold | |
| gap_ok = gap >= gap_threshold | |
| if conf_ok and gap_ok: | |
| prediction = ref_labels[best_idx] | |
| status = "accepted" | |
| else: | |
| prediction = "unknown" | |
| reasons = [] | |
| if not conf_ok: | |
| reasons.append(f"conf {conf:.4f} < {conf_threshold}") | |
| if not gap_ok: | |
| reasons.append(f"gap {gap:.4f} < {gap_threshold}") | |
| status = "rejected: " + ", ".join(reasons) | |
| return { | |
| "prediction": prediction, | |
| "raw_prediction": ref_labels[best_idx], | |
| "confidence": conf, | |
| "gap": gap, | |
| "second_best": ref_labels[second_idx], | |
| "second_conf": float(sims[second_idx]), | |
| "status": status, | |
| "all_sims": {ref_labels[j]: float(sims[j]) for j in range(len(ref_labels))}, | |
| } | |
| def main(): | |
| args = parse_args() | |
| input_dir, output_dir = Path(args.input), Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS) | |
| if not paths: | |
| return print(f"[!] No images in {input_dir}") | |
| print(f"[*] {len(paths)} query images") | |
| print(f"[*] Conf threshold: {args.conf_threshold} | Gap threshold: {args.gap_threshold}\n") | |
| encoder = JinaCLIPv2Encoder("cuda") | |
| print("[*] Building references...") | |
| ref_labels, ref_embs = build_refs( | |
| encoder, Path(args.refs), args.dim, args.text_weight, args.batch_size | |
| ) | |
| print(f"\n[*] {len(ref_labels)} classes: {ref_labels}\n") | |
| if args.save_refs: | |
| np.save(output_dir / "ref_embeddings.npy", ref_embs) | |
| with open(output_dir / "ref_labels.json", "w") as jf: | |
| json.dump(ref_labels, jf) | |
| print(f"[*] Saved refs to {output_dir}\n") | |
| csv_path = output_dir / "classifications.csv" | |
| f = open(csv_path, "w", newline="") | |
| w = csv.writer(f) | |
| w.writerow(["image", "prediction", "raw_prediction", "confidence", "gap", | |
| "second_best", "second_conf", "status"] + | |
| [f"sim_{l}" for l in ref_labels] + ["time_ms"]) | |
| times = [] | |
| counts = {"unknown": 0} | |
| for l in ref_labels: | |
| counts[l] = 0 | |
| accepted, rejected = 0, 0 | |
| hdr = " ".join(f"{l:>10}" for l in ref_labels) | |
| print(f"{'Image':<30} {'Result':<10} {'Conf':>6} {'Gap':>6} {hdr} {'Status'}") | |
| print("=" * (30 + 10 + 14 + len(hdr) + 40)) | |
| for p in paths: | |
| t0 = time.perf_counter() | |
| img = Image.open(p) | |
| q = encoder.encode_images([img], args.dim) | |
| ms = (time.perf_counter() - t0) * 1000 | |
| times.append(ms) | |
| result = classify(q, ref_labels, ref_embs, args.conf_threshold, args.gap_threshold) | |
| counts[result["prediction"]] += 1 | |
| if result["prediction"] != "unknown": | |
| accepted += 1 | |
| else: | |
| rejected += 1 | |
| annotated = draw_label_on_image(img, result["prediction"], result["confidence"]) | |
| out_path = output_dir / p.name | |
| annotated.save(out_path) | |
| sim_str = " ".join(f"{result['all_sims'][l]:>10.4f}" for l in ref_labels) | |
| print(f"{p.name:<30} {result['prediction']:<10} " | |
| f"{result['confidence']:>6.4f} {result['gap']:>6.4f} " | |
| f"{sim_str} {result['status']}") | |
| w.writerow([ | |
| p.name, result["prediction"], result["raw_prediction"], | |
| f"{result['confidence']:.4f}", f"{result['gap']:.4f}", | |
| result["second_best"], f"{result['second_conf']:.4f}", | |
| result["status"], | |
| ] + [f"{result['all_sims'][l]:.4f}" for l in ref_labels] + | |
| [f"{ms:.1f}"]) | |
| f.close() | |
| n = len(times) | |
| total = sum(times) | |
| print(f"\n{'='*70}") | |
| print("SUMMARY") | |
| print(f"{'='*70}") | |
| print(f" Model : jina-clip-v2 (ONNX Runtime, fp32)") | |
| print(f" Embed dim : {args.dim}") | |
| print(f" Text weight : {args.text_weight}") | |
| print(f" Conf threshold : {args.conf_threshold}") | |
| print(f" Gap threshold : {args.gap_threshold}") | |
| print(f" Images : {n}") | |
| if n: | |
| print(f" Accepted : {accepted} ({accepted/n*100:.1f}%)") | |
| print(f" Rejected : {rejected} ({rejected/n*100:.1f}%)") | |
| print(f" ββββββββββββββββββββββββββββββββββββββββββ") | |
| for l in ref_labels + ["unknown"]: | |
| c = counts.get(l, 0) | |
| pct = (c / n * 100) if n else 0 | |
| print(f" {l:<14}: {c:>4} ({pct:.1f}%)") | |
| print(f" ββββββββββββββββββββββββββββββββββββββββββ") | |
| if n: | |
| print(f" Total : {total:.0f}ms ({total/1000:.2f}s)") | |
| print(f" Avg/image : {total/n:.1f}ms") | |
| print(f" Throughput : {n/(total/1000):.1f} img/s") | |
| print(f" CSV : {csv_path}") | |
| print(f" Annotated imgs : {output_dir}") | |
| print(f"{'='*70}") | |
| if __name__ == "__main__": | |
| main() |