| import gradio as gr | |
| import supervision as sv | |
| from PIL import Image | |
| from rfdetr import RFDETRMedium, RFDETRSegPreview | |
| from rfdetr.detr import RFDETR | |
| from rfdetr.util.coco_classes import COCO_CLASSES | |
| MARKDOWN = """ | |
| # (RF-DETR Demo-Alliance Bioversity-CIAT) | |
| RF-DETR object detection and segmentation | |
| """ | |
| COLOR = sv.ColorPalette.from_hex([ | |
| "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", | |
| "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" | |
| ]) | |
| def load_model(resolution: int, checkpoint: str) -> RFDETR: | |
| if checkpoint == "medium (object detection)": | |
| return RFDETRMedium(resolution=resolution) | |
| if checkpoint == "segmentation preview": | |
| return RFDETRSegPreview(resolution=resolution) | |
| raise TypeError("Checkpoint must be medium (object detection) or segmentation preview.") | |
| def adjust_resolution(checkpoint: str, resolution: int) -> int: | |
| if checkpoint == "segmentation preview": | |
| divisor = 24 | |
| elif checkpoint == "medium (object detection)": | |
| divisor = 32 | |
| else: | |
| raise ValueError(f"Unknown checkpoint: {checkpoint}") | |
| remainder = resolution % divisor | |
| if remainder == 0: | |
| return resolution | |
| lower = resolution - remainder | |
| upper = lower + divisor | |
| if resolution - lower < upper - resolution: | |
| return lower | |
| else: | |
| return upper | |
| def image_processing_inference( | |
| input_image: Image.Image, | |
| confidence: float, | |
| resolution: int, | |
| checkpoint: str | |
| ): | |
| if input_image is None: | |
| return None | |
| resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution) | |
| model = load_model(resolution=resolution, checkpoint=checkpoint) | |
| detections = model.predict(input_image, threshold=confidence) | |
| resolution_wh = (input_image.width, input_image.height) | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2 | |
| thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) | |
| mask_annotator = sv.MaskAnnotator(color=COLOR) | |
| bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness) | |
| label_annotator = sv.LabelAnnotator( | |
| color=COLOR, | |
| text_color=sv.Color.BLACK, | |
| text_scale=text_scale | |
| ) | |
| labels = [ | |
| f"{COCO_CLASSES[class_id]} {confidence:.2f}" | |
| for class_id, confidence | |
| in zip(detections.class_id, detections.confidence) | |
| ] | |
| annotated_image = input_image.copy() | |
| annotated_image = bbox_annotator.annotate(annotated_image, detections) | |
| annotated_image = label_annotator.annotate(annotated_image, detections, labels) | |
| if checkpoint == "segmentation preview": | |
| annotated_image = mask_annotator.annotate(annotated_image, detections) | |
| return annotated_image | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Row(): | |
| image_processing_input_image = gr.Image( | |
| label="Upload image", | |
| image_mode="RGB", | |
| type="pil", | |
| height=600 | |
| ) | |
| image_processing_output_image = gr.Image( | |
| label="Output image", | |
| image_mode="RGB", | |
| type="pil", | |
| height=600 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_processing_confidence_slider = gr.Slider( | |
| label="Confidence", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.5, | |
| ) | |
| image_processing_resolution_slider = gr.Slider( | |
| label="Inference resolution", | |
| minimum=224, | |
| maximum=2240, | |
| step=1, | |
| value=896, | |
| ) | |
| image_processing_checkpoint_dropdown = gr.Dropdown( | |
| label="Checkpoint", | |
| choices=[ | |
| "medium (object detection)", | |
| "segmentation preview" | |
| ], | |
| value="medium (object detection)" | |
| ) | |
| with gr.Column(): | |
| image_processing_submit_button = gr.Button("Submit") | |
| image_processing_submit_button.click( | |
| image_processing_inference, | |
| inputs=[ | |
| image_processing_input_image, | |
| image_processing_confidence_slider, | |
| image_processing_resolution_slider, | |
| image_processing_checkpoint_dropdown | |
| ], | |
| outputs=image_processing_output_image, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=False, show_error=True) | |