lyimo commited on
Commit
af13d88
·
verified ·
1 Parent(s): 5045d1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -75
app.py CHANGED
@@ -1,102 +1,84 @@
1
- import os
2
  import gradio as gr
3
  import supervision as sv
4
- from inference import get_model
5
  from PIL import Image
6
-
7
- ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY")
8
-
9
- if ROBOFLOW_API_KEY is None:
10
- raise RuntimeError(
11
- "ROBOFLOW_API_KEY ERROR. "
12
- )
13
-
14
- DET_MODEL_ID = "rfdetr-base"
15
- SEG_MODEL_ID = "rfdetr-seg-preview"
16
 
17
  det_model = None
18
  seg_model = None
19
 
20
-
21
- def load_model(task: str):
22
- global det_model, seg_model
23
- try:
24
- if task == "Object Detection":
25
- if det_model is None:
26
- det_model = get_model(DET_MODEL_ID, api_key=ROBOFLOW_API_KEY)
27
- return det_model
28
- else:
29
- if seg_model is None:
30
- seg_model = get_model(SEG_MODEL_ID, api_key=ROBOFLOW_API_KEY)
31
- return seg_model
32
- except Exception as e:
33
- raise gr.Error(
34
- "Failed to load model from Roboflow. "
35
- "Check that your API key is correct and has access to this model.\n\n"
36
- f"Details: {type(e).__name__}: {e}"
37
- )
38
-
39
-
40
- def run_inference(image: Image.Image, task: str, confidence: float):
41
  if image is None:
42
- raise gr.Error("Please upload an image first.")
43
- model = load_model(task)
44
- predictions = model.infer(image, confidence=confidence)[0]
45
- detections = sv.Detections.from_inference(predictions)
46
- labels = [prediction.class_name for prediction in predictions.predictions]
47
- annotated_image = image.copy()
48
-
49
  if task == "Object Detection":
50
- box_annotator = sv.BoxAnnotator(color=sv.ColorPalette.ROBOFLOW)
51
- label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
52
- annotated_image = box_annotator.annotate(annotated_image, detections)
53
- annotated_image = label_annotator.annotate(annotated_image, detections, labels)
 
 
 
 
 
 
54
  else:
55
- mask_annotator = sv.MaskAnnotator(color=sv.ColorPalette.ROBOFLOW)
56
- label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
57
- annotated_image = mask_annotator.annotate(annotated_image, detections)
58
- annotated_image = label_annotator.annotate(annotated_image, detections, labels)
59
-
60
- return annotated_image
61
-
 
 
 
 
 
 
62
 
63
  with gr.Blocks() as demo:
64
- gr.Markdown(
65
- """
66
- # CIAT RF-DETR Demo
67
- Upload an image and choose **Object Detection** or **Segmentation**.
68
- """
69
- )
70
-
71
  with gr.Row():
72
  with gr.Column():
73
- image_input = gr.Image(
74
- label="Input image",
75
- type="pil"
76
- )
77
- task_input = gr.Radio(
78
- choices=["Object Detection", "Segmentation"],
79
  value="Object Detection",
80
  label="Task"
81
  )
82
- conf_input = gr.Slider(
83
  minimum=0.1,
84
- maximum=1.0,
85
  value=0.5,
86
  step=0.05,
87
  label="Confidence threshold"
88
  )
89
- run_button = gr.Button("Run RF-DETR")
90
-
91
  with gr.Column():
92
- image_output = gr.Image(
93
- label="Annotated output"
94
- )
95
-
96
- run_button.click(
97
  fn=run_inference,
98
- inputs=[image_input, task_input, conf_input],
99
- outputs=image_output
100
  )
101
 
102
  if __name__ == "__main__":
 
1
+ import io
2
  import gradio as gr
3
  import supervision as sv
 
4
  from PIL import Image
5
+ from rfdetr import RFDETRBase, RFDETRSegPreview
6
+ from rfdetr.util.coco_classes import COCO_CLASSES
 
 
 
 
 
 
 
 
7
 
8
  det_model = None
9
  seg_model = None
10
 
11
+ def load_det_model():
12
+ global det_model
13
+ if det_model is None:
14
+ det_model = RFDETRBase()
15
+ det_model.optimize_for_inference()
16
+ return det_model
17
+
18
+ def load_seg_model():
19
+ global seg_model
20
+ if seg_model is None:
21
+ seg_model = RFDETRSegPreview()
22
+ seg_model.optimize_for_inference()
23
+ return seg_model
24
+
25
+ def run_inference(image, task, threshold):
 
 
 
 
 
 
26
  if image is None:
27
+ return None
28
+ if isinstance(image, str):
29
+ image = Image.open(io.BytesIO(image)).convert("RGB")
30
+ else:
31
+ image = image.convert("RGB")
 
 
32
  if task == "Object Detection":
33
+ model = load_det_model()
34
+ detections = model.predict(image, threshold=threshold)
35
+ labels = [
36
+ f"{COCO_CLASSES[int(class_id)]} {confidence:.2f}"
37
+ for class_id, confidence in zip(detections.class_id, detections.confidence)
38
+ ]
39
+ annotated = image.copy()
40
+ annotated = sv.BoxAnnotator().annotate(annotated, detections)
41
+ annotated = sv.LabelAnnotator().annotate(annotated, detections, labels)
42
+ return annotated
43
  else:
44
+ model = load_seg_model()
45
+ detections = model.predict(image, threshold=threshold)
46
+ labels = [
47
+ f"{COCO_CLASSES[int(class_id)]} {confidence:.2f}"
48
+ for class_id, confidence in zip(detections.class_id, detections.confidence)
49
+ ]
50
+ annotated = image.copy()
51
+ try:
52
+ annotated = sv.MaskAnnotator().annotate(annotated, detections)
53
+ except Exception:
54
+ annotated = sv.BoxAnnotator().annotate(annotated, detections)
55
+ annotated = sv.LabelAnnotator().annotate(annotated, detections, labels)
56
+ return annotated
57
 
58
  with gr.Blocks() as demo:
59
+ gr.Markdown("# RF-DETR: Detection and Segmentation Preview")
 
 
 
 
 
 
60
  with gr.Row():
61
  with gr.Column():
62
+ inp_image = gr.Image(type="pil", label="Input image")
63
+ task = gr.Radio(
64
+ ["Object Detection", "Segmentation"],
 
 
 
65
  value="Object Detection",
66
  label="Task"
67
  )
68
+ threshold = gr.Slider(
69
  minimum=0.1,
70
+ maximum=0.9,
71
  value=0.5,
72
  step=0.05,
73
  label="Confidence threshold"
74
  )
75
+ run_btn = gr.Button("Run")
 
76
  with gr.Column():
77
+ out_image = gr.Image(type="pil", label="Output", interactive=False)
78
+ run_btn.click(
 
 
 
79
  fn=run_inference,
80
+ inputs=[inp_image, task, threshold],
81
+ outputs=out_image
82
  )
83
 
84
  if __name__ == "__main__":