orik-ss's picture
Made siglip-256 as default classifier
d5c03fb
""" Gradio app: D-FINE + SigLIP Classify. """
import os
import gradio as gr
from pathlib import Path
from dfine_jina_pipeline import run_single_image
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_LABELS = "gun, knife, cigarette, phone"
def run_dfine_classify(image, dfine_threshold, dfine_model_choice, classifier_choice, siglip_threshold, labels_text):
"""D-FINE first, then classify crops with SigLIP.
Returns (group_crop_gallery, known_crop_gallery, status_message).
"""
if image is None:
return [], [], "Upload an image."
labels = [l.strip() for l in labels_text.split(",") if l.strip()]
if not labels:
return [], [], "Enter at least one label."
dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco"
conf_thresh = float(siglip_threshold)
classifier = classifier_choice.strip() if classifier_choice else "siglip-256"
group_crops, known_crops, status = run_single_image(
image,
dfine_model=dfine_model,
det_threshold=float(dfine_threshold),
conf_threshold=conf_thresh,
gap_threshold=0.0,
min_side=24,
crop_dedup_iou=0.4,
min_display_conf=conf_thresh,
classifier=classifier,
labels=labels,
)
return [(g, None) for g in (group_crops or [])], [(k, None) for k in (known_crops or [])], status or ""
IMG_HEIGHT = 400
with gr.Blocks(title="Small Object Detection") as app:
gr.Markdown("# Small Object Detection")
gr.Markdown(
"**D-FINE** detects persons/cars, then small-object crops are classified with **SigLIP** (zero-shot). "
"Choose a D-FINE model and enter comma-separated class labels for SigLIP."
)
with gr.Row():
with gr.Column(scale=1):
inp_dfine = gr.Image(
type="pil",
label="Input image",
height=IMG_HEIGHT
)
dfine_model_radio = gr.Dropdown(
choices=[
"small-obj365", "medium-obj365", "large-obj365",
"small-coco", "medium-coco", "large-coco",
"small-obj2coco", "medium-obj2coco", "large-obj2coco",
],
value="medium-obj2coco",
label="D-FINE model",
)
classifier_dropdown = gr.Dropdown(
choices=["siglip-224", "siglip-256", "siglip-384"],
value="siglip-256",
label="Classifier model",
)
dfine_threshold_slider = gr.Slider(
minimum=0.05,
maximum=0.5,
value=0.15,
step=0.05,
label="D-FINE detection threshold",
)
def update_dfine_threshold_default(choice):
if not choice:
return gr.update(value=0.15)
size = choice.strip().lower().split("-")[0]
defaults = {"large": 0.2, "medium": 0.15, "small": 0.1}
return gr.update(value=defaults.get(size, 0.15))
dfine_model_radio.change(
fn=update_dfine_threshold_default,
inputs=[dfine_model_radio],
outputs=[dfine_threshold_slider],
)
siglip_threshold_slider = gr.Slider(
minimum=0.001,
maximum=0.1,
value=0.005,
step=0.001,
label="SigLIP: min confidence threshold",
)
labels_input = gr.Textbox(
label="Labels (comma-separated)",
value=DEFAULT_LABELS,
placeholder="e.g. gun, knife, cigarette, phone",
)
btn_dfine = gr.Button(
"Run D-FINE + Classify",
variant="primary"
)
with gr.Column(scale=1):
out_gallery_dfine = gr.Gallery(
label="Person/car crops (all D-FINE objects inside drawn with label + score)",
height=IMG_HEIGHT,
columns=2,
object_fit="contain",
)
out_gallery_known = gr.Gallery(
label="Known objects (class + score above each crop)",
height=IMG_HEIGHT,
columns=4,
object_fit="contain",
)
out_status_dfine = gr.Textbox(
label="Classification details",
lines=8,
interactive=False,
)
btn_dfine.click(
fn=run_dfine_classify,
inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, classifier_dropdown, siglip_threshold_slider, labels_input],
outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine],
concurrency_limit=1,
)
app.launch(
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(
os.environ.get(
"PORT",
os.environ.get("GRADIO_SERVER_PORT", 7860)
)
),
)