CIAT_Demo / app.py
lyimo's picture
Update app.py
9c1c3e8 verified
import gradio as gr
import supervision as sv
from PIL import Image
from rfdetr import RFDETRMedium, RFDETRSegPreview
from rfdetr.detr import RFDETR
from rfdetr.util.coco_classes import COCO_CLASSES
MARKDOWN = """
# (RF-DETR Demo-Alliance Bioversity-CIAT)
RF-DETR object detection and segmentation
"""
COLOR = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])
def load_model(resolution: int, checkpoint: str) -> RFDETR:
if checkpoint == "medium (object detection)":
return RFDETRMedium(resolution=resolution)
if checkpoint == "segmentation preview":
return RFDETRSegPreview(resolution=resolution)
raise TypeError("Checkpoint must be medium (object detection) or segmentation preview.")
def adjust_resolution(checkpoint: str, resolution: int) -> int:
if checkpoint == "segmentation preview":
divisor = 24
elif checkpoint == "medium (object detection)":
divisor = 32
else:
raise ValueError(f"Unknown checkpoint: {checkpoint}")
remainder = resolution % divisor
if remainder == 0:
return resolution
lower = resolution - remainder
upper = lower + divisor
if resolution - lower < upper - resolution:
return lower
else:
return upper
def image_processing_inference(
input_image: Image.Image,
confidence: float,
resolution: int,
checkpoint: str
):
if input_image is None:
return None
resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
model = load_model(resolution=resolution, checkpoint=checkpoint)
detections = model.predict(input_image, threshold=confidence)
resolution_wh = (input_image.width, input_image.height)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
mask_annotator = sv.MaskAnnotator(color=COLOR)
bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness)
label_annotator = sv.LabelAnnotator(
color=COLOR,
text_color=sv.Color.BLACK,
text_scale=text_scale
)
labels = [
f"{COCO_CLASSES[class_id]} {confidence:.2f}"
for class_id, confidence
in zip(detections.class_id, detections.confidence)
]
annotated_image = input_image.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
if checkpoint == "segmentation preview":
annotated_image = mask_annotator.annotate(annotated_image, detections)
return annotated_image
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
image_processing_input_image = gr.Image(
label="Upload image",
image_mode="RGB",
type="pil",
height=600
)
image_processing_output_image = gr.Image(
label="Output image",
image_mode="RGB",
type="pil",
height=600
)
with gr.Row():
with gr.Column():
image_processing_confidence_slider = gr.Slider(
label="Confidence",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
image_processing_resolution_slider = gr.Slider(
label="Inference resolution",
minimum=224,
maximum=2240,
step=1,
value=896,
)
image_processing_checkpoint_dropdown = gr.Dropdown(
label="Checkpoint",
choices=[
"medium (object detection)",
"segmentation preview"
],
value="medium (object detection)"
)
with gr.Column():
image_processing_submit_button = gr.Button("Submit")
image_processing_submit_button.click(
image_processing_inference,
inputs=[
image_processing_input_image,
image_processing_confidence_slider,
image_processing_resolution_slider,
image_processing_checkpoint_dropdown
],
outputs=image_processing_output_image,
)
if __name__ == "__main__":
demo.launch(debug=False, show_error=True)