lyimo commited on
Commit
721a039
·
verified ·
1 Parent(s): b42c806

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import supervision as sv
3
+ from inference import get_model
4
+ from PIL import Image
5
+
6
+ # Model IDs from RF-DETR docs
7
+ DET_MODEL_ID = "rfdetr-base"
8
+ SEG_MODEL_ID = "rfdetr-seg-preview"
9
+
10
+ det_model = None
11
+ seg_model = None
12
+
13
+
14
+ def load_model(task: str):
15
+ """
16
+ Lazily load the selected model once, then reuse it.
17
+ """
18
+ global det_model, seg_model
19
+
20
+ if task == "Object Detection":
21
+ if det_model is None:
22
+ det_model = get_model(DET_MODEL_ID)
23
+ return det_model
24
+ else: # "Segmentation"
25
+ if seg_model is None:
26
+ seg_model = get_model(SEG_MODEL_ID)
27
+ return seg_model
28
+
29
+
30
+ def run_inference(image: Image.Image, task: str, confidence: float):
31
+ if image is None:
32
+ return None
33
+
34
+ model = load_model(task)
35
+ # Run inference
36
+ predictions = model.infer(image, confidence=confidence)[0]
37
+
38
+ # Convert to supervision.Detections
39
+ detections = sv.Detections.from_inference(predictions)
40
+
41
+ # Labels (class names)
42
+ labels = [prediction.class_name for prediction in predictions.predictions]
43
+
44
+ annotated_image = image.copy()
45
+
46
+ if task == "Object Detection":
47
+ box_annotator = sv.BoxAnnotator(color=sv.ColorPalette.ROBOFLOW)
48
+ label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
49
+
50
+ annotated_image = box_annotator.annotate(annotated_image, detections)
51
+ annotated_image = label_annotator.annotate(annotated_image, detections, labels)
52
+
53
+ else: # Segmentation
54
+ mask_annotator = sv.MaskAnnotator(color=sv.ColorPalette.ROBOFLOW)
55
+ label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
56
+
57
+ annotated_image = mask_annotator.annotate(annotated_image, detections)
58
+ annotated_image = label_annotator.annotate(annotated_image, detections, labels)
59
+
60
+ return annotated_image
61
+
62
+
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown(
65
+ """
66
+ # CIAT RF-DETR Demo
67
+ Upload an image and choose **Object Detection** or **Segmentation**.
68
+ """
69
+ )
70
+
71
+ with gr.Row():
72
+ with gr.Column():
73
+ image_input = gr.Image(
74
+ label="Input image",
75
+ type="pil"
76
+ )
77
+ task_input = gr.Radio(
78
+ choices=["Object Detection", "Segmentation"],
79
+ value="Object Detection",
80
+ label="Task"
81
+ )
82
+ conf_input = gr.Slider(
83
+ minimum=0.1,
84
+ maximum=1.0,
85
+ value=0.5,
86
+ step=0.05,
87
+ label="Confidence threshold"
88
+ )
89
+ run_button = gr.Button("Run RF-DETR")
90
+
91
+ with gr.Column():
92
+ image_output = gr.Image(
93
+ label="Annotated output"
94
+ )
95
+
96
+ run_button.click(
97
+ fn=run_inference,
98
+ inputs=[image_input, task_input, conf_input],
99
+ outputs=image_output
100
+ )
101
+
102
+ if __name__ == "__main__":
103
+ demo.launch()