RFDetr / app.py
stevenbucaille's picture
Fix org name
34aa29c
import gradio as gr
import spaces
import supervision as sv
import torch
from transformers import (
AutoImageProcessor,
RfDetrForInstanceSegmentation,
RfDetrForObjectDetection,
)
def _is_segmentation_model(model_basename: str) -> bool:
return "seg" in model_basename
@spaces.GPU
def infer(model_name, image, confidence_threshold):
# Dynamically scale text and boxes based on image size
width, height = image.size
text_scale = (width / 1000) * 0.5
text_thickness = max(1, int(round(width / 500)))
label_annotator = sv.LabelAnnotator(
text_padding=4,
text_scale=text_scale,
text_thickness=text_thickness,
smart_position=True,
)
box_annotator = sv.BoxAnnotator()
mask_annotator = sv.MaskAnnotator(
color_lookup=sv.ColorLookup.CLASS,
opacity=0.5,
)
hub_model_id = f"stevenbucaille/{model_name}"
segmentation = _is_segmentation_model(model_name)
processor = AutoImageProcessor.from_pretrained(hub_model_id)
if segmentation:
model = RfDetrForInstanceSegmentation.from_pretrained(hub_model_id)
else:
model = RfDetrForObjectDetection.from_pretrained(hub_model_id)
inputs = processor(images=image, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
if segmentation:
target_sizes = [image.size[::-1]]
results = processor.post_process_instance_segmentation(
outputs,
target_sizes=target_sizes,
threshold=confidence_threshold,
)[0]
else:
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=confidence_threshold
)[0]
detections = sv.Detections.from_transformers(
transformers_results=results, id2label=model.config.id2label
)
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(
detections["class_name"], detections.confidence
)
]
if segmentation:
image = mask_annotator.annotate(image, detections)
image = label_annotator.annotate(image, detections, labels)
else:
image = box_annotator.annotate(image, detections)
image = label_annotator.annotate(image, detections, labels)
return image
with gr.Blocks() as demo:
gr.Markdown("# RF-DETR Object Detection")
gr.Markdown(
"RF-DETR is a transformer-based object detection model that is trained on the Objects365 and COCO datasets."
)
gr.Markdown(
"This space is a demo of the RF-DETR model. You can select a model and an image and see the results."
)
with gr.Row():
with gr.Column():
model = gr.Radio(
[
"rf-detr-base",
"rf-detr-base-2",
"rf-detr-large",
"rf-detr-medium",
"rf-detr-nano",
"rf-detr-seg-large",
"rf-detr-seg-medium",
"rf-detr-seg-nano",
"rf-detr-seg-preview",
"rf-detr-seg-small",
"rf-detr-seg-xlarge",
"rf-detr-seg-xxlarge",
"rf-detr-segmentation",
"rf-detr-small",
],
value="rf-detr-base",
label="Model",
)
confidence_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.1,
label="Confidence Threshold",
)
input_image = gr.Image(label="Input Image", type="pil")
send_btn = gr.Button("Infer", variant="primary")
with gr.Column():
output_image = gr.Image(label="Output Image", type="pil")
gr.Examples(
examples=[
"samples/cats.jpg",
"samples/detectron2.png",
"samples/cat.jpg",
"samples/hotdog.jpg",
],
inputs=input_image,
)
send_btn.click(
fn=infer,
inputs=[model, input_image, confidence_threshold],
outputs=[output_image],
)
demo.launch(debug=True)