Spaces:
Sleeping
Sleeping
| """ Gradio app: D-FINE + SigLIP Classify. """ | |
| import os | |
| import gradio as gr | |
| from pathlib import Path | |
| from dfine_jina_pipeline import run_single_image | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DEFAULT_LABELS = "gun, knife, cigarette, phone" | |
| def run_dfine_classify(image, dfine_threshold, dfine_model_choice, classifier_choice, siglip_threshold, labels_text): | |
| """D-FINE first, then classify crops with SigLIP. | |
| Returns (group_crop_gallery, known_crop_gallery, status_message). | |
| """ | |
| if image is None: | |
| return [], [], "Upload an image." | |
| labels = [l.strip() for l in labels_text.split(",") if l.strip()] | |
| if not labels: | |
| return [], [], "Enter at least one label." | |
| dfine_model = dfine_model_choice.strip().lower() if dfine_model_choice else "medium-obj2coco" | |
| conf_thresh = float(siglip_threshold) | |
| classifier = classifier_choice.strip() if classifier_choice else "siglip-256" | |
| group_crops, known_crops, status = run_single_image( | |
| image, | |
| dfine_model=dfine_model, | |
| det_threshold=float(dfine_threshold), | |
| conf_threshold=conf_thresh, | |
| gap_threshold=0.0, | |
| min_side=24, | |
| crop_dedup_iou=0.4, | |
| min_display_conf=conf_thresh, | |
| classifier=classifier, | |
| labels=labels, | |
| ) | |
| return [(g, None) for g in (group_crops or [])], [(k, None) for k in (known_crops or [])], status or "" | |
| IMG_HEIGHT = 400 | |
| with gr.Blocks(title="Small Object Detection") as app: | |
| gr.Markdown("# Small Object Detection") | |
| gr.Markdown( | |
| "**D-FINE** detects persons/cars, then small-object crops are classified with **SigLIP** (zero-shot). " | |
| "Choose a D-FINE model and enter comma-separated class labels for SigLIP." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp_dfine = gr.Image( | |
| type="pil", | |
| label="Input image", | |
| height=IMG_HEIGHT | |
| ) | |
| dfine_model_radio = gr.Dropdown( | |
| choices=[ | |
| "small-obj365", "medium-obj365", "large-obj365", | |
| "small-coco", "medium-coco", "large-coco", | |
| "small-obj2coco", "medium-obj2coco", "large-obj2coco", | |
| ], | |
| value="medium-obj2coco", | |
| label="D-FINE model", | |
| ) | |
| classifier_dropdown = gr.Dropdown( | |
| choices=["siglip-224", "siglip-256", "siglip-384"], | |
| value="siglip-256", | |
| label="Classifier model", | |
| ) | |
| dfine_threshold_slider = gr.Slider( | |
| minimum=0.05, | |
| maximum=0.5, | |
| value=0.15, | |
| step=0.05, | |
| label="D-FINE detection threshold", | |
| ) | |
| def update_dfine_threshold_default(choice): | |
| if not choice: | |
| return gr.update(value=0.15) | |
| size = choice.strip().lower().split("-")[0] | |
| defaults = {"large": 0.2, "medium": 0.15, "small": 0.1} | |
| return gr.update(value=defaults.get(size, 0.15)) | |
| dfine_model_radio.change( | |
| fn=update_dfine_threshold_default, | |
| inputs=[dfine_model_radio], | |
| outputs=[dfine_threshold_slider], | |
| ) | |
| siglip_threshold_slider = gr.Slider( | |
| minimum=0.001, | |
| maximum=0.1, | |
| value=0.005, | |
| step=0.001, | |
| label="SigLIP: min confidence threshold", | |
| ) | |
| labels_input = gr.Textbox( | |
| label="Labels (comma-separated)", | |
| value=DEFAULT_LABELS, | |
| placeholder="e.g. gun, knife, cigarette, phone", | |
| ) | |
| btn_dfine = gr.Button( | |
| "Run D-FINE + Classify", | |
| variant="primary" | |
| ) | |
| with gr.Column(scale=1): | |
| out_gallery_dfine = gr.Gallery( | |
| label="Person/car crops (all D-FINE objects inside drawn with label + score)", | |
| height=IMG_HEIGHT, | |
| columns=2, | |
| object_fit="contain", | |
| ) | |
| out_gallery_known = gr.Gallery( | |
| label="Known objects (class + score above each crop)", | |
| height=IMG_HEIGHT, | |
| columns=4, | |
| object_fit="contain", | |
| ) | |
| out_status_dfine = gr.Textbox( | |
| label="Classification details", | |
| lines=8, | |
| interactive=False, | |
| ) | |
| btn_dfine.click( | |
| fn=run_dfine_classify, | |
| inputs=[inp_dfine, dfine_threshold_slider, dfine_model_radio, classifier_dropdown, siglip_threshold_slider, labels_input], | |
| outputs=[out_gallery_dfine, out_gallery_known, out_status_dfine], | |
| concurrency_limit=1, | |
| ) | |
| app.launch( | |
| server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| server_port=int( | |
| os.environ.get( | |
| "PORT", | |
| os.environ.get("GRADIO_SERVER_PORT", 7860) | |
| ) | |
| ), | |
| ) | |