openvision commited on
Commit
dbd2e13
·
verified ·
1 Parent(s): 48750e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -29
app.py CHANGED
@@ -20,34 +20,34 @@ TASK_TO_REPO_TEMPLATE = {
20
 
21
  YOLOE_REPO_TEMPLATE = "openvision/yoloe26-{scale}-seg"
22
 
 
23
  model_cache = {}
24
 
25
 
26
  def _scale_from_ui_name(model_name: str) -> str:
27
- """
28
- Convert dropdown model string to scale token used in repo names.
29
- Examples:
30
- "YOLO26-N" -> "n"
31
- "YOLOE26-N" -> "n"
32
- """
33
  return model_name.split("-")[-1].strip().lower()
34
 
35
 
 
 
 
 
 
 
 
 
36
  def _get_model(repo_id: str) -> YOLO:
37
- """Download (if needed) and cache YOLO model from a repo that contains 'model.pt'."""
38
  cache_key = f"{repo_id}::model.pt"
39
  if cache_key not in model_cache:
40
- weights_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
41
  model_cache[cache_key] = YOLO(weights_path)
42
  return model_cache[cache_key]
43
 
44
 
45
  def predict_yolo26(image, model_name, task, conf, iou, retina):
46
- """Run YOLO26 inference for various tasks."""
47
  scale = _scale_from_ui_name(model_name)
48
- repo_tmpl = TASK_TO_REPO_TEMPLATE[task]
49
- repo_id = repo_tmpl.format(scale=scale)
50
-
51
  model = _get_model(repo_id)
52
 
53
  use_retina = bool(retina) and task == "Segmentation"
@@ -60,35 +60,50 @@ def predict_yolo26(image, model_name, task, conf, iou, retina):
60
  return Image.fromarray(results[0].plot()[..., ::-1]), None
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def predict_yoloe26(image, model_name, classes_text, conf, retina):
64
- """Run YOLOE26 open-vocabulary inference with text prompts."""
 
 
 
65
  scale = _scale_from_ui_name(model_name)
66
  repo_id = YOLOE_REPO_TEMPLATE.format(scale=scale)
67
 
68
- model = _get_model(repo_id)
 
69
 
70
- names = [c.strip() for c in classes_text.split(",") if c.strip()]
71
  model.set_classes(names, model.get_text_pe(names))
72
-
73
  res = model.predict(source=image, conf=conf, imgsz=640, retina_masks=bool(retina))[0]
74
  return Image.fromarray(res.plot()[..., ::-1])
75
 
76
 
77
  theme = gr.themes.Base().set(
78
- button_primary_background_fill="#111F68", button_primary_background_fill_hover="#042AFF"
 
79
  )
80
 
81
- # Build interface
82
- with gr.Blocks(title="Ultralytics YOLO26 & YOLOE26 Demo") as demo:
83
  gr.Markdown(
84
  "# 🚀 Ultralytics YOLO26 & YOLOE26 Demo\n"
85
- "Showcasing YOLO26 tasks and YOLOE26 open-vocabulary segmentation. "
86
- "[GitHub](https://github.com/ultralytics/ultralytics) | [Docs](https://docs.ultralytics.com/models/yolo26/)"
87
  )
88
 
89
  with gr.Tabs():
90
  with gr.Tab("YOLO26 Tasks"):
91
- gr.Markdown("### Ultralytics YOLO26: Detection, Segmentation, Pose, OBB, Classification")
92
  with gr.Row():
93
  with gr.Column():
94
  y26_image = gr.Image(type="pil", label="Upload Image")
@@ -100,6 +115,7 @@ with gr.Blocks(title="Ultralytics YOLO26 & YOLOE26 Demo") as demo:
100
  y26_iou = gr.Slider(0, 1, value=0.45, label="IoU Threshold")
101
  y26_retina = gr.Checkbox(value=True, label="Retina Masks", info="Higher quality masks, slower inference")
102
  y26_btn = gr.Button("Run Inference", variant="primary")
 
103
  with gr.Column():
104
  y26_output = gr.Image(type="pil", label="Result")
105
  y26_label = gr.Label(label="Classification Results", visible=False)
@@ -131,32 +147,45 @@ with gr.Blocks(title="Ultralytics YOLO26 & YOLOE26 Demo") as demo:
131
  )
132
 
133
  with gr.Tab("YOLOE26 Open-Vocabulary"):
134
- gr.Markdown("### Ultralytics YOLOE26: Open-Vocabulary Segmentation - Detect any object by text description")
 
135
  with gr.Row():
136
  with gr.Column():
137
  ye_image = gr.Image(type="pil", label="Upload Image")
138
- with gr.Row():
139
- ye_model = gr.Dropdown(["YOLOE26-N"], value="YOLOE26-N", label="Model")
140
- ye_classes = gr.Textbox(label="Classes")
 
 
 
141
  with gr.Accordion("Advanced Settings", open=False):
142
  ye_conf = gr.Slider(0, 1, value=0.2, label="Confidence Threshold")
143
  ye_retina = gr.Checkbox(value=True, label="Retina Masks", info="Higher quality masks, slower inference")
144
  ye_btn = gr.Button("Run Inference", variant="primary")
 
145
  with gr.Column():
146
  ye_output = gr.Image(type="pil", label="Result")
147
 
 
 
 
 
148
  gr.Examples(
149
  examples=[
150
  [str(ASSETS / "bus.jpg"), "YOLOE26-N", "person, bus, car", 0.2, True],
151
  [str(ASSETS / "zidane.jpg"), "YOLOE26-N", "person, football, grass", 0.2, True],
 
152
  ],
153
  inputs=[ye_image, ye_model, ye_classes, ye_conf, ye_retina],
154
  outputs=ye_output,
155
  fn=predict_yoloe26,
156
- #cache_examples=True,
157
  )
158
 
159
- ye_btn.click(predict_yoloe26, [ye_image, ye_model, ye_classes, ye_conf, ye_retina], ye_output)
 
 
 
 
160
 
161
  if __name__ == "__main__":
162
- demo.launch(theme=theme, allowed_paths=[str(ASSETS), str(ASSETS.parent)])
 
20
 
21
  YOLOE_REPO_TEMPLATE = "openvision/yoloe26-{scale}-seg"
22
 
23
+ weights_cache = {}
24
  model_cache = {}
25
 
26
 
27
  def _scale_from_ui_name(model_name: str) -> str:
 
 
 
 
 
 
28
  return model_name.split("-")[-1].strip().lower()
29
 
30
 
31
+ def _get_weights(repo_id: str) -> str:
32
+ """Download (if needed) and cache model.pt path."""
33
+ cache_key = f"{repo_id}::model.pt"
34
+ if cache_key not in weights_cache:
35
+ weights_cache[cache_key] = hf_hub_download(repo_id=repo_id, filename="model.pt")
36
+ return weights_cache[cache_key]
37
+
38
+
39
  def _get_model(repo_id: str) -> YOLO:
40
+ """Download (if needed) and cache YOLO model (safe for YOLO26 tasks)."""
41
  cache_key = f"{repo_id}::model.pt"
42
  if cache_key not in model_cache:
43
+ weights_path = _get_weights(repo_id)
44
  model_cache[cache_key] = YOLO(weights_path)
45
  return model_cache[cache_key]
46
 
47
 
48
  def predict_yolo26(image, model_name, task, conf, iou, retina):
 
49
  scale = _scale_from_ui_name(model_name)
50
+ repo_id = TASK_TO_REPO_TEMPLATE[task].format(scale=scale)
 
 
51
  model = _get_model(repo_id)
52
 
53
  use_retina = bool(retina) and task == "Segmentation"
 
60
  return Image.fromarray(results[0].plot()[..., ::-1]), None
61
 
62
 
63
+ def _parse_classes(classes_text: str):
64
+ if classes_text is None:
65
+ return []
66
+ names = [c.strip() for c in classes_text.split(",") if c.strip()]
67
+ # de-dup while preserving order
68
+ seen = set()
69
+ out = []
70
+ for n in names:
71
+ if n.lower() not in seen:
72
+ out.append(n)
73
+ seen.add(n.lower())
74
+ return out
75
+
76
+
77
  def predict_yoloe26(image, model_name, classes_text, conf, retina):
78
+ names = _parse_classes(classes_text)
79
+ if not names:
80
+ raise gr.Error("Enter at least 1 class (comma-separated). Example: 'cat, dog, bicycle'")
81
+
82
  scale = _scale_from_ui_name(model_name)
83
  repo_id = YOLOE_REPO_TEMPLATE.format(scale=scale)
84
 
85
+ weights_path = _get_weights(repo_id)
86
+ model = YOLO(weights_path)
87
 
 
88
  model.set_classes(names, model.get_text_pe(names))
 
89
  res = model.predict(source=image, conf=conf, imgsz=640, retina_masks=bool(retina))[0]
90
  return Image.fromarray(res.plot()[..., ::-1])
91
 
92
 
93
  theme = gr.themes.Base().set(
94
+ button_primary_background_fill="#111F68",
95
+ button_primary_background_fill_hover="#042AFF",
96
  )
97
 
98
+ with gr.Blocks(title="Ultralytics YOLO26 & YOLOE26 Demo", theme=theme) as demo:
 
99
  gr.Markdown(
100
  "# 🚀 Ultralytics YOLO26 & YOLOE26 Demo\n"
101
+ "YOLO26 tasks + YOLOE26 open-vocabulary segmentation."
 
102
  )
103
 
104
  with gr.Tabs():
105
  with gr.Tab("YOLO26 Tasks"):
106
+ gr.Markdown("### Detection, Segmentation, Pose, OBB, Classification")
107
  with gr.Row():
108
  with gr.Column():
109
  y26_image = gr.Image(type="pil", label="Upload Image")
 
115
  y26_iou = gr.Slider(0, 1, value=0.45, label="IoU Threshold")
116
  y26_retina = gr.Checkbox(value=True, label="Retina Masks", info="Higher quality masks, slower inference")
117
  y26_btn = gr.Button("Run Inference", variant="primary")
118
+
119
  with gr.Column():
120
  y26_output = gr.Image(type="pil", label="Result")
121
  y26_label = gr.Label(label="Classification Results", visible=False)
 
147
  )
148
 
149
  with gr.Tab("YOLOE26 Open-Vocabulary"):
150
+ gr.Markdown("### Open-Vocabulary Segmentation (text prompts)")
151
+
152
  with gr.Row():
153
  with gr.Column():
154
  ye_image = gr.Image(type="pil", label="Upload Image")
155
+ ye_model = gr.Dropdown(["YOLOE26-N"], value="YOLOE26-N", label="Model")
156
+ ye_classes = gr.Textbox(
157
+ label="Classes (comma-separated)",
158
+ placeholder="e.g. cat, dog, bicycle",
159
+ value="person, bus, car",
160
+ )
161
  with gr.Accordion("Advanced Settings", open=False):
162
  ye_conf = gr.Slider(0, 1, value=0.2, label="Confidence Threshold")
163
  ye_retina = gr.Checkbox(value=True, label="Retina Masks", info="Higher quality masks, slower inference")
164
  ye_btn = gr.Button("Run Inference", variant="primary")
165
+
166
  with gr.Column():
167
  ye_output = gr.Image(type="pil", label="Result")
168
 
169
+ ye_prompt_state = gr.State(ye_classes.value)
170
+
171
+ ye_classes.change(lambda s: s, ye_classes, ye_prompt_state)
172
+
173
  gr.Examples(
174
  examples=[
175
  [str(ASSETS / "bus.jpg"), "YOLOE26-N", "person, bus, car", 0.2, True],
176
  [str(ASSETS / "zidane.jpg"), "YOLOE26-N", "person, football, grass", 0.2, True],
177
+ [str(ASSETS / "bus.jpg"), "YOLOE26-N", "bicycle, traffic light, road", 0.2, True],
178
  ],
179
  inputs=[ye_image, ye_model, ye_classes, ye_conf, ye_retina],
180
  outputs=ye_output,
181
  fn=predict_yoloe26,
 
182
  )
183
 
184
+ ye_btn.click(
185
+ predict_yoloe26,
186
+ [ye_image, ye_model, ye_prompt_state, ye_conf, ye_retina],
187
+ ye_output,
188
+ )
189
 
190
  if __name__ == "__main__":
191
+ demo.launch(allowed_paths=[str(ASSETS), str(ASSETS.parent)])