LabelPlayground / app.py
bytestream89's picture
Upload folder using huggingface_hub
cf8afa8 verified
"""
app.py β€” OWLv2 / SAM2 image labeling UI
Tab 1 β€” Test: upload one image, pick Detection or Segmentation mode,
tune prompts/threshold/size, see instant annotated results.
Tab 2 β€” Batch: upload multiple images, run in the chosen mode, download a ZIP
containing resized images + coco_export.json.
All artifacts live in a system temp directory β€” nothing is written to the project.
"""
from __future__ import annotations
# spaces MUST be imported before torch initialises CUDA (i.e. before any
# autolabel import). Do this first, before everything else.
try:
import spaces as _spaces # type: ignore
_ZERO_GPU = True
except (ImportError, RuntimeError):
_spaces = None
_ZERO_GPU = False
import os
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import logging
import shutil
import tempfile
import zipfile
from pathlib import Path
from typing import Optional
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from PIL import Image, ImageDraw, ImageFont
load_dotenv()
from autolabel.config import settings
from autolabel.detect import infer as _owlv2_infer
from autolabel.export import build_coco
from autolabel.segment import load_sam2, segment_with_boxes
from autolabel.utils import save_json, setup_logging
setup_logging(logging.INFO)
logger = logging.getLogger(__name__)
# Temp directory for this session β€” cleaned up by the OS on reboot
_TMPDIR = Path(tempfile.mkdtemp(prefix="autolabel_"))
logger.info("Session temp dir: %s", _TMPDIR)
# ---------------------------------------------------------------------------
# Image sizing
# ---------------------------------------------------------------------------
_SIZE_OPTIONS = {
"As is": None,
"416 Γ— 416": (416, 416),
"480 Γ— 480": (480, 480),
"512 Γ— 512": (512, 512),
"640 Γ— 640": (640, 640),
"736 Γ— 736": (736, 736),
"896 Γ— 896": (896, 896),
"1024 Γ— 1024": (1024, 1024),
}
_SIZE_LABELS = list(_SIZE_OPTIONS.keys())
def _resize(pil: Image.Image, size_label: str) -> Image.Image:
target = _SIZE_OPTIONS[size_label]
if target is None:
return pil
return pil.resize(target, Image.LANCZOS)
# ---------------------------------------------------------------------------
# Colours & annotation
# ---------------------------------------------------------------------------
_PALETTE = [
(52, 211, 153), (251, 146, 60), (96, 165, 250), (248, 113, 113),
(167, 139, 250),(250, 204, 21), (34, 211, 238), (244, 114, 182),
(74, 222, 128), (232, 121, 249), (125, 211, 252), (253, 186, 116),
(110, 231, 183),(196, 181, 253), (253, 164, 175), (134, 239, 172),
]
def _colour_for(label: str, prompts: list[str]) -> tuple[int, int, int]:
try:
return _PALETTE[prompts.index(label) % len(_PALETTE)]
except ValueError:
return _PALETTE[hash(label) % len(_PALETTE)]
def _annotate(
pil_image: Image.Image,
detections: list[dict],
prompts: list[str],
mode: str = "Detection",
) -> Image.Image:
"""Draw bounding boxes (+ mask overlays in Segmentation mode) on *pil_image*."""
img = pil_image.copy().convert("RGBA")
# --- Segmentation: paint semi-transparent mask overlays first ---
if mode == "Segmentation":
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
for det in detections:
mask = det.get("mask")
if mask is None or not isinstance(mask, np.ndarray):
continue
r, g, b = _colour_for(det["label"], prompts)
mask_rgba = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
mask_rgba[mask] = [r, g, b, 100] # semi-transparent fill
overlay = Image.alpha_composite(overlay, Image.fromarray(mask_rgba, "RGBA"))
img = Image.alpha_composite(img, overlay)
# --- Bounding boxes and labels (both modes) ---
draw = ImageDraw.Draw(img, "RGBA")
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size=18)
except Exception:
font = ImageFont.load_default()
for det in detections:
x1, y1, x2, y2 = det["box_xyxy"]
r, g, b = _colour_for(det["label"], prompts)
draw.rectangle([x1, y1, x2, y2], outline=(r, g, b), width=3)
tag = f"{det['label']} {det['score']:.2f}"
bbox = draw.textbbox((x1, y1), tag, font=font)
draw.rectangle([bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3], fill=(r, g, b, 210))
draw.text((x1, y1), tag, fill=(255, 255, 255), font=font)
return img.convert("RGB")
# ---------------------------------------------------------------------------
# OWLv2 model (cached)
# ---------------------------------------------------------------------------
_owlv2_cache: dict = {}
def _get_owlv2():
if settings.model not in _owlv2_cache:
_owlv2_cache.clear()
from transformers import Owlv2ForObjectDetection, Owlv2Processor
logger.info("Loading OWLv2 %s on %s …", settings.model, settings.device)
processor = Owlv2Processor.from_pretrained(settings.model)
model = Owlv2ForObjectDetection.from_pretrained(
settings.model, torch_dtype=settings.torch_dtype
).to(settings.device)
model.eval()
_owlv2_cache[settings.model] = (processor, model)
logger.info("OWLv2 ready.")
return _owlv2_cache[settings.model]
# ---------------------------------------------------------------------------
# SAM2 model (cached)
# ---------------------------------------------------------------------------
_sam2_cache: dict = {}
_SAM2_MODEL_ID = "facebook/sam2-hiera-tiny"
def _get_sam2():
if _SAM2_MODEL_ID not in _sam2_cache:
processor, model = load_sam2(settings.device, _SAM2_MODEL_ID)
_sam2_cache[_SAM2_MODEL_ID] = (processor, model)
return _sam2_cache[_SAM2_MODEL_ID]
# ---------------------------------------------------------------------------
# Shared inference helpers
# ---------------------------------------------------------------------------
def _infer_on_device(
pil_image: Image.Image,
prompts: list[str],
threshold: float,
mode: str,
device: str,
dtype,
) -> list[dict]:
"""Run OWLv2 (+ optional SAM2) with explicit device/dtype.
In ZeroGPU mode this is called inside the @spaces.GPU context so CUDA is
available; locally it uses whatever settings.device resolved to.
"""
processor, owlv2 = _get_owlv2()
owlv2.to(device)
try:
detections = _owlv2_infer(
pil_image, processor, owlv2, prompts, threshold, device, dtype,
)
finally:
if _ZERO_GPU:
owlv2.to("cpu") # release VRAM back to ZeroGPU pool
if mode == "Segmentation" and detections:
sam2_processor, sam2_model = _get_sam2()
sam2_model.to(device)
try:
detections = segment_with_boxes(
pil_image, detections, sam2_processor, sam2_model, device
)
finally:
if _ZERO_GPU:
sam2_model.to("cpu")
return detections
if _ZERO_GPU:
@_spaces.GPU
def _run_detection(
pil_image: Image.Image,
prompts: list[str],
threshold: float,
mode: str,
) -> list[dict]:
"""ZeroGPU entry-point: GPU is allocated for the duration of this call."""
import torch
return _infer_on_device(
pil_image, prompts, threshold, mode,
device="cuda", dtype=torch.float16,
)
else:
def _run_detection(
pil_image: Image.Image,
prompts: list[str],
threshold: float,
mode: str,
) -> list[dict]:
return _infer_on_device(
pil_image, prompts, threshold, mode,
device=settings.device, dtype=settings.torch_dtype,
)
def _parse_prompts(text: str) -> list[str]:
return [p.strip() for p in text.split(",") if p.strip()]
# ---------------------------------------------------------------------------
# Object crops
# ---------------------------------------------------------------------------
def _make_crops(
pil_image: Image.Image,
detections: list[dict],
prompts: list[str],
mode: str,
) -> list[tuple[Image.Image, str]]:
"""Return one (cropped PIL image, caption) pair per detection.
Detection mode: plain bounding-box crop with a coloured border.
Segmentation mode: tight crop around the mask's nonzero region; pixels
outside the mask are set to white for a clean cutout.
"""
crops: list[tuple[Image.Image, str]] = []
img_w, img_h = pil_image.size
for det in detections:
x1, y1, x2, y2 = det["box_xyxy"]
x1 = max(0, int(x1)); y1 = max(0, int(y1))
x2 = min(img_w, int(x2)); y2 = min(img_h, int(y2))
if x2 <= x1 or y2 <= y1:
continue
r, g, b = _colour_for(det["label"], prompts)
if mode == "Segmentation":
mask = det.get("mask")
if mask is not None and isinstance(mask, np.ndarray):
# Find the tight bounding box of the mask's nonzero region
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if rows.any() and cols.any():
r_min, r_max = int(np.where(rows)[0][0]), int(np.where(rows)[0][-1])
c_min, c_max = int(np.where(cols)[0][0]), int(np.where(cols)[0][-1])
mask_tight = mask[r_min:r_max + 1, c_min:c_max + 1]
region = np.array(
pil_image.crop((c_min, r_min, c_max + 1, r_max + 1)).convert("RGB")
)
# White background outside the mask
region[~mask_tight] = [255, 255, 255]
crop_rgb = Image.fromarray(region)
else:
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
else:
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
else:
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
# Coloured border
bordered = Image.new("RGB", (crop_rgb.width + 6, crop_rgb.height + 6), (r, g, b))
bordered.paste(crop_rgb, (3, 3))
caption = f"{det['label']} {det['score']:.2f}"
crops.append((bordered, caption))
return crops
# ---------------------------------------------------------------------------
# Tab 1 β€” Test
# ---------------------------------------------------------------------------
def run_test(
image_np: Optional[np.ndarray],
prompts_text: str,
threshold: float,
size_label: str,
mode: str,
):
if image_np is None or not prompts_text.strip():
return image_np, [], []
prompts = _parse_prompts(prompts_text)
if not prompts:
return image_np, [], []
pil = _resize(Image.fromarray(image_np), size_label)
detections = _run_detection(pil, prompts, threshold, mode)
table = [
[i + 1, d["label"], f"{d['score']:.3f}",
f"[{d['box_xyxy'][0]:.0f}, {d['box_xyxy'][1]:.0f}, "
f"{d['box_xyxy'][2]:.0f}, {d['box_xyxy'][3]:.0f}]"]
for i, d in enumerate(detections)
]
crops = _make_crops(pil, detections, prompts, mode)
return np.array(_annotate(pil, detections, prompts, mode)), table, crops
# ---------------------------------------------------------------------------
# Tab 2 β€” Batch
# ---------------------------------------------------------------------------
def run_batch(files, prompts_text: str, threshold: float, size_label: str, mode: str):
if not files or not prompts_text.strip():
return [], "Upload images and enter prompts to get started.", None
prompts = _parse_prompts(prompts_text)
if not prompts:
return [], "No valid prompts.", None
# Fresh temp dir for this run
run_dir = _TMPDIR / "current_run"
if run_dir.exists():
shutil.rmtree(run_dir)
images_dir = run_dir / "images"
images_dir.mkdir(parents=True)
gallery: list[Image.Image] = []
total_dets = 0
for f in files:
try:
src = Path(f.name if hasattr(f, "name") else str(f))
pil = _resize(Image.open(src).convert("RGB"), size_label)
w, h = pil.size
detections = _run_detection(pil, prompts, threshold, mode)
total_dets += len(detections)
# Save resized image (included in the ZIP)
img_name = src.name
pil.save(images_dir / img_name)
# Per-image JSON consumed by build_coco.
# Drop numpy mask arrays β€” they are not JSON-serialisable.
json_dets = [
{k: v for k, v in det.items() if k != "mask"}
for det in detections
]
save_json(
{"image_path": img_name, "image_width": w,
"image_height": h, "detections": json_dets},
run_dir / (src.stem + ".json"),
)
gallery.append(_annotate(pil, detections, prompts, mode))
except Exception:
logger.exception("Failed to process %s", f)
# Build COCO JSON
coco = build_coco(run_dir)
coco_path = run_dir / "coco_export.json"
if coco:
save_json(coco, coco_path)
# Package everything into a ZIP
zip_path = run_dir / "autolabel_export.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
if coco_path.exists():
zf.write(coco_path, "coco_export.json")
for img_file in sorted(images_dir.iterdir()):
zf.write(img_file, f"images/{img_file.name}")
n_ann = len(coco.get("annotations", [])) if coco else 0
size_note = f" Β· resized to {size_label}" if size_label != "As is" else ""
mode_note = f" Β· {mode.lower()}"
stats = (
f"{len(gallery)} image(s) Β· {total_dets} detection(s) Β· "
f"{n_ann} annotations{size_note}{mode_note}"
)
return gallery, stats, str(zip_path)
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
_DEFAULT_PROMPTS = ", ".join(settings.prompts[:8])
_HOW_IT_WORKS_MD = """\
## How it works
| Mode | Models | Output |
|------|--------|--------|
| **Detection** | OWLv2 | Bounding boxes + class labels |
| **Segmentation** | OWLv2 β†’ SAM2 | Bounding boxes + pixel masks + COCO polygons |
**Detection** uses [OWLv2](https://huggingface.co/google/owlv2-large-patch14-finetuned), an
open-vocabulary detector that converts your text prompts directly into bounding boxes β€” no
fixed class list required.
**Segmentation** uses the **Grounded SAM2** pattern:
1. **OWLv2** reads your text prompts and produces bounding boxes
2. **SAM2** (`sam2-hiera-tiny`) takes each box as a spatial prompt and refines it into a
pixel-level mask
SAM2 has no concept of text β€” it only understands spatial prompts (boxes, points, masks).
OWLv2 acts as the *grounding* step, translating words into coordinates that SAM2 can use.
Both models must run in Segmentation mode; Detection mode skips SAM2 entirely.
"""
with gr.Blocks(title="autolabel") as demo:
gr.Markdown("# autolabel β€” OWLv2 + SAM2")
with gr.Accordion("ℹ️ How it works", open=False):
gr.Markdown(_HOW_IT_WORKS_MD)
with gr.Tabs():
# ── Tab 1: Test ──────────────────────────────────────────────────
with gr.Tab("πŸ§ͺ Test"):
with gr.Row():
with gr.Column(scale=1):
t1_image = gr.Image(label="Image β€” upload, paste, or pick a sample below",
type="numpy", sources=["upload", "clipboard"])
t1_mode = gr.Radio(
["Detection", "Segmentation"],
label="Mode", value="Detection",
info="Detection: OWLv2 β†’ boxes only. "
"Segmentation: OWLv2 β†’ boxes β†’ SAM2 β†’ pixel masks.",
)
t1_prompts = gr.Textbox(label="Prompts (comma-separated)",
value=_DEFAULT_PROMPTS, lines=2)
t1_threshold = gr.Slider(label="Threshold", minimum=0.01,
maximum=0.9, step=0.01, value=settings.threshold)
t1_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS,
value="As is")
t1_btn = gr.Button("Detect", variant="primary")
with gr.Column(scale=1):
t1_output = gr.Image(label="Result", type="numpy")
t1_table = gr.Dataframe(
headers=["#", "Label", "Score", "Box (xyxy)"],
row_count=(0, "dynamic"), column_count=(4, "fixed"),
)
t1_crops = gr.Gallery(
label="Object crops",
columns=4, height=220,
object_fit="contain", show_label=True,
)
# Sample images β€” click any thumbnail to load it into the image input
_SAMPLES_DIR = Path(__file__).parent / "samples"
gr.Examples(
label="Sample images (click to load)",
examples=[
[str(_SAMPLES_DIR / "animals.jpg"), "Detection",
"crown, necklace, ball, animal eye", 0.40, "As is"],
[str(_SAMPLES_DIR / "kitchen.jpg"), "Detection",
"apple, banana, orange, broccoli, carrot, bottle, bowl", 0.40, "As is"],
[str(_SAMPLES_DIR / "dog.jpg"), "Detection",
"dog", 0.40, "As is"],
[str(_SAMPLES_DIR / "cat.jpg"), "Detection",
"cat", 0.40, "As is"]
],
inputs=[t1_image, t1_mode, t1_prompts, t1_threshold, t1_size],
examples_per_page=5,
cache_examples=False,
)
t1_btn.click(
run_test,
inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode],
outputs=[t1_output, t1_table, t1_crops],
)
t1_prompts.submit(
run_test,
inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode],
outputs=[t1_output, t1_table, t1_crops],
)
# ── Tab 2: Batch ─────────────────────────────────────────────────
with gr.Tab("πŸ“‚ Batch"):
with gr.Row():
with gr.Column(scale=1):
t2_files = gr.File(label="Images", file_count="multiple",
file_types=["image"])
t2_mode = gr.Radio(
["Detection", "Segmentation"],
label="Mode", value="Detection",
info="Detection: OWLv2 β†’ boxes only. "
"Segmentation: OWLv2 β†’ boxes β†’ SAM2 β†’ pixel masks.",
)
t2_prompts = gr.Textbox(label="Prompts (comma-separated)",
value=_DEFAULT_PROMPTS, lines=2)
t2_threshold = gr.Slider(label="Threshold", minimum=0.01,
maximum=0.9, step=0.01, value=settings.threshold)
t2_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS,
value="640 Γ— 640")
t2_btn = gr.Button("Run", variant="primary")
t2_stats = gr.Textbox(label="Stats", interactive=False)
t2_download = gr.DownloadButton(
label="Download ZIP (images + COCO JSON)",
visible=False, variant="secondary", size="sm",
)
with gr.Column(scale=2):
t2_gallery = gr.Gallery(label="Results", columns=3,
height="auto", object_fit="contain")
def _run_and_reveal(files, prompts_text, threshold, size_label, mode):
gallery, stats, zip_path = run_batch(
files, prompts_text, threshold, size_label, mode
)
return gallery, stats, gr.update(value=zip_path, visible=zip_path is not None)
t2_btn.click(
_run_and_reveal,
inputs=[t2_files, t2_prompts, t2_threshold, t2_size, t2_mode],
outputs=[t2_gallery, t2_stats, t2_download],
)
demo.queue(max_size=5)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860,
share=False, inbrowser=True, theme=gr.themes.Soft())