small_object_detection / siglip_zeroshot.py
orik-ss's picture
Added siglip multiple res models
33708c6
"""
SigLIP zero-shot classifier for crop classification.
Uses google/siglip-base-patch16-224 via PyTorch.
Zero-shot: text prompts only, no reference images needed (folder names used for class labels).
"""
import time
from pathlib import Path
import numpy as np
import torch
from transformers import SiglipModel, AutoProcessor
SIGLIP_MODELS = {
"siglip-224": "google/siglip-base-patch16-224",
"siglip-256": "google/siglip-base-patch16-256",
"siglip-384": "google/siglip-base-patch16-384",
}
class SigLIPClassifier:
"""Zero-shot crop classifier using SigLIP (PyTorch)."""
def __init__(self, device="cuda", model_key="siglip-224"):
model_id = SIGLIP_MODELS.get(model_key, model_key)
print(f"[*] Loading SigLIP ({model_id})...")
t0 = time.perf_counter()
self.device = device
self.model_key = model_key
self.model = SiglipModel.from_pretrained(model_id)
self.model = self.model.to(device).eval()
self.processor = AutoProcessor.from_pretrained(model_id)
self.labels = []
print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})")
def set_labels(self, labels):
"""Set class labels directly from a list of strings."""
self.labels = list(labels)
if not self.labels:
raise ValueError("No labels provided")
print(f" SigLIP labels: {self.labels}")
def build_refs(self, refs_dir=None, labels=None, **kwargs):
"""Set labels from a list or extract from refs_dir subfolders."""
if labels:
self.set_labels(labels)
elif refs_dir:
refs_dir = Path(refs_dir)
self.set_labels(sorted(d.name for d in refs_dir.iterdir() if d.is_dir()))
else:
raise ValueError("Provide either labels or refs_dir")
def classify_crop(self, crop, conf_threshold, gap_threshold):
"""
Classify a single crop image using zero-shot SigLIP.
Returns dict matching jina_fewshot.classify() format.
"""
inputs = self.processor(
text=self.labels,
images=crop,
return_tensors="pt",
padding="max_length",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits_per_image
probs = torch.sigmoid(logits).cpu().numpy().squeeze(0)
probs = np.nan_to_num(probs.astype(np.float64), nan=0.0)
sorted_idx = np.argsort(probs)[::-1]
best_idx = sorted_idx[0]
second_idx = sorted_idx[1]
conf = float(probs[best_idx])
gap = float(probs[best_idx] - probs[second_idx])
if conf >= conf_threshold:
prediction = self.labels[best_idx]
status = "accepted"
else:
prediction = "unknown"
status = f"rejected: conf {conf:.4f} < {conf_threshold}"
return {
"prediction": prediction,
"raw_prediction": self.labels[best_idx],
"confidence": conf,
"gap": gap,
"second_best": self.labels[second_idx],
"second_conf": float(probs[second_idx]),
"status": status,
"all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))},
}