import gradio as gr import supervision as sv from PIL import Image from rfdetr import RFDETRMedium, RFDETRSegPreview from rfdetr.detr import RFDETR from rfdetr.util.coco_classes import COCO_CLASSES MARKDOWN = """ # (RF-DETR Demo-Alliance Bioversity-CIAT) RF-DETR object detection and segmentation """ COLOR = sv.ColorPalette.from_hex([ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" ]) def load_model(resolution: int, checkpoint: str) -> RFDETR: if checkpoint == "medium (object detection)": return RFDETRMedium(resolution=resolution) if checkpoint == "segmentation preview": return RFDETRSegPreview(resolution=resolution) raise TypeError("Checkpoint must be medium (object detection) or segmentation preview.") def adjust_resolution(checkpoint: str, resolution: int) -> int: if checkpoint == "segmentation preview": divisor = 24 elif checkpoint == "medium (object detection)": divisor = 32 else: raise ValueError(f"Unknown checkpoint: {checkpoint}") remainder = resolution % divisor if remainder == 0: return resolution lower = resolution - remainder upper = lower + divisor if resolution - lower < upper - resolution: return lower else: return upper def image_processing_inference( input_image: Image.Image, confidence: float, resolution: int, checkpoint: str ): if input_image is None: return None resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution) model = load_model(resolution=resolution, checkpoint=checkpoint) detections = model.predict(input_image, threshold=confidence) resolution_wh = (input_image.width, input_image.height) text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2 thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) mask_annotator = sv.MaskAnnotator(color=COLOR) bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness) label_annotator = sv.LabelAnnotator( color=COLOR, text_color=sv.Color.BLACK, text_scale=text_scale ) labels = [ f"{COCO_CLASSES[class_id]} {confidence:.2f}" for class_id, confidence in zip(detections.class_id, detections.confidence) ] annotated_image = input_image.copy() annotated_image = bbox_annotator.annotate(annotated_image, detections) annotated_image = label_annotator.annotate(annotated_image, detections, labels) if checkpoint == "segmentation preview": annotated_image = mask_annotator.annotate(annotated_image, detections) return annotated_image with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): image_processing_input_image = gr.Image( label="Upload image", image_mode="RGB", type="pil", height=600 ) image_processing_output_image = gr.Image( label="Output image", image_mode="RGB", type="pil", height=600 ) with gr.Row(): with gr.Column(): image_processing_confidence_slider = gr.Slider( label="Confidence", minimum=0.0, maximum=1.0, step=0.05, value=0.5, ) image_processing_resolution_slider = gr.Slider( label="Inference resolution", minimum=224, maximum=2240, step=1, value=896, ) image_processing_checkpoint_dropdown = gr.Dropdown( label="Checkpoint", choices=[ "medium (object detection)", "segmentation preview" ], value="medium (object detection)" ) with gr.Column(): image_processing_submit_button = gr.Button("Submit") image_processing_submit_button.click( image_processing_inference, inputs=[ image_processing_input_image, image_processing_confidence_slider, image_processing_resolution_slider, image_processing_checkpoint_dropdown ], outputs=image_processing_output_image, ) if __name__ == "__main__": demo.launch(debug=False, show_error=True)