YOLO26 / app.py
openvision's picture
Update app.py
dbd2e13 verified
import gradio as gr
from PIL import Image
from ultralytics import ASSETS, YOLO
from ultralytics.utils.downloads import safe_download
from huggingface_hub import hf_hub_download
# Download OBB test image if not exists
OBB_IMAGE = ASSETS.parent / "boats.jpg"
if not OBB_IMAGE.exists():
safe_download("https://ultralytics.com/images/boats.jpg", dir=ASSETS.parent)
TASK_TO_REPO_TEMPLATE = {
"Detection": "openvision/yolo26-{scale}",
"Segmentation": "openvision/yolo26-{scale}-seg",
"Classification": "openvision/yolo26-{scale}-cls",
"Pose": "openvision/yolo26-{scale}-pose",
"OBB": "openvision/yolo26-{scale}-obb",
}
YOLOE_REPO_TEMPLATE = "openvision/yoloe26-{scale}-seg"
weights_cache = {}
model_cache = {}
def _scale_from_ui_name(model_name: str) -> str:
return model_name.split("-")[-1].strip().lower()
def _get_weights(repo_id: str) -> str:
"""Download (if needed) and cache model.pt path."""
cache_key = f"{repo_id}::model.pt"
if cache_key not in weights_cache:
weights_cache[cache_key] = hf_hub_download(repo_id=repo_id, filename="model.pt")
return weights_cache[cache_key]
def _get_model(repo_id: str) -> YOLO:
"""Download (if needed) and cache YOLO model (safe for YOLO26 tasks)."""
cache_key = f"{repo_id}::model.pt"
if cache_key not in model_cache:
weights_path = _get_weights(repo_id)
model_cache[cache_key] = YOLO(weights_path)
return model_cache[cache_key]
def predict_yolo26(image, model_name, task, conf, iou, retina):
scale = _scale_from_ui_name(model_name)
repo_id = TASK_TO_REPO_TEMPLATE[task].format(scale=scale)
model = _get_model(repo_id)
use_retina = bool(retina) and task == "Segmentation"
results = model.predict(source=image, conf=conf, iou=iou, imgsz=640, retina_masks=use_retina)
if task == "Classification":
top5 = results[0].probs.top5
return None, {results[0].names[i]: float(results[0].probs.top5conf[j]) for j, i in enumerate(top5)}
return Image.fromarray(results[0].plot()[..., ::-1]), None
def _parse_classes(classes_text: str):
if classes_text is None:
return []
names = [c.strip() for c in classes_text.split(",") if c.strip()]
# de-dup while preserving order
seen = set()
out = []
for n in names:
if n.lower() not in seen:
out.append(n)
seen.add(n.lower())
return out
def predict_yoloe26(image, model_name, classes_text, conf, retina):
names = _parse_classes(classes_text)
if not names:
raise gr.Error("Enter at least 1 class (comma-separated). Example: 'cat, dog, bicycle'")
scale = _scale_from_ui_name(model_name)
repo_id = YOLOE_REPO_TEMPLATE.format(scale=scale)
weights_path = _get_weights(repo_id)
model = YOLO(weights_path)
model.set_classes(names, model.get_text_pe(names))
res = model.predict(source=image, conf=conf, imgsz=640, retina_masks=bool(retina))[0]
return Image.fromarray(res.plot()[..., ::-1])
theme = gr.themes.Base().set(
button_primary_background_fill="#111F68",
button_primary_background_fill_hover="#042AFF",
)
with gr.Blocks(title="Ultralytics YOLO26 & YOLOE26 Demo", theme=theme) as demo:
gr.Markdown(
"# ๐Ÿš€ Ultralytics YOLO26 & YOLOE26 Demo\n"
"YOLO26 tasks + YOLOE26 open-vocabulary segmentation."
)
with gr.Tabs():
with gr.Tab("YOLO26 Tasks"):
gr.Markdown("### Detection, Segmentation, Pose, OBB, Classification")
with gr.Row():
with gr.Column():
y26_image = gr.Image(type="pil", label="Upload Image")
with gr.Row():
y26_model = gr.Dropdown(["YOLO26-N"], value="YOLO26-N", label="Model")
y26_task = gr.Dropdown(list(TASK_TO_REPO_TEMPLATE.keys()), value="Detection", label="Task")
with gr.Accordion("Advanced Settings", open=False):
y26_conf = gr.Slider(0, 1, value=0.25, label="Confidence Threshold")
y26_iou = gr.Slider(0, 1, value=0.45, label="IoU Threshold")
y26_retina = gr.Checkbox(value=True, label="Retina Masks", info="Higher quality masks, slower inference")
y26_btn = gr.Button("Run Inference", variant="primary")
with gr.Column():
y26_output = gr.Image(type="pil", label="Result")
y26_label = gr.Label(label="Classification Results", visible=False)
y26_task.change(
lambda t: (gr.update(visible=t != "Classification"), gr.update(visible=t == "Classification")),
y26_task,
[y26_output, y26_label],
)
gr.Examples(
examples=[
[str(ASSETS / "bus.jpg"), "YOLO26-N", "Detection", 0.25, 0.45, True],
[str(ASSETS / "bus.jpg"), "YOLO26-N", "Segmentation", 0.25, 0.45, True],
[str(ASSETS / "zidane.jpg"), "YOLO26-N", "Pose", 0.25, 0.45, True],
[str(OBB_IMAGE), "YOLO26-N", "OBB", 0.25, 0.45, True],
[str(ASSETS / "bus.jpg"), "YOLO26-N", "Classification", 0.25, 0.45, True],
],
inputs=[y26_image, y26_model, y26_task, y26_conf, y26_iou, y26_retina],
outputs=[y26_output, y26_label],
fn=predict_yolo26,
cache_examples=True,
)
y26_btn.click(
predict_yolo26,
[y26_image, y26_model, y26_task, y26_conf, y26_iou, y26_retina],
[y26_output, y26_label],
)
with gr.Tab("YOLOE26 Open-Vocabulary"):
gr.Markdown("### Open-Vocabulary Segmentation (text prompts)")
with gr.Row():
with gr.Column():
ye_image = gr.Image(type="pil", label="Upload Image")
ye_model = gr.Dropdown(["YOLOE26-N"], value="YOLOE26-N", label="Model")
ye_classes = gr.Textbox(
label="Classes (comma-separated)",
placeholder="e.g. cat, dog, bicycle",
value="person, bus, car",
)
with gr.Accordion("Advanced Settings", open=False):
ye_conf = gr.Slider(0, 1, value=0.2, label="Confidence Threshold")
ye_retina = gr.Checkbox(value=True, label="Retina Masks", info="Higher quality masks, slower inference")
ye_btn = gr.Button("Run Inference", variant="primary")
with gr.Column():
ye_output = gr.Image(type="pil", label="Result")
ye_prompt_state = gr.State(ye_classes.value)
ye_classes.change(lambda s: s, ye_classes, ye_prompt_state)
gr.Examples(
examples=[
[str(ASSETS / "bus.jpg"), "YOLOE26-N", "person, bus, car", 0.2, True],
[str(ASSETS / "zidane.jpg"), "YOLOE26-N", "person, football, grass", 0.2, True],
[str(ASSETS / "bus.jpg"), "YOLOE26-N", "bicycle, traffic light, road", 0.2, True],
],
inputs=[ye_image, ye_model, ye_classes, ye_conf, ye_retina],
outputs=ye_output,
fn=predict_yoloe26,
)
ye_btn.click(
predict_yoloe26,
[ye_image, ye_model, ye_prompt_state, ye_conf, ye_retina],
ye_output,
)
if __name__ == "__main__":
demo.launch(allowed_paths=[str(ASSETS), str(ASSETS.parent)])