orik-ss commited on
Commit
33708c6
·
1 Parent(s): 80cacd4

Added siglip multiple res models

Browse files
Files changed (3) hide show
  1. app.py +10 -3
  2. dfine_jina_pipeline.py +7 -5
  3. siglip_zeroshot.py +12 -4
app.py CHANGED
@@ -11,7 +11,7 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
  DEFAULT_LABELS = "gun, knife, cigarette, phone"
12
 
13
 
14
- def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_threshold, labels_text):
15
  """D-FINE first, then classify crops with SigLIP.
16
  Returns (group_crop_gallery, known_crop_gallery, status_message).
17
  """
@@ -24,6 +24,7 @@ def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_thresh
24
 
25
  dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
26
  conf_thresh = float(siglip_threshold)
 
27
 
28
  group_crops, known_crops, status = run_single_image(
29
  image,
@@ -34,7 +35,7 @@ def run_dfine_classify(image, dfine_threshold, dfine_model_choice, siglip_thresh
34
  min_side=24,
35
  crop_dedup_iou=0.4,
36
  min_display_conf=conf_thresh,
37
- classifier="siglip",
38
  labels=labels,
39
  )
40
 
@@ -73,6 +74,12 @@ with gr.Blocks(title="Small Object Detection") as app:
73
  label="D-FINE model",
74
  )
75
 
 
 
 
 
 
 
76
  dfine_threshold_slider = gr.Slider(
77
  minimum=0.05,
78
  maximum=0.5,
@@ -137,7 +144,7 @@ with gr.Blocks(title="Small Object Detection") as app:
137
 
138
  btn_dfine.click(
139
  fn=run_dfine_classify,
140
- inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, siglip_threshold_slider, labels_input],
141
  outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
142
  concurrency_limit=1,
143
  )
 
11
  DEFAULT_LABELS = "gun, knife, cigarette, phone"
12
 
13
 
14
+ def run_dfine_classify(image, dfine_threshold, dfine_model_choice, classifier_choice, siglip_threshold, labels_text):
15
  """D-FINE first, then classify crops with SigLIP.
16
  Returns (group_crop_gallery, known_crop_gallery, status_message).
17
  """
 
24
 
25
  dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
26
  conf_thresh = float(siglip_threshold)
27
+ classifier = classifier_choice.strip() if classifier_choice else "siglip-224"
28
 
29
  group_crops, known_crops, status = run_single_image(
30
  image,
 
35
  min_side=24,
36
  crop_dedup_iou=0.4,
37
  min_display_conf=conf_thresh,
38
+ classifier=classifier,
39
  labels=labels,
40
  )
41
 
 
74
  label="D-FINE model",
75
  )
76
 
77
+ classifier_dropdown = gr.Dropdown(
78
+ choices=["siglip-224", "siglip-256", "siglip-384"],
79
+ value="siglip-224",
80
+ label="Classifier model",
81
+ )
82
+
83
  dfine_threshold_slider = gr.Slider(
84
  minimum=0.05,
85
  maximum=0.5,
 
144
 
145
  btn_dfine.click(
146
  fn=run_dfine_classify,
147
+ inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, classifier_dropdown, siglip_threshold_slider, labels_input],
148
  outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
149
  concurrency_limit=1,
150
  )
dfine_jina_pipeline.py CHANGED
@@ -516,7 +516,7 @@ DFINE_MODEL_IDS = {
516
  "large-obj2coco": "ustc-community/dfine-large-obj2coco-e25",
517
  }
518
 
519
- CLASSIFIER_CHOICES = ["jina", "siglip", "siglip2_onnx"]
520
 
521
 
522
  def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
@@ -529,9 +529,11 @@ def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
529
  ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
530
  return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
531
 
532
- if classifier_name == "siglip":
533
- from siglip_zeroshot import SigLIPClassifier
534
- clf = SigLIPClassifier(device)
 
 
535
  clf.build_refs(refs_dir=refs_dir, labels=labels)
536
  return clf
537
 
@@ -566,7 +568,7 @@ def run_single_image(
566
  crop_dedup_iou=0.35,
567
  squarify=True,
568
  min_display_conf=None,
569
- classifier="siglip",
570
  labels=None,
571
  ):
572
  """
 
516
  "large-obj2coco": "ustc-community/dfine-large-obj2coco-e25",
517
  }
518
 
519
+ CLASSIFIER_CHOICES = ["jina", "siglip-224", "siglip-256", "siglip-384", "siglip2_onnx"]
520
 
521
 
522
  def _load_classifier(classifier_name, device, refs_dir=None, labels=None):
 
529
  ref_labels, ref_embs = build_refs(jina_encoder, refs_dir, TRUNCATE_DIM, 0.3, batch_size=16)
530
  return ("jina_wrapped", jina_encoder, ref_labels, ref_embs)
531
 
532
+ if classifier_name.startswith("siglip-"):
533
+ from siglip_zeroshot import SigLIPClassifier, SIGLIP_MODELS
534
+ if classifier_name not in SIGLIP_MODELS:
535
+ raise ValueError(f"Unknown SigLIP model: {classifier_name}. Choose from {list(SIGLIP_MODELS.keys())}")
536
+ clf = SigLIPClassifier(device, model_key=classifier_name)
537
  clf.build_refs(refs_dir=refs_dir, labels=labels)
538
  return clf
539
 
 
568
  crop_dedup_iou=0.35,
569
  squarify=True,
570
  min_display_conf=None,
571
+ classifier="siglip-224",
572
  labels=None,
573
  ):
574
  """
siglip_zeroshot.py CHANGED
@@ -11,17 +11,25 @@ import numpy as np
11
  import torch
12
  from transformers import SiglipModel, AutoProcessor
13
 
 
 
 
 
 
 
14
  class SigLIPClassifier:
15
  """Zero-shot crop classifier using SigLIP (PyTorch)."""
16
 
17
- def __init__(self, device="cuda"):
18
- print("[*] Loading SigLIP (google/siglip-base-patch16-224)...")
 
19
  t0 = time.perf_counter()
20
 
21
  self.device = device
22
- self.model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
 
23
  self.model = self.model.to(device).eval()
24
- self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
25
 
26
  self.labels = []
27
 
 
11
  import torch
12
  from transformers import SiglipModel, AutoProcessor
13
 
14
+ SIGLIP_MODELS = {
15
+ "siglip-224": "google/siglip-base-patch16-224",
16
+ "siglip-256": "google/siglip-base-patch16-256",
17
+ "siglip-384": "google/siglip-base-patch16-384",
18
+ }
19
+
20
  class SigLIPClassifier:
21
  """Zero-shot crop classifier using SigLIP (PyTorch)."""
22
 
23
+ def __init__(self, device="cuda", model_key="siglip-224"):
24
+ model_id = SIGLIP_MODELS.get(model_key, model_key)
25
+ print(f"[*] Loading SigLIP ({model_id})...")
26
  t0 = time.perf_counter()
27
 
28
  self.device = device
29
+ self.model_key = model_key
30
+ self.model = SiglipModel.from_pretrained(model_id)
31
  self.model = self.model.to(device).eval()
32
+ self.processor = AutoProcessor.from_pretrained(model_id)
33
 
34
  self.labels = []
35