import gradio as gr import PIL.Image import torch from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor from transformers.image_utils import load_image DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class Detector: def __init__(self, model_id: str): self.device = DEVICE self.processor = AutoProcessor.from_pretrained(model_id) self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( self.device ) def detect( self, image: PIL.Image.Image, text_labels: list[str], threshold: float = 0.4, ): inputs = self.processor( images=image, text=[text_labels], return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) results = self.processor.post_process_grounded_object_detection( outputs, threshold=threshold, target_sizes=[(image.height, image.width)] ) detections = [] result = results[0] for box, score, labels in zip( result["boxes"], result["scores"], result["labels"] ): box = [round(x, 2) for x in box.tolist()] detections.append( dict( label=labels, confidence=round(score.item(), 3), box=box, ) ) return detections models = dict( tiny=Detector("iSEE-Laboratory/llmdet_tiny"), base=Detector("iSEE-Laboratory/llmdet_base"), large=Detector("iSEE-Laboratory/llmdet_large"), ) def _postprocess(detections): annotations = [] for detection in detections: box = detection["box"] mask = (int(box[0]), int(box[1]), int(box[2]), int(box[3])) label = f"{detection['label']} ({detection['confidence']:.2f})" annotations.append((mask, label)) return annotations def detect_objects(image, labels, confidence_threshold): labels = [label.strip() for label in labels.split(",")] detections = [] for model_name in models.keys(): detection = models[model_name].detect( image, labels, threshold=confidence_threshold, ) detections.append(_postprocess(detection)) return tuple((image, det) for det in detections) with gr.Blocks() as demo: gr.Markdown("# [LLMDet](https://arxiv.org/abs/2501.18954) Arena ✨") with gr.Row(): with gr.Column(): gr.Markdown("## Input Image") image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = load_image(image_url) image_input = gr.Image(type="pil", image_mode="RGB", value=image) with gr.Column(): gr.Markdown("## Settings") confidence_slider = gr.Slider( 0, 1, value=0.4, step=0.01, interactive=True, label="Confidence threshold:", ) labels = ["a cat", "a remote control"] text_input = gr.Textbox( label="Object labels (comma separated):", placeholder=",".join(labels), lines=1, value=",".join(labels), ) with gr.Row(): detect_button = gr.Button("Run Object Detection") with gr.Row(): gr.Markdown("## Output Annotated Images") with gr.Row(): output_annotated_image_tiny = gr.AnnotatedImage(label="TINY") output_annotated_image_base = gr.AnnotatedImage(label="BASE") output_annotated_image_large = gr.AnnotatedImage(label="LARGE") # Connect the button to the detection function detect_button.click( fn=detect_objects, inputs=[image_input, text_input, confidence_slider], outputs=[ output_annotated_image_tiny, output_annotated_image_base, output_annotated_image_large, ], ) if __name__ == "__main__": demo.launch()