Spaces:
Sleeping
Sleeping
Commit ·
82551bb
0
Parent(s):
Log device, Jina CPU warning, pin revision
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +27 -0
- .gitattributes +2 -0
- .gitignore +14 -0
- Dockerfile +35 -0
- README.md +38 -0
- app.py +202 -0
- dfine_jina_pipeline.py +565 -0
- jina_fewshot.py +399 -0
- models/README.md +6 -0
- models/v1/best.pt +3 -0
- nomic_fewshot.py +147 -0
- refs/cigarette/c2.png +3 -0
- refs/cigarette/c3.png +3 -0
- refs/cigarette/c4.png +3 -0
- refs/cigarette/c5.png +3 -0
- refs/cigarette/c6.png +3 -0
- refs/cigarette/c7.png +3 -0
- refs/cigarette/c9.png +3 -0
- refs/cigarette/cigarette.jpg +3 -0
- refs/gun/g1.png +3 -0
- refs/gun/g2.png +3 -0
- refs/gun/g3.png +3 -0
- refs/gun/g4.png +3 -0
- refs/gun/g5.png +3 -0
- refs/gun/g6.png +3 -0
- refs/gun/g7.png +3 -0
- refs/gun/g8.png +3 -0
- refs/gun/g9.png +3 -0
- refs/gun/pistol.jpeg +3 -0
- refs/knife/k1.png +3 -0
- refs/knife/k2.png +3 -0
- refs/knife/k3.png +3 -0
- refs/knife/k4.png +3 -0
- refs/knife/k5.png +3 -0
- refs/knife/k6.png +3 -0
- refs/knife/k7.png +3 -0
- refs/knife/k8.png +3 -0
- refs/knife/k9.png +3 -0
- refs/knife/knife.jpeg +3 -0
- refs/phone/p1.png +3 -0
- refs/phone/p2.png +3 -0
- refs/phone/p3.png +3 -0
- refs/phone/p4.png +3 -0
- refs/phone/p5.png +3 -0
- refs/phone/p6.png +3 -0
- refs/phone/p7.png +3 -0
- refs/phone/p8.png +3 -0
- refs/phone/p9.jpg +3 -0
- refs/phone/phone.jpg +3 -0
- requirements-lock.txt +17 -0
.dockerignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git and env
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
.gitattributes
|
| 5 |
+
.venv
|
| 6 |
+
venv
|
| 7 |
+
env
|
| 8 |
+
.env
|
| 9 |
+
|
| 10 |
+
# Build / cache
|
| 11 |
+
__pycache__
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*.pyo
|
| 14 |
+
.pytest_cache
|
| 15 |
+
.mypy_cache
|
| 16 |
+
|
| 17 |
+
# Large or generated (not needed in image)
|
| 18 |
+
full_frames_GT
|
| 19 |
+
threshold_tuning
|
| 20 |
+
*.pt
|
| 21 |
+
models/*.pt
|
| 22 |
+
models/*.onnx
|
| 23 |
+
|
| 24 |
+
# IDE / OS
|
| 25 |
+
.cursor
|
| 26 |
+
.DS_Store
|
| 27 |
+
*.swp
|
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models/v1/best.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
refs/** filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.env
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
|
| 7 |
+
# Binary / large assets (refs/ tracked with Git LFS; models not pushed to Space)
|
| 8 |
+
models/*.pt
|
| 9 |
+
models/*.onnx
|
| 10 |
+
full_frames_GT/
|
| 11 |
+
threshold_tuning/crops/
|
| 12 |
+
threshold_tuning/jina_crops/
|
| 13 |
+
threshold_tuning/nomic_crops/
|
| 14 |
+
threshold_tuning/detection_crops/
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Small Object Detection — Docker Space
|
| 2 |
+
# Match local: Python 3.10, pinned deps (requirements-lock.txt). Gradio on 7860.
|
| 3 |
+
FROM python:3.10-slim-bookworm
|
| 4 |
+
|
| 5 |
+
# System deps: font for draw_label; opencv/ultralytics headless (libxcb, glib, etc.)
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
fonts-dejavu-core \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
libxcb1 \
|
| 10 |
+
libxcb-shm0 \
|
| 11 |
+
libxcb-xfixes0 \
|
| 12 |
+
libxrender1 \
|
| 13 |
+
libsm6 \
|
| 14 |
+
libxext6 \
|
| 15 |
+
libgl1-mesa-glx \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
# HF Spaces run as user 1000
|
| 19 |
+
RUN useradd -m -u 1000 user
|
| 20 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 21 |
+
WORKDIR $HOME/app
|
| 22 |
+
USER user
|
| 23 |
+
|
| 24 |
+
# Install Python deps from lock file so Space matches local versions (no GPU at build time)
|
| 25 |
+
COPY --chown=user requirements-lock.txt .
|
| 26 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 27 |
+
&& pip install --no-cache-dir -r requirements-lock.txt
|
| 28 |
+
|
| 29 |
+
# App code (refs/ and code)
|
| 30 |
+
COPY --chown=user . .
|
| 31 |
+
|
| 32 |
+
# Gradio must listen on 0.0.0.0 for Docker
|
| 33 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 34 |
+
EXPOSE 7860
|
| 35 |
+
CMD ["python", "app.py"]
|
README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Small Object Detection
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
sdk: docker
|
| 5 |
+
app_port: 7860
|
| 6 |
+
pinned: false
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Small Object Detection
|
| 10 |
+
|
| 11 |
+
Upload an image to detect objects using the trained YOLO model **`best.pt`** in this repo. **CPU-only** — runs on basic (free) Hugging Face Spaces. The `train26m` folder is not part of this repo; only `best.pt` is included.
|
| 12 |
+
|
| 13 |
+
## Run locally
|
| 14 |
+
|
| 15 |
+
**Using uv (recommended):**
|
| 16 |
+
```bash
|
| 17 |
+
pip install uv
|
| 18 |
+
uv pip install -r requirements.txt
|
| 19 |
+
python app.py
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
**Or with pip only:**
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
python app.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Then open the URL shown in the terminal (e.g. http://127.0.0.1:7860).
|
| 29 |
+
|
| 30 |
+
## Docker (Space)
|
| 31 |
+
|
| 32 |
+
The Space builds from the Dockerfile using **Python 3.10** and **requirements-lock.txt** so the container matches a known set of versions. To match your local env exactly, from your venv run:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
pip freeze > requirements-lock.txt
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Then commit and push; the next Space build will use those exact versions.
|
app.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio app: Tab 1 = Object Detection (YOLO models/v1), Tab 2 = D-FINE + Classify (Jina or Nomic).
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
os.environ["YOLO_CONFIG_DIR"] = os.environ.get("YOLO_CONFIG_DIR", "/tmp")
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from ultralytics import YOLO
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Tab 2: D-FINE runs first, then user chooses Jina or Nomic for crop classification
|
| 13 |
+
from dfine_jina_pipeline import run_single_image
|
| 14 |
+
|
| 15 |
+
# --- Object Detection (Tab 1) ---
|
| 16 |
+
PERSON_CLASS = 0
|
| 17 |
+
CAR_CLASS = 2
|
| 18 |
+
KNIFE_CLASS = 80
|
| 19 |
+
WEAPON_CLASS = 81
|
| 20 |
+
DRAW_CLASSES = [PERSON_CLASS, CAR_CLASS, KNIFE_CLASS, WEAPON_CLASS]
|
| 21 |
+
|
| 22 |
+
CLASS_NAMES = {
|
| 23 |
+
PERSON_CLASS: "person",
|
| 24 |
+
CAR_CLASS: "car",
|
| 25 |
+
KNIFE_CLASS: "knife",
|
| 26 |
+
WEAPON_CLASS: "weapon",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
CONF = 0.25
|
| 30 |
+
IMGSZ = 640
|
| 31 |
+
|
| 32 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 33 |
+
MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| 34 |
+
REFS_DIR = os.path.join(BASE_DIR, "refs")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _load_model(version: str):
|
| 38 |
+
path = os.path.join(MODELS_DIR, version, "best.pt")
|
| 39 |
+
if not os.path.isfile(path):
|
| 40 |
+
raise FileNotFoundError(f"Model not found: {path}")
|
| 41 |
+
return YOLO(path)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
MODELS = {"v1": _load_model("v1")}
|
| 45 |
+
MODEL_CLASSES = {"v1": ["person", "car", "knife", "weapon"]}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_detection(image, model):
|
| 49 |
+
if image is None:
|
| 50 |
+
return None, "{}"
|
| 51 |
+
|
| 52 |
+
img = image if isinstance(image, np.ndarray) else np.array(image)
|
| 53 |
+
if img.ndim == 2:
|
| 54 |
+
img = np.stack([img] * 3, axis=-1)
|
| 55 |
+
|
| 56 |
+
results = model.predict(
|
| 57 |
+
source=img,
|
| 58 |
+
imgsz=IMGSZ,
|
| 59 |
+
conf=CONF,
|
| 60 |
+
device="cpu",
|
| 61 |
+
verbose=False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
r = results[0]
|
| 65 |
+
if r.boxes is None or len(r.boxes) == 0:
|
| 66 |
+
return image, json.dumps({"detections": []}, indent=2)
|
| 67 |
+
|
| 68 |
+
clss = r.boxes.cls.cpu().numpy()
|
| 69 |
+
confs = r.boxes.conf.cpu().numpy()
|
| 70 |
+
keep = [i for i in range(len(r.boxes)) if int(clss[i]) in DRAW_CLASSES]
|
| 71 |
+
|
| 72 |
+
if not keep:
|
| 73 |
+
return image, json.dumps({"detections": []}, indent=2)
|
| 74 |
+
|
| 75 |
+
detections = []
|
| 76 |
+
for i in keep:
|
| 77 |
+
cls_id = int(clss[i])
|
| 78 |
+
detections.append({
|
| 79 |
+
"class": CLASS_NAMES.get(cls_id, str(cls_id)),
|
| 80 |
+
"confidence": round(float(confs[i]), 3),
|
| 81 |
+
"bbox": r.boxes.xyxy[i].cpu().numpy().tolist(),
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
r.boxes = r.boxes[keep]
|
| 85 |
+
out_img = r.plot()
|
| 86 |
+
det_json = json.dumps({"detections": detections}, indent=2)
|
| 87 |
+
return out_img, det_json
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def run_dfine_classify(image, encoder_choice, refs_path):
|
| 91 |
+
"""Tab 2: D-FINE first, then classify crops with Jina or Nomic."""
|
| 92 |
+
if image is None:
|
| 93 |
+
return None, "Upload an image."
|
| 94 |
+
refs = Path(refs_path.strip()) if refs_path and refs_path.strip() else Path(REFS_DIR)
|
| 95 |
+
if not refs.is_dir():
|
| 96 |
+
return None, f"Refs folder not found: {refs}"
|
| 97 |
+
# Tuned on COCO GT: conf=0.5, gap=0.02. Lower det_threshold/min_side so D-FINE picks up more objects (gun, phone, etc.) like local.
|
| 98 |
+
out_img, text = run_single_image(
|
| 99 |
+
image,
|
| 100 |
+
refs_dir=refs,
|
| 101 |
+
encoder_choice=encoder_choice.lower(),
|
| 102 |
+
det_threshold=0.15,
|
| 103 |
+
conf_threshold=0.5,
|
| 104 |
+
gap_threshold=0.02,
|
| 105 |
+
min_side=24,
|
| 106 |
+
crop_dedup_iou=0.4,
|
| 107 |
+
)
|
| 108 |
+
if out_img is None:
|
| 109 |
+
return None, text
|
| 110 |
+
return out_img, text
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
IMG_HEIGHT = 400
|
| 114 |
+
|
| 115 |
+
TAB_STYLE = """
|
| 116 |
+
<style>
|
| 117 |
+
[data-testid="tabs"] > div:first-child,
|
| 118 |
+
.gr-tabs > div:first-child,
|
| 119 |
+
div[class*="tabs"] > div:first-child {
|
| 120 |
+
display: flex !important;
|
| 121 |
+
width: 100% !important;
|
| 122 |
+
}
|
| 123 |
+
[data-testid="tabs"] button,
|
| 124 |
+
.gr-tabs button,
|
| 125 |
+
div[class*="tabs"] > div:first-child button {
|
| 126 |
+
flex: 1 !important;
|
| 127 |
+
min-width: 0 !important;
|
| 128 |
+
min-height: 40px !important;
|
| 129 |
+
color: white !important;
|
| 130 |
+
font-weight: 700 !important;
|
| 131 |
+
font-size: 1rem !important;
|
| 132 |
+
text-align: center !important;
|
| 133 |
+
justify-content: center !important;
|
| 134 |
+
}
|
| 135 |
+
[data-testid="tabs"] button:not([aria-selected="true"]),
|
| 136 |
+
.gr-tabs button:not([aria-selected="true"]),
|
| 137 |
+
div[class*="tabs"] > div:first-child button:not([aria-selected="true"]) {
|
| 138 |
+
background: #6b7280 !important;
|
| 139 |
+
border-color: #6b7280 !important;
|
| 140 |
+
}
|
| 141 |
+
[data-testid="tabs"] button[aria-selected="true"],
|
| 142 |
+
.gr-tabs button[aria-selected="true"],
|
| 143 |
+
div[class*="tabs"] > div:first-child button[aria-selected="true"] {
|
| 144 |
+
background: var(--primary-500, #f97316) !important;
|
| 145 |
+
border-color: var(--primary-500, #f97316) !important;
|
| 146 |
+
}
|
| 147 |
+
</style>
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
with gr.Blocks(title="Small Object Detection") as app:
|
| 151 |
+
gr.HTML(TAB_STYLE)
|
| 152 |
+
gr.Markdown("# Small Object Detection")
|
| 153 |
+
|
| 154 |
+
with gr.Tabs():
|
| 155 |
+
with gr.TabItem("Object Detection"):
|
| 156 |
+
gr.Markdown("**Classes:** " + ", ".join(MODEL_CLASSES["v1"]))
|
| 157 |
+
with gr.Row():
|
| 158 |
+
with gr.Column(scale=1):
|
| 159 |
+
inp_det = gr.Image(label="Input image", height=IMG_HEIGHT)
|
| 160 |
+
btn_det = gr.Button("Detect", variant="primary")
|
| 161 |
+
out_img_det = gr.Image(label="Output", height=IMG_HEIGHT)
|
| 162 |
+
det_output = gr.JSON(label="Detections")
|
| 163 |
+
btn_det.click(
|
| 164 |
+
fn=lambda img: run_detection(img, MODELS["v1"]),
|
| 165 |
+
inputs=inp_det,
|
| 166 |
+
outputs=[out_img_det, det_output],
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
with gr.TabItem("D-FINE + Classify"):
|
| 170 |
+
gr.Markdown(
|
| 171 |
+
"**D-FINE** runs first (person/car grouping), then small-object crops are classified. "
|
| 172 |
+
"Choose **Jina** or **Nomic** for the embedding/classification model. "
|
| 173 |
+
"Uses the **refs** folder (one subfolder per class, e.g. refs/phone/, refs/cigarette/) with reference images."
|
| 174 |
+
)
|
| 175 |
+
with gr.Row():
|
| 176 |
+
with gr.Column(scale=1):
|
| 177 |
+
inp_dfine = gr.Image(type="pil", label="Input image", height=IMG_HEIGHT)
|
| 178 |
+
encoder_choice = gr.Radio(
|
| 179 |
+
choices=["Jina", "Nomic"],
|
| 180 |
+
value="Jina",
|
| 181 |
+
label="Embedding / classification model",
|
| 182 |
+
)
|
| 183 |
+
refs_path = gr.Textbox(
|
| 184 |
+
label="Refs folder path",
|
| 185 |
+
value=REFS_DIR,
|
| 186 |
+
placeholder="e.g. refs or /path/to/refs",
|
| 187 |
+
)
|
| 188 |
+
btn_dfine = gr.Button("Run D-FINE + Classify", variant="primary")
|
| 189 |
+
with gr.Column(scale=1):
|
| 190 |
+
out_img_dfine = gr.Image(label="Output (crops with labels)", height=IMG_HEIGHT)
|
| 191 |
+
out_text_dfine = gr.Textbox(label="Crop predictions", lines=10, interactive=False)
|
| 192 |
+
btn_dfine.click(
|
| 193 |
+
fn=run_dfine_classify,
|
| 194 |
+
inputs=[inp_dfine, encoder_choice, refs_path],
|
| 195 |
+
outputs=[out_img_dfine, out_text_dfine],
|
| 196 |
+
concurrency_limit=1,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
app.launch(
|
| 200 |
+
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
|
| 201 |
+
server_port=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", 7860))),
|
| 202 |
+
)
|
dfine_jina_pipeline.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pipeline: D-FINE (person/car only) → group detections → crop regions →
|
| 3 |
+
find all bboxes inside each crop → Jina-CLIP-v2 and Nomic embeddings on those crops.
|
| 4 |
+
|
| 5 |
+
Outputs separate crop folders per model (jina_crops, nomic_crops) for visual comparison.
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import csv
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from transformers import AutoImageProcessor, DFineForObjectDetection
|
| 17 |
+
|
| 18 |
+
# Jina-CLIP-v2 few-shot (same refs + classify as jina_fewshot.py)
|
| 19 |
+
from jina_fewshot import (
|
| 20 |
+
IMAGE_EXTS,
|
| 21 |
+
TRUNCATE_DIM,
|
| 22 |
+
JinaCLIPv2Encoder,
|
| 23 |
+
build_refs,
|
| 24 |
+
classify as jina_classify,
|
| 25 |
+
draw_label_on_image,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# -----------------------------------------------------------------------------
|
| 29 |
+
# Detection + grouping (from reference_detection.py)
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_box_dist(box1, box2):
|
| 34 |
+
"""Euclidean distance between box centers. box = [x1, y1, x2, y2]."""
|
| 35 |
+
c1 = np.array([(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2])
|
| 36 |
+
c2 = np.array([(box2[0] + box2[2]) / 2, (box2[1] + box2[3]) / 2])
|
| 37 |
+
return np.linalg.norm(c1 - c2)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def group_detections(detections, threshold):
|
| 41 |
+
"""
|
| 42 |
+
Group detections by proximity (center distance < threshold).
|
| 43 |
+
detections: list of {"box": [x1,y1,x2,y2], "conf", "cls", ...}
|
| 44 |
+
Returns list of {"box": merged [x1,y1,x2,y2], "conf": best in group, "cls": best in group}.
|
| 45 |
+
"""
|
| 46 |
+
if not detections:
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
boxes = [d["box"] for d in detections]
|
| 50 |
+
n = len(boxes)
|
| 51 |
+
adj = {i: [] for i in range(n)}
|
| 52 |
+
for i in range(n):
|
| 53 |
+
for j in range(i + 1, n):
|
| 54 |
+
if get_box_dist(boxes[i], boxes[j]) < threshold:
|
| 55 |
+
adj[i].append(j)
|
| 56 |
+
adj[j].append(i)
|
| 57 |
+
|
| 58 |
+
groups = []
|
| 59 |
+
visited = [False] * n
|
| 60 |
+
for i in range(n):
|
| 61 |
+
if not visited[i]:
|
| 62 |
+
group_indices = []
|
| 63 |
+
stack = [i]
|
| 64 |
+
visited[i] = True
|
| 65 |
+
while stack:
|
| 66 |
+
curr = stack.pop()
|
| 67 |
+
group_indices.append(curr)
|
| 68 |
+
for neighbor in adj[curr]:
|
| 69 |
+
if not visited[neighbor]:
|
| 70 |
+
visited[neighbor] = True
|
| 71 |
+
stack.append(neighbor)
|
| 72 |
+
|
| 73 |
+
group_dets = [detections[k] for k in group_indices]
|
| 74 |
+
x1 = min(d["box"][0] for d in group_dets)
|
| 75 |
+
y1 = min(d["box"][1] for d in group_dets)
|
| 76 |
+
x2 = max(d["box"][2] for d in group_dets)
|
| 77 |
+
y2 = max(d["box"][3] for d in group_dets)
|
| 78 |
+
best_det = max(group_dets, key=lambda x: x["conf"])
|
| 79 |
+
|
| 80 |
+
groups.append({
|
| 81 |
+
"box": [x1, y1, x2, y2],
|
| 82 |
+
"conf": best_det["conf"],
|
| 83 |
+
"cls": best_det["cls"],
|
| 84 |
+
"label": best_det.get("label", str(best_det["cls"])),
|
| 85 |
+
})
|
| 86 |
+
return groups
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def box_center_inside(box, crop_box):
|
| 90 |
+
"""True if center of box is inside crop_box. All [x1,y1,x2,y2]."""
|
| 91 |
+
cx = (box[0] + box[2]) / 2
|
| 92 |
+
cy = (box[1] + box[3]) / 2
|
| 93 |
+
return (
|
| 94 |
+
crop_box[0] <= cx <= crop_box[2]
|
| 95 |
+
and crop_box[1] <= cy <= crop_box[3]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def squarify_crop_box(bx1, by1, bx2, by2, img_w, img_h):
|
| 100 |
+
"""
|
| 101 |
+
Expand the shorter side to match the longer (same ratio / square), centered, clamped to image.
|
| 102 |
+
If height > width: expand width. If width >= height: expand height.
|
| 103 |
+
Returns (bx1, by1, bx2, by2) as integers.
|
| 104 |
+
"""
|
| 105 |
+
orig = (int(bx1), int(by1), int(bx2), int(by2))
|
| 106 |
+
w = bx2 - bx1
|
| 107 |
+
h = by2 - by1
|
| 108 |
+
if w <= 0 or h <= 0:
|
| 109 |
+
return orig
|
| 110 |
+
if h > w:
|
| 111 |
+
add = (h - w) / 2.0
|
| 112 |
+
bx1 = max(0, bx1 - add)
|
| 113 |
+
bx2 = min(img_w, bx2 + add)
|
| 114 |
+
else:
|
| 115 |
+
add = (w - h) / 2.0
|
| 116 |
+
by1 = max(0, by1 - add)
|
| 117 |
+
by2 = min(img_h, by2 + add)
|
| 118 |
+
bx1, by1, bx2, by2 = int(bx1), int(by1), int(bx2), int(by2)
|
| 119 |
+
if bx2 <= bx1 or by2 <= by1:
|
| 120 |
+
return orig
|
| 121 |
+
return bx1, by1, bx2, by2
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def box_iou(box1, box2):
|
| 125 |
+
"""IoU of two boxes [x1,y1,x2,y2]. Returns float in [0, 1]."""
|
| 126 |
+
ix1 = max(box1[0], box2[0])
|
| 127 |
+
iy1 = max(box1[1], box2[1])
|
| 128 |
+
ix2 = min(box1[2], box2[2])
|
| 129 |
+
iy2 = min(box1[3], box2[3])
|
| 130 |
+
inter_w = max(0, ix2 - ix1)
|
| 131 |
+
inter_h = max(0, iy2 - iy1)
|
| 132 |
+
inter = inter_w * inter_h
|
| 133 |
+
a1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 134 |
+
a2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 135 |
+
union = a1 + a2 - inter
|
| 136 |
+
return inter / union if union > 0 else 0.0
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def deduplicate_by_iou(detections, iou_threshold=0.9):
|
| 140 |
+
"""Keep one detection per overlapping group (IoU >= iou_threshold). Prefer higher confidence."""
|
| 141 |
+
if not detections:
|
| 142 |
+
return []
|
| 143 |
+
# Sort by confidence descending; keep first, then add only if no kept box overlaps >= threshold
|
| 144 |
+
sorted_d = sorted(detections, key=lambda x: -x["conf"])
|
| 145 |
+
kept = []
|
| 146 |
+
for d in sorted_d:
|
| 147 |
+
if not any(box_iou(d["box"], k["box"]) >= iou_threshold for k in kept):
|
| 148 |
+
kept.append(d)
|
| 149 |
+
return kept
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def parse_args():
|
| 156 |
+
p = argparse.ArgumentParser(
|
| 157 |
+
description="D-FINE (person/car) → group → Jina-CLIP-v2 on crops inside groups"
|
| 158 |
+
)
|
| 159 |
+
p.add_argument("--refs", required=True, help="Reference images folder for Jina and Nomic (e.g. refs/)")
|
| 160 |
+
p.add_argument("--input", required=True, help="Full-frame images folder")
|
| 161 |
+
p.add_argument("--output", default="pipeline_results", help="Output folder (CSV, etc.)")
|
| 162 |
+
p.add_argument("--det-threshold", type=float, default=0.13, help="D-FINE score threshold")
|
| 163 |
+
p.add_argument("--group-dist", type=float, default=None,
|
| 164 |
+
help="Group distance (default: 0.1 * max(H,W))")
|
| 165 |
+
p.add_argument("--min-side", type=int, default=40, help="Min side of expanded bbox in px (skip smaller)")
|
| 166 |
+
p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)")
|
| 167 |
+
p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)")
|
| 168 |
+
p.add_argument("--padding", type=float, default=0.2, help="Crop padding around group box (0.2 = 20%%)")
|
| 169 |
+
p.add_argument("--conf-threshold", type=float, default=0.75, help="Jina accept confidence")
|
| 170 |
+
p.add_argument("--gap-threshold", type=float, default=0.05, help="Jina accept gap")
|
| 171 |
+
p.add_argument("--text-weight", type=float, default=0.3)
|
| 172 |
+
p.add_argument("--max-images", type=int, default=None)
|
| 173 |
+
p.add_argument("--device", default=None)
|
| 174 |
+
return p.parse_args()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_person_car_label_ids(model):
|
| 178 |
+
"""Return set of label IDs for person and car (Objects365: Person, Car, SUV, etc.)."""
|
| 179 |
+
id2label = getattr(model.config, "id2label", None) or {}
|
| 180 |
+
ids = set()
|
| 181 |
+
for idx, name in id2label.items():
|
| 182 |
+
try:
|
| 183 |
+
i = int(idx)
|
| 184 |
+
except (ValueError, TypeError):
|
| 185 |
+
continue
|
| 186 |
+
n = (name or "").lower()
|
| 187 |
+
if "person" in n or n in ("car", "suv"):
|
| 188 |
+
ids.add(i)
|
| 189 |
+
return ids
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def run_dfine(image, processor, model, device, score_threshold):
|
| 193 |
+
"""Run D-FINE, return all detections as list of {box, score, label_id, label}."""
|
| 194 |
+
from PIL import Image
|
| 195 |
+
if isinstance(image, Image.Image):
|
| 196 |
+
pil = image.convert("RGB")
|
| 197 |
+
else:
|
| 198 |
+
pil = Image.fromarray(image).convert("RGB")
|
| 199 |
+
w, h = pil.size
|
| 200 |
+
target_size = torch.tensor([[h, w]], device=device)
|
| 201 |
+
inputs = processor(images=pil, return_tensors="pt")
|
| 202 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 203 |
+
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
outputs = model(**inputs)
|
| 206 |
+
|
| 207 |
+
target_sizes = target_size.to(outputs["logits"].device)
|
| 208 |
+
results = processor.post_process_object_detection(
|
| 209 |
+
outputs, target_sizes=target_sizes, threshold=score_threshold
|
| 210 |
+
)
|
| 211 |
+
id2label = getattr(model.config, "id2label", {}) or {}
|
| 212 |
+
|
| 213 |
+
detections = []
|
| 214 |
+
for result in results:
|
| 215 |
+
for score, label_id, box in zip(
|
| 216 |
+
result["scores"], result["labels"], result["boxes"]
|
| 217 |
+
):
|
| 218 |
+
sid = int(label_id.item())
|
| 219 |
+
detections.append({
|
| 220 |
+
"box": [float(x) for x in box.cpu().tolist()],
|
| 221 |
+
"conf": float(score.item()),
|
| 222 |
+
"cls": sid,
|
| 223 |
+
"label": id2label.get(sid, str(sid)),
|
| 224 |
+
})
|
| 225 |
+
return detections
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def main():
|
| 229 |
+
args = parse_args()
|
| 230 |
+
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 231 |
+
input_dir = Path(args.input)
|
| 232 |
+
output_dir = Path(args.output)
|
| 233 |
+
refs_dir = Path(args.refs)
|
| 234 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 235 |
+
|
| 236 |
+
if not refs_dir.is_dir():
|
| 237 |
+
raise SystemExit(f"Refs folder not found: {refs_dir}")
|
| 238 |
+
if not input_dir.is_dir():
|
| 239 |
+
raise SystemExit(f"Input folder not found: {input_dir}")
|
| 240 |
+
|
| 241 |
+
paths = sorted(
|
| 242 |
+
p for p in input_dir.iterdir()
|
| 243 |
+
if p.suffix.lower() in IMAGE_EXTS
|
| 244 |
+
)
|
| 245 |
+
if args.max_images is not None:
|
| 246 |
+
paths = paths[: args.max_images]
|
| 247 |
+
if not paths:
|
| 248 |
+
raise SystemExit(f"No images in {input_dir}")
|
| 249 |
+
|
| 250 |
+
# Load D-FINE
|
| 251 |
+
print("[*] Loading D-FINE (dfine-medium-obj365)...")
|
| 252 |
+
t0 = time.perf_counter()
|
| 253 |
+
image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj365")
|
| 254 |
+
dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365")
|
| 255 |
+
dfine_model = dfine_model.to(device).eval()
|
| 256 |
+
person_car_ids = get_person_car_label_ids(dfine_model)
|
| 257 |
+
print(f" Person/car label IDs: {person_car_ids} ({time.perf_counter()-t0:.1f}s)")
|
| 258 |
+
|
| 259 |
+
# Load Jina-CLIP-v2 + build refs
|
| 260 |
+
print("[*] Loading Jina-CLIP-v2 and building refs...")
|
| 261 |
+
t0 = time.perf_counter()
|
| 262 |
+
jina_encoder = JinaCLIPv2Encoder(device)
|
| 263 |
+
ref_labels, ref_embs = build_refs(
|
| 264 |
+
jina_encoder, refs_dir, TRUNCATE_DIM, args.text_weight, batch_size=16
|
| 265 |
+
)
|
| 266 |
+
print(f" Jina refs: {ref_labels} ({time.perf_counter()-t0:.1f}s)\n")
|
| 267 |
+
|
| 268 |
+
# Load Nomic vision + text, build refs (same as Jina: image + text prompts, text_weight 0.3)
|
| 269 |
+
print("[*] Loading Nomic embed-vision + embed-text and building refs...")
|
| 270 |
+
t0 = time.perf_counter()
|
| 271 |
+
nomic_encoder = NomicVisionEncoder(device)
|
| 272 |
+
nomic_text_encoder = NomicTextEncoder(device)
|
| 273 |
+
ref_labels_nomic, ref_embs_nomic = build_refs_nomic(
|
| 274 |
+
nomic_encoder, refs_dir, batch_size=16,
|
| 275 |
+
text_encoder=nomic_text_encoder, text_weight=args.text_weight,
|
| 276 |
+
)
|
| 277 |
+
print(f" Nomic refs: {ref_labels_nomic} ({time.perf_counter()-t0:.1f}s)\n")
|
| 278 |
+
|
| 279 |
+
# Separate output folders per model for visual comparison
|
| 280 |
+
jina_crops_dir = output_dir / "jina_crops"
|
| 281 |
+
nomic_crops_dir = output_dir / "nomic_crops"
|
| 282 |
+
jina_crops_dir.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
nomic_crops_dir.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
# CSV
|
| 286 |
+
csv_path = output_dir / "results.csv"
|
| 287 |
+
f = open(csv_path, "w", newline="")
|
| 288 |
+
w = csv.writer(f)
|
| 289 |
+
w.writerow([
|
| 290 |
+
"image", "crop_filename", "group_idx", "crop_x1", "crop_y1", "crop_x2", "crop_y2",
|
| 291 |
+
"bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2", "dfine_label", "dfine_conf",
|
| 292 |
+
"jina_prediction", "jina_confidence", "jina_status",
|
| 293 |
+
"nomic_prediction", "nomic_confidence", "nomic_status",
|
| 294 |
+
])
|
| 295 |
+
|
| 296 |
+
for img_path in paths:
|
| 297 |
+
pil = Image.open(img_path).convert("RGB")
|
| 298 |
+
img_w, img_h = pil.size
|
| 299 |
+
group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w)
|
| 300 |
+
|
| 301 |
+
# 1) D-FINE: detect everything, keep all bboxes for the image
|
| 302 |
+
detections = run_dfine(
|
| 303 |
+
pil, image_processor, dfine_model, device, args.det_threshold
|
| 304 |
+
)
|
| 305 |
+
person_car = [d for d in detections if d["cls"] in person_car_ids]
|
| 306 |
+
if not person_car:
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
# 2) Group person/car detections (same as reference)
|
| 310 |
+
grouped = group_detections(person_car, group_dist)
|
| 311 |
+
grouped.sort(key=lambda x: x["conf"], reverse=True)
|
| 312 |
+
top_groups = grouped[:10] # limit groups per image
|
| 313 |
+
|
| 314 |
+
# 3) Collect all candidate crops (bboxes inside person/car groups)
|
| 315 |
+
# Each: (crop_box, crop_pil, d, gidx, crop_idx, x1, y1, x2, y2)
|
| 316 |
+
candidates = []
|
| 317 |
+
for gidx, grp in enumerate(top_groups):
|
| 318 |
+
x1, y1, x2, y2 = grp["box"]
|
| 319 |
+
group_box = [x1, y1, x2, y2]
|
| 320 |
+
inside = [
|
| 321 |
+
d for d in detections
|
| 322 |
+
if box_center_inside(d["box"], group_box)
|
| 323 |
+
and d["cls"] not in person_car_ids
|
| 324 |
+
]
|
| 325 |
+
inside = deduplicate_by_iou(inside, iou_threshold=0.9)
|
| 326 |
+
|
| 327 |
+
for crop_idx, d in enumerate(inside):
|
| 328 |
+
bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
|
| 329 |
+
obj_w, obj_h = bx2 - bx1, by2 - by1
|
| 330 |
+
if obj_w <= 0 or obj_h <= 0:
|
| 331 |
+
continue
|
| 332 |
+
pad_x = obj_w * 0.3
|
| 333 |
+
pad_y = obj_h * 0.3
|
| 334 |
+
bx1 = max(0, int(bx1 - pad_x))
|
| 335 |
+
by1 = max(0, int(by1 - pad_y))
|
| 336 |
+
bx2 = min(img_w, int(bx2 + pad_x))
|
| 337 |
+
by2 = min(img_h, int(by2 + pad_y))
|
| 338 |
+
if bx2 <= bx1 or by2 <= by1:
|
| 339 |
+
continue
|
| 340 |
+
if min(bx2 - bx1, by2 - by1) < args.min_side:
|
| 341 |
+
continue
|
| 342 |
+
expanded_box = [bx1, by1, bx2, by2]
|
| 343 |
+
candidates.append((expanded_box, d, gidx, crop_idx, x1, y1, x2, y2))
|
| 344 |
+
|
| 345 |
+
# 4) Dedup on EXPANDED boxes (before squarify), keep larger; then squarify only kept
|
| 346 |
+
def crop_area(box):
|
| 347 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 348 |
+
|
| 349 |
+
candidates.sort(key=lambda c: -crop_area(c[0]))
|
| 350 |
+
kept = []
|
| 351 |
+
for c in candidates:
|
| 352 |
+
expanded_box = c[0]
|
| 353 |
+
def is_same_object(box_a, box_b):
|
| 354 |
+
if box_iou(box_a, box_b) >= args.crop_dedup_iou:
|
| 355 |
+
return True
|
| 356 |
+
if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
|
| 357 |
+
return True
|
| 358 |
+
return False
|
| 359 |
+
if not any(is_same_object(expanded_box, k[0]) for k in kept):
|
| 360 |
+
kept.append(c)
|
| 361 |
+
|
| 362 |
+
# 5) Optionally squarify, then run Jina and Nomic only on kept crops
|
| 363 |
+
for i, (expanded_box, d, gidx, crop_idx, x1, y1, x2, y2) in enumerate(kept):
|
| 364 |
+
if not args.no_squarify:
|
| 365 |
+
bx1, by1, bx2, by2 = squarify_crop_box(
|
| 366 |
+
expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3], img_w, img_h
|
| 367 |
+
)
|
| 368 |
+
else:
|
| 369 |
+
bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]
|
| 370 |
+
crop_pil = pil.crop((bx1, by1, bx2, by2))
|
| 371 |
+
crop_name = f"{img_path.stem}_g{gidx}_{i}_{bx1}_{by1}_{bx2}_{by2}{img_path.suffix}"
|
| 372 |
+
|
| 373 |
+
q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
|
| 374 |
+
result_jina = jina_classify(
|
| 375 |
+
q_jina, ref_labels, ref_embs,
|
| 376 |
+
args.conf_threshold, args.gap_threshold
|
| 377 |
+
)
|
| 378 |
+
if result_jina["prediction"] in ref_labels:
|
| 379 |
+
label_jina = result_jina["prediction"]
|
| 380 |
+
conf_jina = result_jina["confidence"]
|
| 381 |
+
else:
|
| 382 |
+
label_jina = f"unnamed (dfine: {d['label']})"
|
| 383 |
+
conf_jina = 0.0
|
| 384 |
+
ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina)
|
| 385 |
+
ann_jina.save(jina_crops_dir / crop_name)
|
| 386 |
+
|
| 387 |
+
q_nomic = nomic_encoder.encode_images([crop_pil])
|
| 388 |
+
result_nomic = jina_classify(
|
| 389 |
+
q_nomic, ref_labels_nomic, ref_embs_nomic,
|
| 390 |
+
args.conf_threshold, args.gap_threshold
|
| 391 |
+
)
|
| 392 |
+
if result_nomic["prediction"] in ref_labels_nomic:
|
| 393 |
+
label_nomic = result_nomic["prediction"]
|
| 394 |
+
conf_nomic = result_nomic["confidence"]
|
| 395 |
+
else:
|
| 396 |
+
label_nomic = f"unnamed (dfine: {d['label']})"
|
| 397 |
+
conf_nomic = 0.0
|
| 398 |
+
ann_nomic = draw_label_on_image(crop_pil, label_nomic, conf_nomic)
|
| 399 |
+
ann_nomic.save(nomic_crops_dir / crop_name)
|
| 400 |
+
|
| 401 |
+
w.writerow([
|
| 402 |
+
img_path.name, crop_name, gidx,
|
| 403 |
+
x1, y1, x2, y2,
|
| 404 |
+
bx1, by1, bx2, by2,
|
| 405 |
+
d["label"], f"{d['conf']:.4f}",
|
| 406 |
+
result_jina["prediction"], f"{result_jina['confidence']:.4f}", result_jina["status"],
|
| 407 |
+
result_nomic["prediction"], f"{result_nomic['confidence']:.4f}", result_nomic["status"],
|
| 408 |
+
])
|
| 409 |
+
|
| 410 |
+
f.close()
|
| 411 |
+
print(f"[*] Wrote {csv_path}")
|
| 412 |
+
print(f"[*] Jina crops: {jina_crops_dir}")
|
| 413 |
+
print(f"[*] Nomic crops: {nomic_crops_dir}")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# -----------------------------------------------------------------------------
|
| 417 |
+
# Single-image runner for Gradio app: D-FINE first, then Jina or Nomic (user choice)
|
| 418 |
+
# -----------------------------------------------------------------------------
|
| 419 |
+
_APP_DFINE = None
|
| 420 |
+
_APP_JINA = None
|
| 421 |
+
_APP_NOMIC = None
|
| 422 |
+
_APP_REFS_JINA = None
|
| 423 |
+
_APP_REFS_NOMIC = None
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def run_single_image(
|
| 427 |
+
pil_image,
|
| 428 |
+
refs_dir,
|
| 429 |
+
device=None,
|
| 430 |
+
encoder_choice="jina",
|
| 431 |
+
det_threshold=0.3,
|
| 432 |
+
conf_threshold=0.75,
|
| 433 |
+
gap_threshold=0.05,
|
| 434 |
+
min_side=40,
|
| 435 |
+
crop_dedup_iou=0.35,
|
| 436 |
+
squarify=True,
|
| 437 |
+
):
|
| 438 |
+
"""
|
| 439 |
+
Run D-FINE on one image, then classify small-object crops with Jina or Nomic.
|
| 440 |
+
refs_dir: path to refs folder (str or Path). encoder_choice: "jina" or "nomic".
|
| 441 |
+
Returns (annotated_pil, result_text) for display in app.
|
| 442 |
+
"""
|
| 443 |
+
import numpy as np
|
| 444 |
+
from PIL import Image
|
| 445 |
+
global _APP_DFINE, _APP_JINA, _APP_NOMIC, _APP_REFS_JINA, _APP_REFS_NOMIC
|
| 446 |
+
refs_dir = Path(refs_dir)
|
| 447 |
+
if not refs_dir.is_dir():
|
| 448 |
+
return None, f"Refs folder not found: {refs_dir}"
|
| 449 |
+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 450 |
+
print(f"[*] Device: {device}")
|
| 451 |
+
pil = pil_image.convert("RGB") if isinstance(pil_image, Image.Image) else Image.fromarray(pil_image).convert("RGB")
|
| 452 |
+
img_w, img_h = pil.size
|
| 453 |
+
group_dist = 0.1 * max(img_h, img_w)
|
| 454 |
+
|
| 455 |
+
# Load D-FINE once
|
| 456 |
+
if _APP_DFINE is None:
|
| 457 |
+
image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj365")
|
| 458 |
+
dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365")
|
| 459 |
+
dfine_model = dfine_model.to(device).eval()
|
| 460 |
+
person_car_ids = get_person_car_label_ids(dfine_model)
|
| 461 |
+
_APP_DFINE = (image_processor, dfine_model, person_car_ids)
|
| 462 |
+
image_processor, dfine_model, person_car_ids = _APP_DFINE
|
| 463 |
+
|
| 464 |
+
detections = run_dfine(pil, image_processor, dfine_model, device, det_threshold)
|
| 465 |
+
person_car = [d for d in detections if d["cls"] in person_car_ids]
|
| 466 |
+
if not person_car:
|
| 467 |
+
return np.array(pil), "No person/car detected. No small-object crops."
|
| 468 |
+
|
| 469 |
+
grouped = group_detections(person_car, group_dist)
|
| 470 |
+
grouped.sort(key=lambda x: x["conf"], reverse=True)
|
| 471 |
+
top_groups = grouped[:10]
|
| 472 |
+
candidates = []
|
| 473 |
+
for gidx, grp in enumerate(top_groups):
|
| 474 |
+
x1, y1, x2, y2 = grp["box"]
|
| 475 |
+
group_box = [x1, y1, x2, y2]
|
| 476 |
+
inside = [
|
| 477 |
+
d for d in detections
|
| 478 |
+
if box_center_inside(d["box"], group_box) and d["cls"] not in person_car_ids
|
| 479 |
+
]
|
| 480 |
+
inside = deduplicate_by_iou(inside, iou_threshold=0.9)
|
| 481 |
+
for crop_idx, d in enumerate(inside):
|
| 482 |
+
bx1, by1, bx2, by2 = [float(x) for x in d["box"]]
|
| 483 |
+
obj_w, obj_h = bx2 - bx1, by2 - by1
|
| 484 |
+
if obj_w <= 0 or obj_h <= 0:
|
| 485 |
+
continue
|
| 486 |
+
pad_x, pad_y = obj_w * 0.3, obj_h * 0.3
|
| 487 |
+
bx1 = max(0, int(bx1 - pad_x))
|
| 488 |
+
by1 = max(0, int(by1 - pad_y))
|
| 489 |
+
bx2 = min(img_w, int(bx2 + pad_x))
|
| 490 |
+
by2 = min(img_h, int(by2 + pad_y))
|
| 491 |
+
if bx2 <= bx1 or by2 <= by1:
|
| 492 |
+
continue
|
| 493 |
+
if min(bx2 - bx1, by2 - by1) < min_side:
|
| 494 |
+
continue
|
| 495 |
+
expanded_box = [bx1, by1, bx2, by2]
|
| 496 |
+
candidates.append((expanded_box, d, gidx, crop_idx))
|
| 497 |
+
|
| 498 |
+
def crop_area(box):
|
| 499 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 500 |
+
|
| 501 |
+
candidates.sort(key=lambda c: -crop_area(c[0]))
|
| 502 |
+
kept = []
|
| 503 |
+
for c in candidates:
|
| 504 |
+
def is_same_object(box_a, box_b):
|
| 505 |
+
if box_iou(box_a, box_b) >= crop_dedup_iou:
|
| 506 |
+
return True
|
| 507 |
+
if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a):
|
| 508 |
+
return True
|
| 509 |
+
return False
|
| 510 |
+
if not any(is_same_object(c[0], k[0]) for k in kept):
|
| 511 |
+
kept.append(c)
|
| 512 |
+
|
| 513 |
+
if not kept:
|
| 514 |
+
if not candidates:
|
| 515 |
+
return np.array(pil), "No small-object crops: D-FINE did not detect any object (gun/phone/etc.) inside person/car areas, or all were below min size. Try a higher-resolution image."
|
| 516 |
+
return np.array(pil), "No small-object crops (after dedup)."
|
| 517 |
+
|
| 518 |
+
# Load encoder + refs for chosen model
|
| 519 |
+
if encoder_choice == "jina":
|
| 520 |
+
if _APP_JINA is None or _APP_REFS_JINA != str(refs_dir):
|
| 521 |
+
jina_encoder = JinaCLIPv2Encoder(device)
|
| 522 |
+
ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
|
| 523 |
+
_APP_JINA = (jina_encoder, ref_labels, ref_embs)
|
| 524 |
+
_APP_REFS_JINA = str(refs_dir)
|
| 525 |
+
jina_encoder, ref_labels, ref_embs = _APP_JINA
|
| 526 |
+
else:
|
| 527 |
+
if _APP_NOMIC is None or _APP_REFS_NOMIC != str(refs_dir):
|
| 528 |
+
nomic_encoder = NomicVisionEncoder(device)
|
| 529 |
+
nomic_text_encoder = NomicTextEncoder(device)
|
| 530 |
+
ref_labels, ref_embs = build_refs_nomic(
|
| 531 |
+
nomic_encoder, refs_dir, batch_size=16,
|
| 532 |
+
text_encoder=nomic_text_encoder, text_weight=0.3,
|
| 533 |
+
)
|
| 534 |
+
_APP_NOMIC = (nomic_encoder, ref_labels, ref_embs)
|
| 535 |
+
_APP_REFS_NOMIC = str(refs_dir)
|
| 536 |
+
nomic_encoder, ref_labels, ref_embs = _APP_NOMIC
|
| 537 |
+
|
| 538 |
+
lines = []
|
| 539 |
+
out_img = pil.copy()
|
| 540 |
+
for i, (expanded_box, d, gidx, crop_idx) in enumerate(kept):
|
| 541 |
+
if squarify:
|
| 542 |
+
bx1, by1, bx2, by2 = squarify_crop_box(
|
| 543 |
+
expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3], img_w, img_h
|
| 544 |
+
)
|
| 545 |
+
else:
|
| 546 |
+
bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3]
|
| 547 |
+
crop_pil = pil.crop((bx1, by1, bx2, by2))
|
| 548 |
+
if encoder_choice == "jina":
|
| 549 |
+
q = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM)
|
| 550 |
+
result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
|
| 551 |
+
else:
|
| 552 |
+
q = nomic_encoder.encode_images([crop_pil])
|
| 553 |
+
result = jina_classify(q, ref_labels, ref_embs, conf_threshold, gap_threshold)
|
| 554 |
+
pred = result["prediction"] if result["prediction"] in ref_labels else f"unknown ({d['label']})"
|
| 555 |
+
conf = result["confidence"]
|
| 556 |
+
lines.append(f"Crop {i+1}: {pred} ({conf:.2f})")
|
| 557 |
+
labeled = draw_label_on_image(crop_pil, pred, conf)
|
| 558 |
+
out_img.paste(labeled, (bx1, by1))
|
| 559 |
+
|
| 560 |
+
result_text = "\n".join(lines) if lines else "No crops"
|
| 561 |
+
return np.array(out_img), result_text
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if __name__ == "__main__":
|
| 565 |
+
main()
|
jina_fewshot.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Few-shot object classification using jina-clip-v2 (EVA02-L, 304M).
|
| 3 |
+
|
| 4 |
+
Combines IMAGE embeddings from reference photos + TEXT embeddings
|
| 5 |
+
from class names. Dual threshold: confidence + gap between top-1 and top-2.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python jina_clipv2_fewshot.py \
|
| 9 |
+
--refs refs/ \
|
| 10 |
+
--input crops/ \
|
| 11 |
+
--output results/ \
|
| 12 |
+
--text-weight 0.3 \
|
| 13 |
+
--conf-threshold 0.75 \
|
| 14 |
+
--gap-threshold 0.05
|
| 15 |
+
|
| 16 |
+
refs/ folder structure (3-10 images per class recommended):
|
| 17 |
+
refs/
|
| 18 |
+
├── cigarette/
|
| 19 |
+
├── gun/
|
| 20 |
+
├── knife/
|
| 21 |
+
├── phone/
|
| 22 |
+
└── nothing/ (empty hands, random objects)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import csv
|
| 27 |
+
import json
|
| 28 |
+
import time
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 34 |
+
from transformers import AutoModel
|
| 35 |
+
|
| 36 |
+
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"}
|
| 37 |
+
TRUNCATE_DIM = 1024
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _to_numpy(embs):
|
| 41 |
+
"""Convert to numpy; Jina may return tensor on some code paths."""
|
| 42 |
+
if hasattr(embs, "cpu"):
|
| 43 |
+
embs = embs.cpu().float().numpy()
|
| 44 |
+
return np.asarray(embs, dtype=np.float64)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def draw_label_on_image(img: Image.Image, label: str, confidence: float) -> Image.Image:
|
| 48 |
+
"""Draw the label in a bar outside and on top of the image (full width). Returns new image."""
|
| 49 |
+
img = img.convert("RGB")
|
| 50 |
+
w, h = img.width, img.height
|
| 51 |
+
text = f"{label} ({confidence:.2f})"
|
| 52 |
+
margin = 8
|
| 53 |
+
max_text_w = max(1, w - 2 * margin)
|
| 54 |
+
|
| 55 |
+
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
|
| 56 |
+
try:
|
| 57 |
+
font_size = max(10, min(h, w) // 12)
|
| 58 |
+
font = ImageFont.truetype(font_path, size=font_size)
|
| 59 |
+
except OSError:
|
| 60 |
+
font = ImageFont.load_default()
|
| 61 |
+
font_size = None
|
| 62 |
+
|
| 63 |
+
dummy = Image.new("RGB", (1, 1))
|
| 64 |
+
ddraw = ImageDraw.Draw(dummy)
|
| 65 |
+
bbox = ddraw.textbbox((0, 0), text, font=font)
|
| 66 |
+
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
| 67 |
+
# Shrink font until text fits within image width (only when using truetype)
|
| 68 |
+
if font_size is not None:
|
| 69 |
+
while tw > max_text_w and font_size > 8:
|
| 70 |
+
font_size = max(8, font_size - 2)
|
| 71 |
+
font = ImageFont.truetype(font_path, size=font_size)
|
| 72 |
+
bbox = ddraw.textbbox((0, 0), text, font=font)
|
| 73 |
+
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
| 74 |
+
bar_height = th + 2 * margin
|
| 75 |
+
|
| 76 |
+
# Canvas: bar on top (full width) + image below
|
| 77 |
+
out = Image.new("RGB", (w, bar_height + h), color=(255, 255, 255))
|
| 78 |
+
draw = ImageDraw.Draw(out)
|
| 79 |
+
draw.rectangle([0, 0, w, bar_height], fill=(0, 0, 0))
|
| 80 |
+
x = (w - tw) // 2
|
| 81 |
+
y = margin
|
| 82 |
+
draw.text((x, y), text, fill=(255, 255, 255), font=font)
|
| 83 |
+
out.paste(img, (0, bar_height))
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
CLASS_PROMPTS = {
|
| 87 |
+
"knife": [
|
| 88 |
+
"a knife",
|
| 89 |
+
"a person holding a knife",
|
| 90 |
+
"a sharp blade knife",
|
| 91 |
+
],
|
| 92 |
+
"gun": [
|
| 93 |
+
"a gun",
|
| 94 |
+
"a pistol",
|
| 95 |
+
"a handgun",
|
| 96 |
+
"a person holding a gun",
|
| 97 |
+
"a person holding a pistol",
|
| 98 |
+
"a firearm weapon",
|
| 99 |
+
],
|
| 100 |
+
"cigarette": [
|
| 101 |
+
"a cigarette",
|
| 102 |
+
"a person smoking a cigarette",
|
| 103 |
+
"a lit cigarette in hand",
|
| 104 |
+
],
|
| 105 |
+
"phone": [
|
| 106 |
+
"a phone",
|
| 107 |
+
"a person holding a smartphone",
|
| 108 |
+
"a mobile phone cell phone",
|
| 109 |
+
],
|
| 110 |
+
"nothing": [
|
| 111 |
+
"a person with empty hands",
|
| 112 |
+
"a person standing with no objects",
|
| 113 |
+
"empty hands no weapon",
|
| 114 |
+
],
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def parse_args():
|
| 119 |
+
p = argparse.ArgumentParser(description="Jina-CLIP-v2 few-shot classifier")
|
| 120 |
+
p.add_argument("--refs", required=True, help="Reference images folder")
|
| 121 |
+
p.add_argument("--input", required=True, help="Query crop images folder")
|
| 122 |
+
p.add_argument("--output", default="jinaclip_results", help="Output folder")
|
| 123 |
+
p.add_argument("--dim", type=int, default=TRUNCATE_DIM, help="Embedding dim (64-1024)")
|
| 124 |
+
p.add_argument("--text-weight", type=float, default=0.3,
|
| 125 |
+
help="Text embedding weight (0.0=image only, default 0.3)")
|
| 126 |
+
p.add_argument("--conf-threshold", type=float, default=0.75,
|
| 127 |
+
help="Min confidence to accept prediction (default 0.75)")
|
| 128 |
+
p.add_argument("--gap-threshold", type=float, default=0.05,
|
| 129 |
+
help="Min gap between top-1 and top-2 (default 0.05)")
|
| 130 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 131 |
+
p.add_argument("--save-refs", action="store_true",
|
| 132 |
+
help="Save reference embeddings to .npy for fast reload")
|
| 133 |
+
return p.parse_args()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class JinaCLIPv2Encoder:
|
| 137 |
+
def __init__(self, device="cuda"):
|
| 138 |
+
self.device = device
|
| 139 |
+
print("[*] Loading jina-clip-v2...")
|
| 140 |
+
t0 = time.perf_counter()
|
| 141 |
+
# On HF Spaces, accelerate is pre-installed and transformers uses its meta-device
|
| 142 |
+
# context during init, so set_default_device("cpu") is overridden. Jina's
|
| 143 |
+
# eva_model.py:606 does torch.linspace(0, drop_path_rate, depth).item() and
|
| 144 |
+
# crashes on meta tensors. Monkey-patch linspace to force device="cpu" for init.
|
| 145 |
+
_orig_linspace = torch.linspace
|
| 146 |
+
|
| 147 |
+
def _safe_linspace(*args, **kwargs):
|
| 148 |
+
kwargs.pop("device", None)
|
| 149 |
+
return _orig_linspace(*args, **kwargs, device="cpu")
|
| 150 |
+
|
| 151 |
+
torch.linspace = _safe_linspace
|
| 152 |
+
try:
|
| 153 |
+
self.model = AutoModel.from_pretrained(
|
| 154 |
+
"jinaai/jina-clip-v2",
|
| 155 |
+
trust_remote_code=True,
|
| 156 |
+
low_cpu_mem_usage=False,
|
| 157 |
+
revision="main",
|
| 158 |
+
)
|
| 159 |
+
finally:
|
| 160 |
+
torch.linspace = _orig_linspace
|
| 161 |
+
self.model = self.model.to(device).eval()
|
| 162 |
+
if device == "cpu":
|
| 163 |
+
self.model = self.model.float()
|
| 164 |
+
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s (device={device})\n")
|
| 165 |
+
|
| 166 |
+
def encode_images(self, images: list[Image.Image], dim: int = TRUNCATE_DIM) -> np.ndarray:
|
| 167 |
+
rgb = [img.convert("RGB") for img in images]
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
embs = self.model.encode_image(rgb, truncate_dim=dim)
|
| 170 |
+
embs = _to_numpy(embs)
|
| 171 |
+
embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 172 |
+
norms = np.linalg.norm(embs, axis=-1, keepdims=True)
|
| 173 |
+
norms = np.maximum(norms, 1e-12)
|
| 174 |
+
return (embs / norms).astype(np.float32)
|
| 175 |
+
|
| 176 |
+
def encode_texts(self, texts: list[str], dim: int = TRUNCATE_DIM) -> np.ndarray:
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
embs = self.model.encode_text(texts, truncate_dim=dim)
|
| 179 |
+
embs = _to_numpy(embs)
|
| 180 |
+
embs = np.nan_to_num(embs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 181 |
+
norms = np.linalg.norm(embs, axis=-1, keepdims=True)
|
| 182 |
+
norms = np.maximum(norms, 1e-12)
|
| 183 |
+
return (embs / norms).astype(np.float32)
|
| 184 |
+
|
| 185 |
+
def encode_image_paths(self, paths: list[str], dim: int = TRUNCATE_DIM,
|
| 186 |
+
batch_size: int = 16) -> np.ndarray:
|
| 187 |
+
all_embs = []
|
| 188 |
+
for i in range(0, len(paths), batch_size):
|
| 189 |
+
batch = [Image.open(p) for p in paths[i:i + batch_size]]
|
| 190 |
+
all_embs.append(self.encode_images(batch, dim))
|
| 191 |
+
return np.concatenate(all_embs, axis=0)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def build_refs(encoder: JinaCLIPv2Encoder, refs_dir: Path,
|
| 195 |
+
dim: int, text_weight: float, batch_size: int):
|
| 196 |
+
class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
|
| 197 |
+
if not class_dirs:
|
| 198 |
+
raise ValueError(f"No subfolders in {refs_dir}")
|
| 199 |
+
|
| 200 |
+
labels, embeddings = [], []
|
| 201 |
+
_device = getattr(encoder, "device", "?")
|
| 202 |
+
print(f" Device: {_device} | Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
|
| 203 |
+
if _device == "cpu":
|
| 204 |
+
print(" [WARNING] Jina is on CPU. Ref embeddings are often all zeros on CPU. Use a Space with GPU (e.g. T4) for D-FINE + Classify.\n")
|
| 205 |
+
|
| 206 |
+
for d in class_dirs:
|
| 207 |
+
name = d.name
|
| 208 |
+
paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
|
| 209 |
+
if not paths:
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
# Image embeddings
|
| 213 |
+
img_embs = encoder.encode_image_paths(paths, dim, batch_size)
|
| 214 |
+
img_avg = np.nan_to_num(img_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
|
| 215 |
+
|
| 216 |
+
# Text embeddings
|
| 217 |
+
prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
|
| 218 |
+
text_embs = encoder.encode_texts(prompts, dim)
|
| 219 |
+
text_avg = np.nan_to_num(text_embs.mean(axis=0), nan=0.0, posinf=0.0, neginf=0.0)
|
| 220 |
+
|
| 221 |
+
# Combine
|
| 222 |
+
combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
|
| 223 |
+
combined = np.nan_to_num(combined, nan=0.0, posinf=0.0, neginf=0.0)
|
| 224 |
+
combined = combined / (np.linalg.norm(combined) + 1e-12)
|
| 225 |
+
|
| 226 |
+
labels.append(name)
|
| 227 |
+
embeddings.append(combined)
|
| 228 |
+
|
| 229 |
+
img_norm = img_avg / (np.linalg.norm(img_avg) + 1e-12)
|
| 230 |
+
text_norm = text_avg / (np.linalg.norm(text_avg) + 1e-12)
|
| 231 |
+
sim = float(np.nan_to_num(np.dot(img_norm, text_norm), nan=0.0))
|
| 232 |
+
print(f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts | "
|
| 233 |
+
f"img-text sim: {sim:.4f}")
|
| 234 |
+
if labels and np.allclose(np.stack(embeddings), 0.0):
|
| 235 |
+
print("\n [WARNING] All ref embeddings are zero. Jina-CLIP often returns zeros on CPU. "
|
| 236 |
+
"Use a Space with GPU (e.g. T4) for D-FINE + Classify to work correctly.")
|
| 237 |
+
|
| 238 |
+
return labels, np.stack(embeddings)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def classify(query_emb: np.ndarray, ref_labels: list[str], ref_embs: np.ndarray,
|
| 242 |
+
conf_threshold: float, gap_threshold: float) -> dict:
|
| 243 |
+
sims = (query_emb @ ref_embs.T).squeeze(0)
|
| 244 |
+
sims = np.nan_to_num(sims.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0)
|
| 245 |
+
sorted_idx = np.argsort(sims)[::-1]
|
| 246 |
+
|
| 247 |
+
best_idx = sorted_idx[0]
|
| 248 |
+
second_idx = sorted_idx[1]
|
| 249 |
+
conf = float(sims[best_idx])
|
| 250 |
+
gap = float(sims[best_idx] - sims[second_idx])
|
| 251 |
+
|
| 252 |
+
# Dual threshold
|
| 253 |
+
conf_ok = conf >= conf_threshold
|
| 254 |
+
gap_ok = gap >= gap_threshold
|
| 255 |
+
|
| 256 |
+
if conf_ok and gap_ok:
|
| 257 |
+
prediction = ref_labels[best_idx]
|
| 258 |
+
status = "accepted"
|
| 259 |
+
else:
|
| 260 |
+
prediction = "unknown"
|
| 261 |
+
reasons = []
|
| 262 |
+
if not conf_ok:
|
| 263 |
+
reasons.append(f"conf {conf:.4f} < {conf_threshold}")
|
| 264 |
+
if not gap_ok:
|
| 265 |
+
reasons.append(f"gap {gap:.4f} < {gap_threshold}")
|
| 266 |
+
status = "rejected: " + ", ".join(reasons)
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"prediction": prediction,
|
| 270 |
+
"raw_prediction": ref_labels[best_idx],
|
| 271 |
+
"confidence": conf,
|
| 272 |
+
"gap": gap,
|
| 273 |
+
"second_best": ref_labels[second_idx],
|
| 274 |
+
"second_conf": float(sims[second_idx]),
|
| 275 |
+
"status": status,
|
| 276 |
+
"all_sims": {ref_labels[j]: float(sims[j]) for j in range(len(ref_labels))},
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def main():
|
| 281 |
+
args = parse_args()
|
| 282 |
+
input_dir, output_dir = Path(args.input), Path(args.output)
|
| 283 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS)
|
| 286 |
+
if not paths:
|
| 287 |
+
return print(f"[!] No images in {input_dir}")
|
| 288 |
+
|
| 289 |
+
print(f"[*] {len(paths)} query images")
|
| 290 |
+
print(f"[*] Conf threshold: {args.conf_threshold} | Gap threshold: {args.gap_threshold}\n")
|
| 291 |
+
|
| 292 |
+
encoder = JinaCLIPv2Encoder("cuda")
|
| 293 |
+
|
| 294 |
+
# Build references
|
| 295 |
+
print("[*] Building references...")
|
| 296 |
+
ref_labels, ref_embs = build_refs(
|
| 297 |
+
encoder, Path(args.refs), args.dim, args.text_weight, args.batch_size
|
| 298 |
+
)
|
| 299 |
+
print(f"\n[*] {len(ref_labels)} classes: {ref_labels}\n")
|
| 300 |
+
|
| 301 |
+
# Save refs if requested
|
| 302 |
+
if args.save_refs:
|
| 303 |
+
np.save(output_dir / "ref_embeddings.npy", ref_embs)
|
| 304 |
+
with open(output_dir / "ref_labels.json", "w") as jf:
|
| 305 |
+
json.dump(ref_labels, jf)
|
| 306 |
+
print(f"[*] Saved refs to {output_dir}\n")
|
| 307 |
+
|
| 308 |
+
# CSV
|
| 309 |
+
csv_path = output_dir / "classifications.csv"
|
| 310 |
+
f = open(csv_path, "w", newline="")
|
| 311 |
+
w = csv.writer(f)
|
| 312 |
+
w.writerow(["image", "prediction", "raw_prediction", "confidence", "gap",
|
| 313 |
+
"second_best", "second_conf", "status"] +
|
| 314 |
+
[f"sim_{l}" for l in ref_labels] + ["time_ms"])
|
| 315 |
+
|
| 316 |
+
# Stats
|
| 317 |
+
times = []
|
| 318 |
+
counts = {"unknown": 0}
|
| 319 |
+
for l in ref_labels:
|
| 320 |
+
counts[l] = 0
|
| 321 |
+
accepted, rejected = 0, 0
|
| 322 |
+
|
| 323 |
+
# Header
|
| 324 |
+
hdr = " ".join(f"{l:>10}" for l in ref_labels)
|
| 325 |
+
print(f"{'Image':<30} {'Result':<10} {'Conf':>6} {'Gap':>6} {hdr} {'Status'}")
|
| 326 |
+
print("=" * (30 + 10 + 14 + len(hdr) + 40))
|
| 327 |
+
|
| 328 |
+
# Classify
|
| 329 |
+
for p in paths:
|
| 330 |
+
t0 = time.perf_counter()
|
| 331 |
+
img = Image.open(p)
|
| 332 |
+
q = encoder.encode_images([img], args.dim)
|
| 333 |
+
ms = (time.perf_counter() - t0) * 1000
|
| 334 |
+
times.append(ms)
|
| 335 |
+
|
| 336 |
+
result = classify(q, ref_labels, ref_embs, args.conf_threshold, args.gap_threshold)
|
| 337 |
+
counts[result["prediction"]] += 1
|
| 338 |
+
|
| 339 |
+
if result["prediction"] != "unknown":
|
| 340 |
+
accepted += 1
|
| 341 |
+
else:
|
| 342 |
+
rejected += 1
|
| 343 |
+
|
| 344 |
+
# Draw label on image and save to output folder
|
| 345 |
+
annotated = draw_label_on_image(img, result["prediction"], result["confidence"])
|
| 346 |
+
out_path = output_dir / p.name
|
| 347 |
+
annotated.save(out_path)
|
| 348 |
+
|
| 349 |
+
sim_str = " ".join(f"{result['all_sims'][l]:>10.4f}" for l in ref_labels)
|
| 350 |
+
print(f"{p.name:<30} {result['prediction']:<10} "
|
| 351 |
+
f"{result['confidence']:>6.4f} {result['gap']:>6.4f} "
|
| 352 |
+
f"{sim_str} {result['status']}")
|
| 353 |
+
|
| 354 |
+
w.writerow([
|
| 355 |
+
p.name,
|
| 356 |
+
result["prediction"],
|
| 357 |
+
result["raw_prediction"],
|
| 358 |
+
f"{result['confidence']:.4f}",
|
| 359 |
+
f"{result['gap']:.4f}",
|
| 360 |
+
result["second_best"],
|
| 361 |
+
f"{result['second_conf']:.4f}",
|
| 362 |
+
result["status"],
|
| 363 |
+
] + [f"{result['all_sims'][l]:.4f}" for l in ref_labels] +
|
| 364 |
+
[f"{ms:.1f}"])
|
| 365 |
+
|
| 366 |
+
f.close()
|
| 367 |
+
|
| 368 |
+
# Summary
|
| 369 |
+
n = len(times)
|
| 370 |
+
total = sum(times)
|
| 371 |
+
print(f"\n{'='*70}")
|
| 372 |
+
print("SUMMARY")
|
| 373 |
+
print(f"{'='*70}")
|
| 374 |
+
print(f" Model : jina-clip-v2 (EVA02-L, 304M, CLS pooling)")
|
| 375 |
+
print(f" Embed dim : {args.dim}")
|
| 376 |
+
print(f" Text weight : {args.text_weight}")
|
| 377 |
+
print(f" Conf threshold : {args.conf_threshold}")
|
| 378 |
+
print(f" Gap threshold : {args.gap_threshold}")
|
| 379 |
+
print(f" Images : {n}")
|
| 380 |
+
if n:
|
| 381 |
+
print(f" Accepted : {accepted} ({accepted/n*100:.1f}%)")
|
| 382 |
+
print(f" Rejected : {rejected} ({rejected/n*100:.1f}%)")
|
| 383 |
+
print(f" ──────────────────────────────────────────")
|
| 384 |
+
for l in ref_labels + ["unknown"]:
|
| 385 |
+
c = counts.get(l, 0)
|
| 386 |
+
pct = (c / n * 100) if n else 0
|
| 387 |
+
print(f" {l:<14}: {c:>4} ({pct:.1f}%)")
|
| 388 |
+
print(f" ──────────────────────────────────────────")
|
| 389 |
+
if n:
|
| 390 |
+
print(f" Total : {total:.0f}ms ({total/1000:.2f}s)")
|
| 391 |
+
print(f" Avg/image : {total/n:.1f}ms")
|
| 392 |
+
print(f" Throughput : {n/(total/1000):.1f} img/s")
|
| 393 |
+
print(f" CSV : {csv_path}")
|
| 394 |
+
print(f" Annotated imgs : {output_dir}")
|
| 395 |
+
print(f"{'='*70}")
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == "__main__":
|
| 399 |
+
main()
|
models/README.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model versions
|
| 2 |
+
|
| 3 |
+
- **v1** — `v1/best.pt` (current)
|
| 4 |
+
- **v2** — add `v2/best.pt` and a new tab in `app.py` when ready
|
| 5 |
+
|
| 6 |
+
Each version folder should contain `best.pt` (or the weights file used by the app).
|
models/v1/best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ce9c6f1f6193256572eae61176c486e52a11aa7f5885778aa8f3a445e04d1e5
|
| 3 |
+
size 44256473
|
nomic_fewshot.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Few-shot object classification using Nomic embed-vision-v1.5 + embed-text-v1.5.
|
| 3 |
+
|
| 4 |
+
Same treatment as Jina: image refs + text prompts, combined with text_weight (default 0.3).
|
| 5 |
+
Used by dfine_jina_pipeline.py and tune_thresholds.py for Nomic crop classification.
|
| 6 |
+
"""
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
|
| 15 |
+
from transformers import modeling_utils
|
| 16 |
+
|
| 17 |
+
from jina_fewshot import CLASS_PROMPTS, IMAGE_EXTS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _patch_tied_weights_for_nomic():
|
| 21 |
+
"""NomicVisionModel has _tied_weights_keys but newer transformers expect all_tied_weights_keys.
|
| 22 |
+
Only patch when this method exists (newer transformers); older versions don't need it."""
|
| 23 |
+
if not hasattr(modeling_utils.PreTrainedModel, "mark_tied_weights_as_initialized"):
|
| 24 |
+
return
|
| 25 |
+
_orig = modeling_utils.PreTrainedModel.mark_tied_weights_as_initialized
|
| 26 |
+
|
| 27 |
+
def _patched(self, loading_info):
|
| 28 |
+
if not hasattr(self, "all_tied_weights_keys"):
|
| 29 |
+
self.all_tied_weights_keys = getattr(self, "_tied_weights_keys", None) or {}
|
| 30 |
+
return _orig(self, loading_info)
|
| 31 |
+
|
| 32 |
+
modeling_utils.PreTrainedModel.mark_tied_weights_as_initialized = _patched
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _nomic_mean_pool(last_hidden_state, attention_mask):
|
| 36 |
+
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
| 37 |
+
return torch.sum(last_hidden_state * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class NomicTextEncoder:
|
| 41 |
+
"""Nomic embed-text-v1.5: text → normalized embedding (aligned to vision space)."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, device="cuda"):
|
| 44 |
+
self.device = device
|
| 45 |
+
print("[*] Loading nomic-embed-text-v1.5...")
|
| 46 |
+
t0 = time.perf_counter()
|
| 47 |
+
self.tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
|
| 48 |
+
if hasattr(torch, "set_default_device"):
|
| 49 |
+
torch.set_default_device("cpu")
|
| 50 |
+
try:
|
| 51 |
+
self.model = AutoModel.from_pretrained(
|
| 52 |
+
"nomic-ai/nomic-embed-text-v1.5",
|
| 53 |
+
trust_remote_code=True,
|
| 54 |
+
low_cpu_mem_usage=False,
|
| 55 |
+
)
|
| 56 |
+
finally:
|
| 57 |
+
if hasattr(torch, "set_default_device"):
|
| 58 |
+
torch.set_default_device("cpu")
|
| 59 |
+
self.model = self.model.to(device).eval()
|
| 60 |
+
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
|
| 61 |
+
|
| 62 |
+
def encode_texts(self, texts: list[str]) -> np.ndarray:
|
| 63 |
+
prefixed = [f"classification: {t}" for t in texts]
|
| 64 |
+
inputs = self.tokenizer(prefixed, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
| 65 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
out = self.model(**inputs)
|
| 68 |
+
embs = _nomic_mean_pool(out.last_hidden_state, inputs["attention_mask"])
|
| 69 |
+
embs = F.normalize(embs, p=2, dim=1)
|
| 70 |
+
return embs.cpu().float().numpy()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class NomicVisionEncoder:
|
| 74 |
+
"""Nomic embed-vision-v1.5: image → normalized CLS embedding."""
|
| 75 |
+
|
| 76 |
+
def __init__(self, device="cuda"):
|
| 77 |
+
self.device = device
|
| 78 |
+
print("[*] Loading nomic-embed-vision-v1.5...")
|
| 79 |
+
t0 = time.perf_counter()
|
| 80 |
+
self.processor = AutoImageProcessor.from_pretrained("nomic-ai/nomic-embed-vision-v1.5")
|
| 81 |
+
_patch_tied_weights_for_nomic()
|
| 82 |
+
if hasattr(torch, "set_default_device"):
|
| 83 |
+
torch.set_default_device("cpu")
|
| 84 |
+
try:
|
| 85 |
+
self.model = AutoModel.from_pretrained(
|
| 86 |
+
"nomic-ai/nomic-embed-vision-v1.5",
|
| 87 |
+
trust_remote_code=True,
|
| 88 |
+
low_cpu_mem_usage=False,
|
| 89 |
+
)
|
| 90 |
+
finally:
|
| 91 |
+
if hasattr(torch, "set_default_device"):
|
| 92 |
+
torch.set_default_device("cpu")
|
| 93 |
+
self.model = self.model.to(device).eval()
|
| 94 |
+
print(f"[*] Loaded in {time.perf_counter() - t0:.1f}s\n")
|
| 95 |
+
|
| 96 |
+
def encode_images(self, images: list) -> np.ndarray:
|
| 97 |
+
"""Encode images to L2-normalized embeddings (CLS token)."""
|
| 98 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 99 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
out = self.model(**inputs).last_hidden_state
|
| 102 |
+
# CLS token, then normalize
|
| 103 |
+
embs = F.normalize(out[:, 0], p=2, dim=1)
|
| 104 |
+
return embs.cpu().float().numpy()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_refs_nomic(
|
| 108 |
+
encoder: NomicVisionEncoder,
|
| 109 |
+
refs_dir: Path,
|
| 110 |
+
batch_size: int = 16,
|
| 111 |
+
text_encoder: NomicTextEncoder | None = None,
|
| 112 |
+
text_weight: float = 0.3,
|
| 113 |
+
):
|
| 114 |
+
"""Build one ref embedding per class. Same treatment as Jina: image refs + text prompts, combined with text_weight (default 0.3)."""
|
| 115 |
+
class_dirs = sorted(d for d in refs_dir.iterdir() if d.is_dir())
|
| 116 |
+
if not class_dirs:
|
| 117 |
+
raise ValueError(f"No subfolders in {refs_dir}")
|
| 118 |
+
labels = []
|
| 119 |
+
embeddings = []
|
| 120 |
+
if text_encoder is not None:
|
| 121 |
+
print(f" Text weight: {text_weight:.1f} | Image weight: {1 - text_weight:.1f}\n")
|
| 122 |
+
for d in class_dirs:
|
| 123 |
+
name = d.name
|
| 124 |
+
paths = sorted(str(p) for p in d.iterdir() if p.suffix.lower() in IMAGE_EXTS)
|
| 125 |
+
if not paths:
|
| 126 |
+
continue
|
| 127 |
+
all_embs = []
|
| 128 |
+
for i in range(0, len(paths), batch_size):
|
| 129 |
+
batch = [Image.open(p).convert("RGB") for p in paths[i : i + batch_size]]
|
| 130 |
+
all_embs.append(encoder.encode_images(batch))
|
| 131 |
+
img_embs = np.concatenate(all_embs, axis=0)
|
| 132 |
+
img_avg = img_embs.mean(axis=0)
|
| 133 |
+
if text_encoder is not None:
|
| 134 |
+
prompts = CLASS_PROMPTS.get(name, [f"a {name}", f"a person holding a {name}"])
|
| 135 |
+
text_embs = text_encoder.encode_texts(prompts)
|
| 136 |
+
text_avg = text_embs.mean(axis=0)
|
| 137 |
+
combined = (1.0 - text_weight) * img_avg + text_weight * text_avg
|
| 138 |
+
combined = combined / (np.linalg.norm(combined) + 1e-12)
|
| 139 |
+
labels.append(name)
|
| 140 |
+
embeddings.append(combined)
|
| 141 |
+
print(f" {name:<14}: {len(paths)} imgs + {len(prompts)} prompts")
|
| 142 |
+
else:
|
| 143 |
+
img_avg = img_avg / (np.linalg.norm(img_avg) + 1e-12)
|
| 144 |
+
labels.append(name)
|
| 145 |
+
embeddings.append(img_avg)
|
| 146 |
+
print(f" {name:<14}: {len(paths)} imgs")
|
| 147 |
+
return labels, np.stack(embeddings)
|
refs/cigarette/c2.png
ADDED
|
Git LFS Details
|
refs/cigarette/c3.png
ADDED
|
Git LFS Details
|
refs/cigarette/c4.png
ADDED
|
Git LFS Details
|
refs/cigarette/c5.png
ADDED
|
Git LFS Details
|
refs/cigarette/c6.png
ADDED
|
Git LFS Details
|
refs/cigarette/c7.png
ADDED
|
Git LFS Details
|
refs/cigarette/c9.png
ADDED
|
Git LFS Details
|
refs/cigarette/cigarette.jpg
ADDED
|
Git LFS Details
|
refs/gun/g1.png
ADDED
|
Git LFS Details
|
refs/gun/g2.png
ADDED
|
Git LFS Details
|
refs/gun/g3.png
ADDED
|
Git LFS Details
|
refs/gun/g4.png
ADDED
|
Git LFS Details
|
refs/gun/g5.png
ADDED
|
Git LFS Details
|
refs/gun/g6.png
ADDED
|
Git LFS Details
|
refs/gun/g7.png
ADDED
|
Git LFS Details
|
refs/gun/g8.png
ADDED
|
Git LFS Details
|
refs/gun/g9.png
ADDED
|
Git LFS Details
|
refs/gun/pistol.jpeg
ADDED
|
Git LFS Details
|
refs/knife/k1.png
ADDED
|
Git LFS Details
|
refs/knife/k2.png
ADDED
|
Git LFS Details
|
refs/knife/k3.png
ADDED
|
Git LFS Details
|
refs/knife/k4.png
ADDED
|
Git LFS Details
|
refs/knife/k5.png
ADDED
|
Git LFS Details
|
refs/knife/k6.png
ADDED
|
Git LFS Details
|
refs/knife/k7.png
ADDED
|
Git LFS Details
|
refs/knife/k8.png
ADDED
|
Git LFS Details
|
refs/knife/k9.png
ADDED
|
Git LFS Details
|
refs/knife/knife.jpeg
ADDED
|
Git LFS Details
|
refs/phone/p1.png
ADDED
|
Git LFS Details
|
refs/phone/p2.png
ADDED
|
Git LFS Details
|
refs/phone/p3.png
ADDED
|
Git LFS Details
|
refs/phone/p4.png
ADDED
|
Git LFS Details
|
refs/phone/p5.png
ADDED
|
Git LFS Details
|
refs/phone/p6.png
ADDED
|
Git LFS Details
|
refs/phone/p7.png
ADDED
|
Git LFS Details
|
refs/phone/p8.png
ADDED
|
Git LFS Details
|
refs/phone/p9.jpg
ADDED
|
Git LFS Details
|
refs/phone/phone.jpg
ADDED
|
Git LFS Details
|
requirements-lock.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pinned versions for Docker so Space matches local. Regenerate from your venv with:
|
| 2 |
+
# pip freeze > requirements-lock.txt
|
| 3 |
+
# Then rebuild the Docker image.
|
| 4 |
+
gradio==6.0.0
|
| 5 |
+
ultralytics==8.3.0
|
| 6 |
+
torch==2.2.2
|
| 7 |
+
torchvision==0.17.2
|
| 8 |
+
transformers==4.44.2
|
| 9 |
+
accelerate==0.33.0
|
| 10 |
+
pillow>=9.0.0
|
| 11 |
+
numpy>=1.24.0
|
| 12 |
+
huggingface_hub>=0.20.0
|
| 13 |
+
matplotlib>=3.5.0
|
| 14 |
+
requests>=2.28.0
|
| 15 |
+
einops>=0.7.0
|
| 16 |
+
timm>=0.9.0
|
| 17 |
+
sentencepiece>=0.1.99
|