Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import PIL.Image | |
| import torch | |
| from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| class Detector: | |
| def __init__(self, model_id: str): | |
| self.device = DEVICE | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( | |
| self.device | |
| ) | |
| def detect( | |
| self, | |
| image: PIL.Image.Image, | |
| text_labels: list[str], | |
| threshold: float = 0.4, | |
| ): | |
| inputs = self.processor( | |
| images=image, text=[text_labels], return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| results = self.processor.post_process_grounded_object_detection( | |
| outputs, threshold=threshold, target_sizes=[(image.height, image.width)] | |
| ) | |
| detections = [] | |
| result = results[0] | |
| for box, score, labels in zip( | |
| result["boxes"], result["scores"], result["text_labels"] | |
| ): | |
| box = [round(x, 2) for x in box.tolist()] | |
| detections.append( | |
| dict( | |
| label=labels, | |
| confidence=round(score.item(), 3), | |
| box=box, | |
| ) | |
| ) | |
| return detections | |
| models = dict( | |
| tiny=Detector("iSEE-Laboratory/llmdet_tiny"), | |
| base=Detector("iSEE-Laboratory/llmdet_base"), | |
| large=Detector("iSEE-Laboratory/llmdet_large"), | |
| ) | |
| def _postprocess(detections): | |
| annotations = [] | |
| for detection in detections: | |
| box = detection["box"] | |
| mask = (int(box[0]), int(box[1]), int(box[2]), int(box[3])) | |
| label = f"{detection['label']} ({detection['confidence']:.2f})" | |
| annotations.append((mask, label)) | |
| return annotations | |
| def detect_objects(image, labels, confidence_threshold): | |
| labels = [label.strip() for label in labels.split(",")] | |
| detections = [] | |
| for model_name in models.keys(): | |
| detection = models[model_name].detect( | |
| image, | |
| labels, | |
| threshold=confidence_threshold, | |
| ) | |
| detections.append(_postprocess(detection)) | |
| return tuple((image, det) for det in detections) | |
| with gr.Blocks(delete_cache=(5, 10)) as demo: | |
| gr.Markdown( | |
| "# LLMDet Arena ✨\n ### [Paper](https://arxiv.org/abs/2501.18954) - [Repository](https://github.com/iSEE-Laboratory/LLMDet)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Input Image") | |
| image_input = gr.Image(type="pil", image_mode="RGB", format="jpeg") | |
| with gr.Column(): | |
| gr.Markdown("## Settings") | |
| confidence_slider = gr.Slider( | |
| 0, | |
| 1, | |
| value=0.3, | |
| step=0.01, | |
| interactive=True, | |
| label="Confidence threshold:", | |
| ) | |
| labels = ["a cat", "a remote control"] | |
| text_input = gr.Textbox( | |
| label="Object labels (comma separated):", | |
| placeholder=",".join(labels), | |
| lines=1, | |
| ) | |
| with gr.Row(): | |
| detect_button = gr.Button("Detect Objects") | |
| with gr.Row(): | |
| gr.Markdown("## Output Annotated Images") | |
| with gr.Row(): | |
| output_annotated_image_tiny = gr.AnnotatedImage(label="TINY", format="jpeg") | |
| output_annotated_image_base = gr.AnnotatedImage(label="BASE", format="jpeg") | |
| output_annotated_image_large = gr.AnnotatedImage(label="LARGE", format="jpeg") | |
| # Connect the button to the detection function | |
| detect_button.click( | |
| fn=detect_objects, | |
| inputs=[image_input, text_input, confidence_slider], | |
| outputs=[ | |
| output_annotated_image_tiny, | |
| output_annotated_image_base, | |
| output_annotated_image_large, | |
| ], | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("## Examples") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "http://images.cocodataset.org/val2017/000000039769.jpg", | |
| "a cat, a remote control", | |
| 0.3, | |
| ], | |
| [ | |
| "http://images.cocodataset.org/val2017/000000370486.jpg", | |
| "a person", | |
| 0.3, | |
| ], | |
| [ | |
| "http://images.cocodataset.org/train2017/000000345263.jpg", | |
| "a red apple, a green apple", | |
| 0.3, | |
| ], | |
| ], | |
| inputs=[image_input, text_input, confidence_slider], | |
| outputs=[ | |
| output_annotated_image_tiny, | |
| output_annotated_image_base, | |
| output_annotated_image_large, | |
| ], | |
| fn=detect_objects, | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |