lyimo commited on
Commit
9b17214
·
verified ·
1 Parent(s): af13d88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -65
app.py CHANGED
@@ -1,85 +1,133 @@
1
- import io
2
  import gradio as gr
3
  import supervision as sv
4
  from PIL import Image
5
- from rfdetr import RFDETRBase, RFDETRSegPreview
 
6
  from rfdetr.util.coco_classes import COCO_CLASSES
7
 
8
- det_model = None
9
- seg_model = None
 
 
10
 
11
- def load_det_model():
12
- global det_model
13
- if det_model is None:
14
- det_model = RFDETRBase()
15
- det_model.optimize_for_inference()
16
- return det_model
17
 
18
- def load_seg_model():
19
- global seg_model
20
- if seg_model is None:
21
- seg_model = RFDETRSegPreview()
22
- seg_model.optimize_for_inference()
23
- return seg_model
24
 
25
- def run_inference(image, task, threshold):
26
- if image is None:
27
- return None
28
- if isinstance(image, str):
29
- image = Image.open(io.BytesIO(image)).convert("RGB")
30
  else:
31
- image = image.convert("RGB")
32
- if task == "Object Detection":
33
- model = load_det_model()
34
- detections = model.predict(image, threshold=threshold)
35
- labels = [
36
- f"{COCO_CLASSES[int(class_id)]} {confidence:.2f}"
37
- for class_id, confidence in zip(detections.class_id, detections.confidence)
38
- ]
39
- annotated = image.copy()
40
- annotated = sv.BoxAnnotator().annotate(annotated, detections)
41
- annotated = sv.LabelAnnotator().annotate(annotated, detections, labels)
42
- return annotated
43
  else:
44
- model = load_seg_model()
45
- detections = model.predict(image, threshold=threshold)
46
- labels = [
47
- f"{COCO_CLASSES[int(class_id)]} {confidence:.2f}"
48
- for class_id, confidence in zip(detections.class_id, detections.confidence)
49
- ]
50
- annotated = image.copy()
51
- try:
52
- annotated = sv.MaskAnnotator().annotate(annotated, detections)
53
- except Exception:
54
- annotated = sv.BoxAnnotator().annotate(annotated, detections)
55
- annotated = sv.LabelAnnotator().annotate(annotated, detections, labels)
56
- return annotated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("# RF-DETR: Detection and Segmentation Preview")
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with gr.Row():
61
  with gr.Column():
62
- inp_image = gr.Image(type="pil", label="Input image")
63
- task = gr.Radio(
64
- ["Object Detection", "Segmentation"],
65
- value="Object Detection",
66
- label="Task"
67
- )
68
- threshold = gr.Slider(
69
- minimum=0.1,
70
- maximum=0.9,
71
- value=0.5,
72
  step=0.05,
73
- label="Confidence threshold"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
- run_btn = gr.Button("Run")
76
  with gr.Column():
77
- out_image = gr.Image(type="pil", label="Output", interactive=False)
78
- run_btn.click(
79
- fn=run_inference,
80
- inputs=[inp_image, task, threshold],
81
- outputs=out_image
 
 
 
 
 
 
82
  )
83
 
84
  if __name__ == "__main__":
85
- demo.launch()
 
 
1
  import gradio as gr
2
  import supervision as sv
3
  from PIL import Image
4
+ from rfdetr import RFDETRMedium, RFDETRSegPreview
5
+ from rfdetr.detr import RFDETR
6
  from rfdetr.util.coco_classes import COCO_CLASSES
7
 
8
+ MARKDOWN = """
9
+ # RF-DETR 🔥
10
+ Medium object detection and segmentation preview
11
+ """
12
 
13
+ COLOR = sv.ColorPalette.from_hex([
14
+ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
15
+ "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
16
+ ])
 
 
17
 
18
+ def load_model(resolution: int, checkpoint: str) -> RFDETR:
19
+ if checkpoint == "medium (object detection)":
20
+ return RFDETRMedium(resolution=resolution)
21
+ if checkpoint == "segmentation preview":
22
+ return RFDETRSegPreview(resolution=resolution)
23
+ raise TypeError("Checkpoint must be medium (object detection) or segmentation preview.")
24
 
25
+ def adjust_resolution(checkpoint: str, resolution: int) -> int:
26
+ if checkpoint == "segmentation preview":
27
+ divisor = 24
28
+ elif checkpoint == "medium (object detection)":
29
+ divisor = 32
30
  else:
31
+ raise ValueError(f"Unknown checkpoint: {checkpoint}")
32
+ remainder = resolution % divisor
33
+ if remainder == 0:
34
+ return resolution
35
+ lower = resolution - remainder
36
+ upper = lower + divisor
37
+ if resolution - lower < upper - resolution:
38
+ return lower
 
 
 
 
39
  else:
40
+ return upper
41
+
42
+ def image_processing_inference(
43
+ input_image: Image.Image,
44
+ confidence: float,
45
+ resolution: int,
46
+ checkpoint: str
47
+ ):
48
+ if input_image is None:
49
+ return None
50
+ resolution = adjust_resolution(checkpoint=checkpoint, resolution=resolution)
51
+ model = load_model(resolution=resolution, checkpoint=checkpoint)
52
+ detections = model.predict(input_image, threshold=confidence)
53
+
54
+ resolution_wh = (input_image.width, input_image.height)
55
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2
56
+ thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
57
+
58
+ mask_annotator = sv.MaskAnnotator(color=COLOR)
59
+ bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness)
60
+ label_annotator = sv.LabelAnnotator(
61
+ color=COLOR,
62
+ text_color=sv.Color.BLACK,
63
+ text_scale=text_scale
64
+ )
65
+
66
+ labels = [
67
+ f"{COCO_CLASSES[class_id]} {confidence:.2f}"
68
+ for class_id, confidence
69
+ in zip(detections.class_id, detections.confidence)
70
+ ]
71
+
72
+ annotated_image = input_image.copy()
73
+ annotated_image = bbox_annotator.annotate(annotated_image, detections)
74
+ annotated_image = label_annotator.annotate(annotated_image, detections, labels)
75
+ if checkpoint == "segmentation preview":
76
+ annotated_image = mask_annotator.annotate(annotated_image, detections)
77
+ return annotated_image
78
 
79
  with gr.Blocks() as demo:
80
+ gr.Markdown(MARKDOWN)
81
+ with gr.Row():
82
+ image_processing_input_image = gr.Image(
83
+ label="Upload image",
84
+ image_mode="RGB",
85
+ type="pil",
86
+ height=600
87
+ )
88
+ image_processing_output_image = gr.Image(
89
+ label="Output image",
90
+ image_mode="RGB",
91
+ type="pil",
92
+ height=600
93
+ )
94
  with gr.Row():
95
  with gr.Column():
96
+ image_processing_confidence_slider = gr.Slider(
97
+ label="Confidence",
98
+ minimum=0.0,
99
+ maximum=1.0,
 
 
 
 
 
 
100
  step=0.05,
101
+ value=0.5,
102
+ )
103
+ image_processing_resolution_slider = gr.Slider(
104
+ label="Inference resolution",
105
+ minimum=224,
106
+ maximum=2240,
107
+ step=1,
108
+ value=896,
109
+ )
110
+ image_processing_checkpoint_dropdown = gr.Dropdown(
111
+ label="Checkpoint",
112
+ choices=[
113
+ "medium (object detection)",
114
+ "segmentation preview"
115
+ ],
116
+ value="medium (object detection)"
117
  )
 
118
  with gr.Column():
119
+ image_processing_submit_button = gr.Button("Submit")
120
+
121
+ image_processing_submit_button.click(
122
+ image_processing_inference,
123
+ inputs=[
124
+ image_processing_input_image,
125
+ image_processing_confidence_slider,
126
+ image_processing_resolution_slider,
127
+ image_processing_checkpoint_dropdown
128
+ ],
129
+ outputs=image_processing_output_image,
130
  )
131
 
132
  if __name__ == "__main__":
133
+ demo.launch(debug=False, show_error=True)