lyimo commited on
Commit
5045d1c
·
verified ·
1 Parent(s): f00971f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -1,8 +1,16 @@
 
1
  import gradio as gr
2
  import supervision as sv
3
  from inference import get_model
4
  from PIL import Image
5
 
 
 
 
 
 
 
 
6
  DET_MODEL_ID = "rfdetr-base"
7
  SEG_MODEL_ID = "rfdetr-seg-preview"
8
 
@@ -11,48 +19,41 @@ seg_model = None
11
 
12
 
13
  def load_model(task: str):
14
- """
15
- Lazily load the selected model once, then reuse it.
16
- """
17
  global det_model, seg_model
18
-
19
- if task == "Object Detection":
20
- if det_model is None:
21
- det_model = get_model(DET_MODEL_ID)
22
- return det_model
23
- else: # "Segmentation"
24
- if seg_model is None:
25
- seg_model = get_model(SEG_MODEL_ID)
26
- return seg_model
 
 
 
 
 
 
27
 
28
 
29
  def run_inference(image: Image.Image, task: str, confidence: float):
30
  if image is None:
31
- return None
32
-
33
  model = load_model(task)
34
- # Run inference
35
  predictions = model.infer(image, confidence=confidence)[0]
36
-
37
- # Convert to supervision.Detections
38
  detections = sv.Detections.from_inference(predictions)
39
-
40
- # Labels (class names)
41
  labels = [prediction.class_name for prediction in predictions.predictions]
42
-
43
  annotated_image = image.copy()
44
 
45
  if task == "Object Detection":
46
  box_annotator = sv.BoxAnnotator(color=sv.ColorPalette.ROBOFLOW)
47
  label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
48
-
49
  annotated_image = box_annotator.annotate(annotated_image, detections)
50
  annotated_image = label_annotator.annotate(annotated_image, detections, labels)
51
-
52
- else: # Segmentation
53
  mask_annotator = sv.MaskAnnotator(color=sv.ColorPalette.ROBOFLOW)
54
  label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
55
-
56
  annotated_image = mask_annotator.annotate(annotated_image, detections)
57
  annotated_image = label_annotator.annotate(annotated_image, detections, labels)
58
 
@@ -62,7 +63,7 @@ def run_inference(image: Image.Image, task: str, confidence: float):
62
  with gr.Blocks() as demo:
63
  gr.Markdown(
64
  """
65
- # CIAT RF-DETR Demo
66
  Upload an image and choose **Object Detection** or **Segmentation**.
67
  """
68
  )
 
1
+ import os
2
  import gradio as gr
3
  import supervision as sv
4
  from inference import get_model
5
  from PIL import Image
6
 
7
+ ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY")
8
+
9
+ if ROBOFLOW_API_KEY is None:
10
+ raise RuntimeError(
11
+ "ROBOFLOW_API_KEY ERROR. "
12
+ )
13
+
14
  DET_MODEL_ID = "rfdetr-base"
15
  SEG_MODEL_ID = "rfdetr-seg-preview"
16
 
 
19
 
20
 
21
  def load_model(task: str):
 
 
 
22
  global det_model, seg_model
23
+ try:
24
+ if task == "Object Detection":
25
+ if det_model is None:
26
+ det_model = get_model(DET_MODEL_ID, api_key=ROBOFLOW_API_KEY)
27
+ return det_model
28
+ else:
29
+ if seg_model is None:
30
+ seg_model = get_model(SEG_MODEL_ID, api_key=ROBOFLOW_API_KEY)
31
+ return seg_model
32
+ except Exception as e:
33
+ raise gr.Error(
34
+ "Failed to load model from Roboflow. "
35
+ "Check that your API key is correct and has access to this model.\n\n"
36
+ f"Details: {type(e).__name__}: {e}"
37
+ )
38
 
39
 
40
  def run_inference(image: Image.Image, task: str, confidence: float):
41
  if image is None:
42
+ raise gr.Error("Please upload an image first.")
 
43
  model = load_model(task)
 
44
  predictions = model.infer(image, confidence=confidence)[0]
 
 
45
  detections = sv.Detections.from_inference(predictions)
 
 
46
  labels = [prediction.class_name for prediction in predictions.predictions]
 
47
  annotated_image = image.copy()
48
 
49
  if task == "Object Detection":
50
  box_annotator = sv.BoxAnnotator(color=sv.ColorPalette.ROBOFLOW)
51
  label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
 
52
  annotated_image = box_annotator.annotate(annotated_image, detections)
53
  annotated_image = label_annotator.annotate(annotated_image, detections, labels)
54
+ else:
 
55
  mask_annotator = sv.MaskAnnotator(color=sv.ColorPalette.ROBOFLOW)
56
  label_annotator = sv.LabelAnnotator(color=sv.ColorPalette.ROBOFLOW)
 
57
  annotated_image = mask_annotator.annotate(annotated_image, detections)
58
  annotated_image = label_annotator.annotate(annotated_image, detections, labels)
59
 
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown(
65
  """
66
+ # CIAT RF-DETR Demo
67
  Upload an image and choose **Object Detection** or **Segmentation**.
68
  """
69
  )