File size: 4,478 Bytes
721a039 9b17214 af13d88 721a039 9b17214 9c1c3e8 32d6a70 9b17214 721a039 9b17214 af13d88 9b17214 af13d88 9b17214 af13d88 9b17214 5045d1c 9b17214 721a039 9b17214 721a039 9b17214 721a039 9b17214 721a039 9b17214 721a039 9b17214 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import supervision as sv
from PIL import Image
from rfdetr import RFDETRMedium, RFDETRSegPreview
from rfdetr.detr import RFDETR
from rfdetr.util.coco_classes import COCO_CLASSES
MARKDOWN = """
# (RF-DETR Demo-Alliance Bioversity-CIAT)
RF-DETR object detection and segmentation
"""
COLOR = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])
def load_model(resolution: int, checkpoint: str) -> RFDETR:
if checkpoint == "medium (object detection)":
return RFDETRMedium(resolution=resolution)
if checkpoint == "segmentation preview":
return RFDETRSegPreview(resolution=resolution)
raise TypeError("Checkpoint must be medium (object detection) or segmentation preview.")
def adjust_resolution(checkpoint: str, resolution: int) -> int:
if checkpoint == "segmentation preview":
divisor = 24
elif checkpoint == "medium (object detection)":
divisor = 32
else:
raise ValueError(f"Unknown checkpoint: {checkpoint}")
remainder = resolution % divisor
if remainder == 0:
return resolution
lower = resolution - remainder
upper = lower + divisor
if resolution - lower < upper - resolution:
return lower
else:
return upper
def image_processing_inference(
input_image: Image.Image,
confidence: float,
resolution: int,
checkpoint: str
):
if input_image is None:
return None
resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
model = load_model(resolution=resolution, checkpoint=checkpoint)
detections = model.predict(input_image, threshold=confidence)
resolution_wh = (input_image.width, input_image.height)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
mask_annotator = sv.MaskAnnotator(color=COLOR)
bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness)
label_annotator = sv.LabelAnnotator(
color=COLOR,
text_color=sv.Color.BLACK,
text_scale=text_scale
)
labels = [
f"{COCO_CLASSES[class_id]} {confidence:.2f}"
for class_id, confidence
in zip(detections.class_id, detections.confidence)
]
annotated_image = input_image.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
if checkpoint == "segmentation preview":
annotated_image = mask_annotator.annotate(annotated_image, detections)
return annotated_image
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
image_processing_input_image = gr.Image(
label="Upload image",
image_mode="RGB",
type="pil",
height=600
)
image_processing_output_image = gr.Image(
label="Output image",
image_mode="RGB",
type="pil",
height=600
)
with gr.Row():
with gr.Column():
image_processing_confidence_slider = gr.Slider(
label="Confidence",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
image_processing_resolution_slider = gr.Slider(
label="Inference resolution",
minimum=224,
maximum=2240,
step=1,
value=896,
)
image_processing_checkpoint_dropdown = gr.Dropdown(
label="Checkpoint",
choices=[
"medium (object detection)",
"segmentation preview"
],
value="medium (object detection)"
)
with gr.Column():
image_processing_submit_button = gr.Button("Submit")
image_processing_submit_button.click(
image_processing_inference,
inputs=[
image_processing_input_image,
image_processing_confidence_slider,
image_processing_resolution_slider,
image_processing_checkpoint_dropdown
],
outputs=image_processing_output_image,
)
if __name__ == "__main__":
demo.launch(debug=False, show_error=True)
|