""" app.py — OWLv2 / SAM2 image labeling UI Tab 1 — Test: upload one image, pick Detection or Segmentation mode, tune prompts/threshold/size, see instant annotated results. Tab 2 — Batch: upload multiple images, run in the chosen mode, download a ZIP containing resized images + coco_export.json. All artifacts live in a system temp directory — nothing is written to the project. """ from __future__ import annotations # spaces MUST be imported before torch initialises CUDA (i.e. before any # autolabel import). Do this first, before everything else. try: import spaces as _spaces # type: ignore _ZERO_GPU = True except (ImportError, RuntimeError): _spaces = None _ZERO_GPU = False import os os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") import logging import shutil import tempfile import zipfile from pathlib import Path from typing import Optional import gradio as gr import numpy as np from dotenv import load_dotenv from PIL import Image, ImageDraw, ImageFont load_dotenv() from autolabel.config import settings from autolabel.detect import infer as _owlv2_infer from autolabel.export import build_coco from autolabel.segment import load_sam2, segment_with_boxes from autolabel.utils import save_json, setup_logging setup_logging(logging.INFO) logger = logging.getLogger(__name__) # Temp directory for this session — cleaned up by the OS on reboot _TMPDIR = Path(tempfile.mkdtemp(prefix="autolabel_")) logger.info("Session temp dir: %s", _TMPDIR) # --------------------------------------------------------------------------- # Image sizing # --------------------------------------------------------------------------- _SIZE_OPTIONS = { "As is": None, "416 × 416": (416, 416), "480 × 480": (480, 480), "512 × 512": (512, 512), "640 × 640": (640, 640), "736 × 736": (736, 736), "896 × 896": (896, 896), "1024 × 1024": (1024, 1024), } _SIZE_LABELS = list(_SIZE_OPTIONS.keys()) def _resize(pil: Image.Image, size_label: str) -> Image.Image: target = _SIZE_OPTIONS[size_label] if target is None: return pil return pil.resize(target, Image.LANCZOS) # --------------------------------------------------------------------------- # Colours & annotation # --------------------------------------------------------------------------- _PALETTE = [ (52, 211, 153), (251, 146, 60), (96, 165, 250), (248, 113, 113), (167, 139, 250),(250, 204, 21), (34, 211, 238), (244, 114, 182), (74, 222, 128), (232, 121, 249), (125, 211, 252), (253, 186, 116), (110, 231, 183),(196, 181, 253), (253, 164, 175), (134, 239, 172), ] def _colour_for(label: str, prompts: list[str]) -> tuple[int, int, int]: try: return _PALETTE[prompts.index(label) % len(_PALETTE)] except ValueError: return _PALETTE[hash(label) % len(_PALETTE)] def _annotate( pil_image: Image.Image, detections: list[dict], prompts: list[str], mode: str = "Detection", ) -> Image.Image: """Draw bounding boxes (+ mask overlays in Segmentation mode) on *pil_image*.""" img = pil_image.copy().convert("RGBA") # --- Segmentation: paint semi-transparent mask overlays first --- if mode == "Segmentation": overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) for det in detections: mask = det.get("mask") if mask is None or not isinstance(mask, np.ndarray): continue r, g, b = _colour_for(det["label"], prompts) mask_rgba = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) mask_rgba[mask] = [r, g, b, 100] # semi-transparent fill overlay = Image.alpha_composite(overlay, Image.fromarray(mask_rgba, "RGBA")) img = Image.alpha_composite(img, overlay) # --- Bounding boxes and labels (both modes) --- draw = ImageDraw.Draw(img, "RGBA") try: font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size=18) except Exception: font = ImageFont.load_default() for det in detections: x1, y1, x2, y2 = det["box_xyxy"] r, g, b = _colour_for(det["label"], prompts) draw.rectangle([x1, y1, x2, y2], outline=(r, g, b), width=3) tag = f"{det['label']} {det['score']:.2f}" bbox = draw.textbbox((x1, y1), tag, font=font) draw.rectangle([bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3], fill=(r, g, b, 210)) draw.text((x1, y1), tag, fill=(255, 255, 255), font=font) return img.convert("RGB") # --------------------------------------------------------------------------- # OWLv2 model (cached) # --------------------------------------------------------------------------- _owlv2_cache: dict = {} def _get_owlv2(): if settings.model not in _owlv2_cache: _owlv2_cache.clear() from transformers import Owlv2ForObjectDetection, Owlv2Processor logger.info("Loading OWLv2 %s on %s …", settings.model, settings.device) processor = Owlv2Processor.from_pretrained(settings.model) model = Owlv2ForObjectDetection.from_pretrained( settings.model, torch_dtype=settings.torch_dtype ).to(settings.device) model.eval() _owlv2_cache[settings.model] = (processor, model) logger.info("OWLv2 ready.") return _owlv2_cache[settings.model] # --------------------------------------------------------------------------- # SAM2 model (cached) # --------------------------------------------------------------------------- _sam2_cache: dict = {} _SAM2_MODEL_ID = "facebook/sam2-hiera-tiny" def _get_sam2(): if _SAM2_MODEL_ID not in _sam2_cache: processor, model = load_sam2(settings.device, _SAM2_MODEL_ID) _sam2_cache[_SAM2_MODEL_ID] = (processor, model) return _sam2_cache[_SAM2_MODEL_ID] # --------------------------------------------------------------------------- # Shared inference helpers # --------------------------------------------------------------------------- def _infer_on_device( pil_image: Image.Image, prompts: list[str], threshold: float, mode: str, device: str, dtype, ) -> list[dict]: """Run OWLv2 (+ optional SAM2) with explicit device/dtype. In ZeroGPU mode this is called inside the @spaces.GPU context so CUDA is available; locally it uses whatever settings.device resolved to. """ processor, owlv2 = _get_owlv2() owlv2.to(device) try: detections = _owlv2_infer( pil_image, processor, owlv2, prompts, threshold, device, dtype, ) finally: if _ZERO_GPU: owlv2.to("cpu") # release VRAM back to ZeroGPU pool if mode == "Segmentation" and detections: sam2_processor, sam2_model = _get_sam2() sam2_model.to(device) try: detections = segment_with_boxes( pil_image, detections, sam2_processor, sam2_model, device ) finally: if _ZERO_GPU: sam2_model.to("cpu") return detections if _ZERO_GPU: @_spaces.GPU def _run_detection( pil_image: Image.Image, prompts: list[str], threshold: float, mode: str, ) -> list[dict]: """ZeroGPU entry-point: GPU is allocated for the duration of this call.""" import torch return _infer_on_device( pil_image, prompts, threshold, mode, device="cuda", dtype=torch.float16, ) else: def _run_detection( pil_image: Image.Image, prompts: list[str], threshold: float, mode: str, ) -> list[dict]: return _infer_on_device( pil_image, prompts, threshold, mode, device=settings.device, dtype=settings.torch_dtype, ) def _parse_prompts(text: str) -> list[str]: return [p.strip() for p in text.split(",") if p.strip()] # --------------------------------------------------------------------------- # Object crops # --------------------------------------------------------------------------- def _make_crops( pil_image: Image.Image, detections: list[dict], prompts: list[str], mode: str, ) -> list[tuple[Image.Image, str]]: """Return one (cropped PIL image, caption) pair per detection. Detection mode: plain bounding-box crop with a coloured border. Segmentation mode: tight crop around the mask's nonzero region; pixels outside the mask are set to white for a clean cutout. """ crops: list[tuple[Image.Image, str]] = [] img_w, img_h = pil_image.size for det in detections: x1, y1, x2, y2 = det["box_xyxy"] x1 = max(0, int(x1)); y1 = max(0, int(y1)) x2 = min(img_w, int(x2)); y2 = min(img_h, int(y2)) if x2 <= x1 or y2 <= y1: continue r, g, b = _colour_for(det["label"], prompts) if mode == "Segmentation": mask = det.get("mask") if mask is not None and isinstance(mask, np.ndarray): # Find the tight bounding box of the mask's nonzero region rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) if rows.any() and cols.any(): r_min, r_max = int(np.where(rows)[0][0]), int(np.where(rows)[0][-1]) c_min, c_max = int(np.where(cols)[0][0]), int(np.where(cols)[0][-1]) mask_tight = mask[r_min:r_max + 1, c_min:c_max + 1] region = np.array( pil_image.crop((c_min, r_min, c_max + 1, r_max + 1)).convert("RGB") ) # White background outside the mask region[~mask_tight] = [255, 255, 255] crop_rgb = Image.fromarray(region) else: crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB") else: crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB") else: crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB") # Coloured border bordered = Image.new("RGB", (crop_rgb.width + 6, crop_rgb.height + 6), (r, g, b)) bordered.paste(crop_rgb, (3, 3)) caption = f"{det['label']} {det['score']:.2f}" crops.append((bordered, caption)) return crops # --------------------------------------------------------------------------- # Tab 1 — Test # --------------------------------------------------------------------------- def run_test( image_np: Optional[np.ndarray], prompts_text: str, threshold: float, size_label: str, mode: str, ): if image_np is None or not prompts_text.strip(): return image_np, [], [] prompts = _parse_prompts(prompts_text) if not prompts: return image_np, [], [] pil = _resize(Image.fromarray(image_np), size_label) detections = _run_detection(pil, prompts, threshold, mode) table = [ [i + 1, d["label"], f"{d['score']:.3f}", f"[{d['box_xyxy'][0]:.0f}, {d['box_xyxy'][1]:.0f}, " f"{d['box_xyxy'][2]:.0f}, {d['box_xyxy'][3]:.0f}]"] for i, d in enumerate(detections) ] crops = _make_crops(pil, detections, prompts, mode) return np.array(_annotate(pil, detections, prompts, mode)), table, crops # --------------------------------------------------------------------------- # Tab 2 — Batch # --------------------------------------------------------------------------- def run_batch(files, prompts_text: str, threshold: float, size_label: str, mode: str): if not files or not prompts_text.strip(): return [], "Upload images and enter prompts to get started.", None prompts = _parse_prompts(prompts_text) if not prompts: return [], "No valid prompts.", None # Fresh temp dir for this run run_dir = _TMPDIR / "current_run" if run_dir.exists(): shutil.rmtree(run_dir) images_dir = run_dir / "images" images_dir.mkdir(parents=True) gallery: list[Image.Image] = [] total_dets = 0 for f in files: try: src = Path(f.name if hasattr(f, "name") else str(f)) pil = _resize(Image.open(src).convert("RGB"), size_label) w, h = pil.size detections = _run_detection(pil, prompts, threshold, mode) total_dets += len(detections) # Save resized image (included in the ZIP) img_name = src.name pil.save(images_dir / img_name) # Per-image JSON consumed by build_coco. # Drop numpy mask arrays — they are not JSON-serialisable. json_dets = [ {k: v for k, v in det.items() if k != "mask"} for det in detections ] save_json( {"image_path": img_name, "image_width": w, "image_height": h, "detections": json_dets}, run_dir / (src.stem + ".json"), ) gallery.append(_annotate(pil, detections, prompts, mode)) except Exception: logger.exception("Failed to process %s", f) # Build COCO JSON coco = build_coco(run_dir) coco_path = run_dir / "coco_export.json" if coco: save_json(coco, coco_path) # Package everything into a ZIP zip_path = run_dir / "autolabel_export.zip" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: if coco_path.exists(): zf.write(coco_path, "coco_export.json") for img_file in sorted(images_dir.iterdir()): zf.write(img_file, f"images/{img_file.name}") n_ann = len(coco.get("annotations", [])) if coco else 0 size_note = f" · resized to {size_label}" if size_label != "As is" else "" mode_note = f" · {mode.lower()}" stats = ( f"{len(gallery)} image(s) · {total_dets} detection(s) · " f"{n_ann} annotations{size_note}{mode_note}" ) return gallery, stats, str(zip_path) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- _DEFAULT_PROMPTS = ", ".join(settings.prompts[:8]) _HOW_IT_WORKS_MD = """\ ## How it works | Mode | Models | Output | |------|--------|--------| | **Detection** | OWLv2 | Bounding boxes + class labels | | **Segmentation** | OWLv2 → SAM2 | Bounding boxes + pixel masks + COCO polygons | **Detection** uses [OWLv2](https://huggingface.co/google/owlv2-large-patch14-finetuned), an open-vocabulary detector that converts your text prompts directly into bounding boxes — no fixed class list required. **Segmentation** uses the **Grounded SAM2** pattern: 1. **OWLv2** reads your text prompts and produces bounding boxes 2. **SAM2** (`sam2-hiera-tiny`) takes each box as a spatial prompt and refines it into a pixel-level mask SAM2 has no concept of text — it only understands spatial prompts (boxes, points, masks). OWLv2 acts as the *grounding* step, translating words into coordinates that SAM2 can use. Both models must run in Segmentation mode; Detection mode skips SAM2 entirely. """ with gr.Blocks(title="autolabel") as demo: gr.Markdown("# autolabel — OWLv2 + SAM2") with gr.Accordion("ℹ️ How it works", open=False): gr.Markdown(_HOW_IT_WORKS_MD) with gr.Tabs(): # ── Tab 1: Test ────────────────────────────────────────────────── with gr.Tab("🧪 Test"): with gr.Row(): with gr.Column(scale=1): t1_image = gr.Image(label="Image — upload, paste, or pick a sample below", type="numpy", sources=["upload", "clipboard"]) t1_mode = gr.Radio( ["Detection", "Segmentation"], label="Mode", value="Detection", info="Detection: OWLv2 → boxes only. " "Segmentation: OWLv2 → boxes → SAM2 → pixel masks.", ) t1_prompts = gr.Textbox(label="Prompts (comma-separated)", value=_DEFAULT_PROMPTS, lines=2) t1_threshold = gr.Slider(label="Threshold", minimum=0.01, maximum=0.9, step=0.01, value=settings.threshold) t1_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS, value="As is") t1_btn = gr.Button("Detect", variant="primary") with gr.Column(scale=1): t1_output = gr.Image(label="Result", type="numpy") t1_table = gr.Dataframe( headers=["#", "Label", "Score", "Box (xyxy)"], row_count=(0, "dynamic"), column_count=(4, "fixed"), ) t1_crops = gr.Gallery( label="Object crops", columns=4, height=220, object_fit="contain", show_label=True, ) # Sample images — click any thumbnail to load it into the image input _SAMPLES_DIR = Path(__file__).parent / "samples" gr.Examples( label="Sample images (click to load)", examples=[ [str(_SAMPLES_DIR / "animals.jpg"), "Detection", "crown, necklace, ball, animal eye", 0.40, "As is"], [str(_SAMPLES_DIR / "kitchen.jpg"), "Detection", "apple, banana, orange, broccoli, carrot, bottle, bowl", 0.40, "As is"], [str(_SAMPLES_DIR / "dog.jpg"), "Detection", "dog", 0.40, "As is"], [str(_SAMPLES_DIR / "cat.jpg"), "Detection", "cat", 0.40, "As is"] ], inputs=[t1_image, t1_mode, t1_prompts, t1_threshold, t1_size], examples_per_page=5, cache_examples=False, ) t1_btn.click( run_test, inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode], outputs=[t1_output, t1_table, t1_crops], ) t1_prompts.submit( run_test, inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode], outputs=[t1_output, t1_table, t1_crops], ) # ── Tab 2: Batch ───────────────────────────────────────────────── with gr.Tab("📂 Batch"): with gr.Row(): with gr.Column(scale=1): t2_files = gr.File(label="Images", file_count="multiple", file_types=["image"]) t2_mode = gr.Radio( ["Detection", "Segmentation"], label="Mode", value="Detection", info="Detection: OWLv2 → boxes only. " "Segmentation: OWLv2 → boxes → SAM2 → pixel masks.", ) t2_prompts = gr.Textbox(label="Prompts (comma-separated)", value=_DEFAULT_PROMPTS, lines=2) t2_threshold = gr.Slider(label="Threshold", minimum=0.01, maximum=0.9, step=0.01, value=settings.threshold) t2_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS, value="640 × 640") t2_btn = gr.Button("Run", variant="primary") t2_stats = gr.Textbox(label="Stats", interactive=False) t2_download = gr.DownloadButton( label="Download ZIP (images + COCO JSON)", visible=False, variant="secondary", size="sm", ) with gr.Column(scale=2): t2_gallery = gr.Gallery(label="Results", columns=3, height="auto", object_fit="contain") def _run_and_reveal(files, prompts_text, threshold, size_label, mode): gallery, stats, zip_path = run_batch( files, prompts_text, threshold, size_label, mode ) return gallery, stats, gr.update(value=zip_path, visible=zip_path is not None) t2_btn.click( _run_and_reveal, inputs=[t2_files, t2_prompts, t2_threshold, t2_size, t2_mode], outputs=[t2_gallery, t2_stats, t2_download], ) demo.queue(max_size=5) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True, theme=gr.themes.Soft())