Spaces:
Sleeping
Sleeping
| """ | |
| SigLIP zero-shot classifier for crop classification. | |
| Uses google/siglip-base-patch16-224 via PyTorch. | |
| Zero-shot: text prompts only, no reference images needed (folder names used for class labels). | |
| """ | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from transformers import SiglipModel, AutoProcessor | |
| SIGLIP_MODELS = { | |
| "siglip-224": "google/siglip-base-patch16-224", | |
| "siglip-256": "google/siglip-base-patch16-256", | |
| "siglip-384": "google/siglip-base-patch16-384", | |
| } | |
| class SigLIPClassifier: | |
| """Zero-shot crop classifier using SigLIP (PyTorch).""" | |
| def __init__(self, device="cuda", model_key="siglip-224"): | |
| model_id = SIGLIP_MODELS.get(model_key, model_key) | |
| print(f"[*] Loading SigLIP ({model_id})...") | |
| t0 = time.perf_counter() | |
| self.device = device | |
| self.model_key = model_key | |
| self.model = SiglipModel.from_pretrained(model_id) | |
| self.model = self.model.to(device).eval() | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| self.labels = [] | |
| print(f"[*] SigLIP loaded in {time.perf_counter() - t0:.1f}s (device={device})") | |
| def set_labels(self, labels): | |
| """Set class labels directly from a list of strings.""" | |
| self.labels = list(labels) | |
| if not self.labels: | |
| raise ValueError("No labels provided") | |
| print(f" SigLIP labels: {self.labels}") | |
| def build_refs(self, refs_dir=None, labels=None, **kwargs): | |
| """Set labels from a list or extract from refs_dir subfolders.""" | |
| if labels: | |
| self.set_labels(labels) | |
| elif refs_dir: | |
| refs_dir = Path(refs_dir) | |
| self.set_labels(sorted(d.name for d in refs_dir.iterdir() if d.is_dir())) | |
| else: | |
| raise ValueError("Provide either labels or refs_dir") | |
| def classify_crop(self, crop, conf_threshold, gap_threshold): | |
| """ | |
| Classify a single crop image using zero-shot SigLIP. | |
| Returns dict matching jina_fewshot.classify() format. | |
| """ | |
| inputs = self.processor( | |
| text=self.labels, | |
| images=crop, | |
| return_tensors="pt", | |
| padding="max_length", | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits_per_image | |
| probs = torch.sigmoid(logits).cpu().numpy().squeeze(0) | |
| probs = np.nan_to_num(probs.astype(np.float64), nan=0.0) | |
| sorted_idx = np.argsort(probs)[::-1] | |
| best_idx = sorted_idx[0] | |
| second_idx = sorted_idx[1] | |
| conf = float(probs[best_idx]) | |
| gap = float(probs[best_idx] - probs[second_idx]) | |
| if conf >= conf_threshold: | |
| prediction = self.labels[best_idx] | |
| status = "accepted" | |
| else: | |
| prediction = "unknown" | |
| status = f"rejected: conf {conf:.4f} < {conf_threshold}" | |
| return { | |
| "prediction": prediction, | |
| "raw_prediction": self.labels[best_idx], | |
| "confidence": conf, | |
| "gap": gap, | |
| "second_best": self.labels[second_idx], | |
| "second_conf": float(probs[second_idx]), | |
| "status": status, | |
| "all_sims": {self.labels[j]: float(probs[j]) for j in range(len(self.labels))}, | |
| } | |