Spaces:
Sleeping
Sleeping
Added siglip multiple res models
Browse files- app.py +10 -3
- dfine_jina_pipeline.py +7 -5
- siglip_zeroshot.py +12 -4
app.py
CHANGED
|
@@ -11,7 +11,7 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
| 11 |
DEFAULT_LABELS = "gun, knife, cigarette, phone"
|
| 12 |
|
| 13 |
|
| 14 |
-
def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_threshold, labels_text):
|
| 15 |
"""D-FINE first, then classify crops with SigLIP.
|
| 16 |
Returns (group_crop_gallery, known_crop_gallery, status_message).
|
| 17 |
"""
|
|
@@ -24,6 +24,7 @@ def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_thresh
|
|
| 24 |
|
| 25 |
dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
|
| 26 |
conf_thresh = float(siglip_threshold)
|
|
|
|
| 27 |
|
| 28 |
group_crops, known_crops, status = run_single_image(
|
| 29 |
image,
|
|
@@ -34,7 +35,7 @@ def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_thresh
|
|
| 34 |
min_side=24,
|
| 35 |
crop_dedup_iou=0.4,
|
| 36 |
min_display_conf=conf_thresh,
|
| 37 |
-
classifier=
|
| 38 |
labels=labels,
|
| 39 |
)
|
| 40 |
|
|
@@ -73,6 +74,12 @@ with gr.Blocks(title="Small Object Detection") as app:
|
|
| 73 |
label="D-FINE model",
|
| 74 |
)
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
dfine_threshold_slider = gr.Slider(
|
| 77 |
minimum=0.05,
|
| 78 |
maximum=0.5,
|
|
@@ -137,7 +144,7 @@ with gr.Blocks(title="Small Object Detection") as app:
|
|
| 137 |
|
| 138 |
btn_dfine.click(
|
| 139 |
fn=run_dfine_classify,
|
| 140 |
-
inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, siglip_threshold_slider, labels_input],
|
| 141 |
outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
|
| 142 |
concurrency_limit=1,
|
| 143 |
)
|
|
|
|
| 11 |
DEFAULT_LABELS = "gun, knife, cigarette, phone"
|
| 12 |
|
| 13 |
|
| 14 |
+
def run_dfine_classify(image, dfine_threshold, dfine_model_choice, classifier_choice, siglip_threshold, labels_text):
|
| 15 |
"""D-FINE first, then classify crops with SigLIP.
|
| 16 |
Returns (group_crop_gallery, known_crop_gallery, status_message).
|
| 17 |
"""
|
|
|
|
| 24 |
|
| 25 |
dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
|
| 26 |
conf_thresh = float(siglip_threshold)
|
| 27 |
+
classifier = classifier_choice.strip() if classifier_choice else "siglip-224"
|
| 28 |
|
| 29 |
group_crops, known_crops, status = run_single_image(
|
| 30 |
image,
|
|
|
|
| 35 |
min_side=24,
|
| 36 |
crop_dedup_iou=0.4,
|
| 37 |
min_display_conf=conf_thresh,
|
| 38 |
+
classifier=classifier,
|
| 39 |
labels=labels,
|
| 40 |
)
|
| 41 |
|
|
|
|
| 74 |
label="D-FINE model",
|
| 75 |
)
|
| 76 |
|
| 77 |
+
classifier_dropdown = gr.Dropdown(
|
| 78 |
+
choices=["siglip-224", "siglip-256", "siglip-384"],
|
| 79 |
+
value="siglip-224",
|
| 80 |
+
label="Classifier model",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
dfine_threshold_slider = gr.Slider(
|
| 84 |
minimum=0.05,
|
| 85 |
maximum=0.5,
|
|
|
|
| 144 |
|
| 145 |
btn_dfine.click(
|
| 146 |
fn=run_dfine_classify,
|
| 147 |
+
inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, classifier_dropdown, siglip_threshold_slider, labels_input],
|
| 148 |
outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
|
| 149 |
concurrency_limit=1,
|
| 150 |
)
|
dfine_jina_pipeline.py
CHANGED
|
@@ -516,7 +516,7 @@ DFINE_MODEL_IDS = {
|
|
| 516 |
"large-obj2coco": "ustc-community/dfine-large-obj2coco-e25",
|
| 517 |
}
|
| 518 |
|
| 519 |
-
CLASSIFIER_CHOICES = ["jina", "siglip", "siglip2_onnx"]
|
| 520 |
|
| 521 |
|
| 522 |
def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
|
|
@@ -529,9 +529,11 @@ def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
|
|
| 529 |
ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
|
| 530 |
return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
|
| 531 |
|
| 532 |
-
if classifier_name
|
| 533 |
-
from siglip_zeroshot import SigLIPClassifier
|
| 534 |
-
|
|
|
|
|
|
|
| 535 |
clf.build_refs(refs_dir=refs_dir, labels=labels)
|
| 536 |
return clf
|
| 537 |
|
|
@@ -566,7 +568,7 @@ def run_single_image(
|
|
| 566 |
crop_dedup_iou=0.35,
|
| 567 |
squarify=True,
|
| 568 |
min_display_conf=None,
|
| 569 |
-
classifier="siglip",
|
| 570 |
labels=None,
|
| 571 |
):
|
| 572 |
"""
|
|
|
|
| 516 |
"large-obj2coco": "ustc-community/dfine-large-obj2coco-e25",
|
| 517 |
}
|
| 518 |
|
| 519 |
+
CLASSIFIER_CHOICES = ["jina", "siglip-224", "siglip-256", "siglip-384", "siglip2_onnx"]
|
| 520 |
|
| 521 |
|
| 522 |
def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
|
|
|
|
| 529 |
ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
|
| 530 |
return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
|
| 531 |
|
| 532 |
+
if classifier_name.startswith("siglip-"):
|
| 533 |
+
from siglip_zeroshot import SigLIPClassifier, SIGLIP_MODELS
|
| 534 |
+
if classifier_name not in SIGLIP_MODELS:
|
| 535 |
+
raise ValueError(f"Unknown SigLIP model: {classifier_name}. Choose from {list(SIGLIP_MODELS.keys())}")
|
| 536 |
+
clf = SigLIPClassifier(device, model_key=classifier_name)
|
| 537 |
clf.build_refs(refs_dir=refs_dir, labels=labels)
|
| 538 |
return clf
|
| 539 |
|
|
|
|
| 568 |
crop_dedup_iou=0.35,
|
| 569 |
squarify=True,
|
| 570 |
min_display_conf=None,
|
| 571 |
+
classifier="siglip-224",
|
| 572 |
labels=None,
|
| 573 |
):
|
| 574 |
"""
|
siglip_zeroshot.py
CHANGED
|
@@ -11,17 +11,25 @@ import numpy as np
|
|
| 11 |
import torch
|
| 12 |
from transformers import SiglipModel, AutoProcessor
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class SigLIPClassifier:
|
| 15 |
"""Zero-shot crop classifier using SigLIP (PyTorch)."""
|
| 16 |
|
| 17 |
-
def __init__(self, device="cuda"):
|
| 18 |
-
|
|
|
|
| 19 |
t0 = time.perf_counter()
|
| 20 |
|
| 21 |
self.device = device
|
| 22 |
-
self.
|
|
|
|
| 23 |
self.model = self.model.to(device).eval()
|
| 24 |
-
self.processor = AutoProcessor.from_pretrained(
|
| 25 |
|
| 26 |
self.labels = []
|
| 27 |
|
|
|
|
| 11 |
import torch
|
| 12 |
from transformers import SiglipModel, AutoProcessor
|
| 13 |
|
| 14 |
+
SIGLIP_MODELS = {
|
| 15 |
+
"siglip-224": "google/siglip-base-patch16-224",
|
| 16 |
+
"siglip-256": "google/siglip-base-patch16-256",
|
| 17 |
+
"siglip-384": "google/siglip-base-patch16-384",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
class SigLIPClassifier:
|
| 21 |
"""Zero-shot crop classifier using SigLIP (PyTorch)."""
|
| 22 |
|
| 23 |
+
def __init__(self, device="cuda", model_key="siglip-224"):
|
| 24 |
+
model_id = SIGLIP_MODELS.get(model_key, model_key)
|
| 25 |
+
print(f"[*] Loading SigLIP ({model_id})...")
|
| 26 |
t0 = time.perf_counter()
|
| 27 |
|
| 28 |
self.device = device
|
| 29 |
+
self.model_key = model_key
|
| 30 |
+
self.model = SiglipModel.from_pretrained(model_id)
|
| 31 |
self.model = self.model.to(device).eval()
|
| 32 |
+
self.processor = AutoProcessor.from_pretrained(model_id)
|
| 33 |
|
| 34 |
self.labels = []
|
| 35 |
|