Spaces:
Sleeping
Sleeping
| """ | |
| app.py β OWLv2 / SAM2 image labeling UI | |
| Tab 1 β Test: upload one image, pick Detection or Segmentation mode, | |
| tune prompts/threshold/size, see instant annotated results. | |
| Tab 2 β Batch: upload multiple images, run in the chosen mode, download a ZIP | |
| containing resized images + coco_export.json. | |
| All artifacts live in a system temp directory β nothing is written to the project. | |
| """ | |
| from __future__ import annotations | |
| # spaces MUST be imported before torch initialises CUDA (i.e. before any | |
| # autolabel import). Do this first, before everything else. | |
| try: | |
| import spaces as _spaces # type: ignore | |
| _ZERO_GPU = True | |
| except (ImportError, RuntimeError): | |
| _spaces = None | |
| _ZERO_GPU = False | |
| import os | |
| os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") | |
| import logging | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from PIL import Image, ImageDraw, ImageFont | |
| load_dotenv() | |
| from autolabel.config import settings | |
| from autolabel.detect import infer as _owlv2_infer | |
| from autolabel.export import build_coco | |
| from autolabel.segment import load_sam2, segment_with_boxes | |
| from autolabel.utils import save_json, setup_logging | |
| setup_logging(logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Temp directory for this session β cleaned up by the OS on reboot | |
| _TMPDIR = Path(tempfile.mkdtemp(prefix="autolabel_")) | |
| logger.info("Session temp dir: %s", _TMPDIR) | |
| # --------------------------------------------------------------------------- | |
| # Image sizing | |
| # --------------------------------------------------------------------------- | |
| _SIZE_OPTIONS = { | |
| "As is": None, | |
| "416 Γ 416": (416, 416), | |
| "480 Γ 480": (480, 480), | |
| "512 Γ 512": (512, 512), | |
| "640 Γ 640": (640, 640), | |
| "736 Γ 736": (736, 736), | |
| "896 Γ 896": (896, 896), | |
| "1024 Γ 1024": (1024, 1024), | |
| } | |
| _SIZE_LABELS = list(_SIZE_OPTIONS.keys()) | |
| def _resize(pil: Image.Image, size_label: str) -> Image.Image: | |
| target = _SIZE_OPTIONS[size_label] | |
| if target is None: | |
| return pil | |
| return pil.resize(target, Image.LANCZOS) | |
| # --------------------------------------------------------------------------- | |
| # Colours & annotation | |
| # --------------------------------------------------------------------------- | |
| _PALETTE = [ | |
| (52, 211, 153), (251, 146, 60), (96, 165, 250), (248, 113, 113), | |
| (167, 139, 250),(250, 204, 21), (34, 211, 238), (244, 114, 182), | |
| (74, 222, 128), (232, 121, 249), (125, 211, 252), (253, 186, 116), | |
| (110, 231, 183),(196, 181, 253), (253, 164, 175), (134, 239, 172), | |
| ] | |
| def _colour_for(label: str, prompts: list[str]) -> tuple[int, int, int]: | |
| try: | |
| return _PALETTE[prompts.index(label) % len(_PALETTE)] | |
| except ValueError: | |
| return _PALETTE[hash(label) % len(_PALETTE)] | |
| def _annotate( | |
| pil_image: Image.Image, | |
| detections: list[dict], | |
| prompts: list[str], | |
| mode: str = "Detection", | |
| ) -> Image.Image: | |
| """Draw bounding boxes (+ mask overlays in Segmentation mode) on *pil_image*.""" | |
| img = pil_image.copy().convert("RGBA") | |
| # --- Segmentation: paint semi-transparent mask overlays first --- | |
| if mode == "Segmentation": | |
| overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) | |
| for det in detections: | |
| mask = det.get("mask") | |
| if mask is None or not isinstance(mask, np.ndarray): | |
| continue | |
| r, g, b = _colour_for(det["label"], prompts) | |
| mask_rgba = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) | |
| mask_rgba[mask] = [r, g, b, 100] # semi-transparent fill | |
| overlay = Image.alpha_composite(overlay, Image.fromarray(mask_rgba, "RGBA")) | |
| img = Image.alpha_composite(img, overlay) | |
| # --- Bounding boxes and labels (both modes) --- | |
| draw = ImageDraw.Draw(img, "RGBA") | |
| try: | |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size=18) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| for det in detections: | |
| x1, y1, x2, y2 = det["box_xyxy"] | |
| r, g, b = _colour_for(det["label"], prompts) | |
| draw.rectangle([x1, y1, x2, y2], outline=(r, g, b), width=3) | |
| tag = f"{det['label']} {det['score']:.2f}" | |
| bbox = draw.textbbox((x1, y1), tag, font=font) | |
| draw.rectangle([bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3], fill=(r, g, b, 210)) | |
| draw.text((x1, y1), tag, fill=(255, 255, 255), font=font) | |
| return img.convert("RGB") | |
| # --------------------------------------------------------------------------- | |
| # OWLv2 model (cached) | |
| # --------------------------------------------------------------------------- | |
| _owlv2_cache: dict = {} | |
| def _get_owlv2(): | |
| if settings.model not in _owlv2_cache: | |
| _owlv2_cache.clear() | |
| from transformers import Owlv2ForObjectDetection, Owlv2Processor | |
| logger.info("Loading OWLv2 %s on %s β¦", settings.model, settings.device) | |
| processor = Owlv2Processor.from_pretrained(settings.model) | |
| model = Owlv2ForObjectDetection.from_pretrained( | |
| settings.model, torch_dtype=settings.torch_dtype | |
| ).to(settings.device) | |
| model.eval() | |
| _owlv2_cache[settings.model] = (processor, model) | |
| logger.info("OWLv2 ready.") | |
| return _owlv2_cache[settings.model] | |
| # --------------------------------------------------------------------------- | |
| # SAM2 model (cached) | |
| # --------------------------------------------------------------------------- | |
| _sam2_cache: dict = {} | |
| _SAM2_MODEL_ID = "facebook/sam2-hiera-tiny" | |
| def _get_sam2(): | |
| if _SAM2_MODEL_ID not in _sam2_cache: | |
| processor, model = load_sam2(settings.device, _SAM2_MODEL_ID) | |
| _sam2_cache[_SAM2_MODEL_ID] = (processor, model) | |
| return _sam2_cache[_SAM2_MODEL_ID] | |
| # --------------------------------------------------------------------------- | |
| # Shared inference helpers | |
| # --------------------------------------------------------------------------- | |
| def _infer_on_device( | |
| pil_image: Image.Image, | |
| prompts: list[str], | |
| threshold: float, | |
| mode: str, | |
| device: str, | |
| dtype, | |
| ) -> list[dict]: | |
| """Run OWLv2 (+ optional SAM2) with explicit device/dtype. | |
| In ZeroGPU mode this is called inside the @spaces.GPU context so CUDA is | |
| available; locally it uses whatever settings.device resolved to. | |
| """ | |
| processor, owlv2 = _get_owlv2() | |
| owlv2.to(device) | |
| try: | |
| detections = _owlv2_infer( | |
| pil_image, processor, owlv2, prompts, threshold, device, dtype, | |
| ) | |
| finally: | |
| if _ZERO_GPU: | |
| owlv2.to("cpu") # release VRAM back to ZeroGPU pool | |
| if mode == "Segmentation" and detections: | |
| sam2_processor, sam2_model = _get_sam2() | |
| sam2_model.to(device) | |
| try: | |
| detections = segment_with_boxes( | |
| pil_image, detections, sam2_processor, sam2_model, device | |
| ) | |
| finally: | |
| if _ZERO_GPU: | |
| sam2_model.to("cpu") | |
| return detections | |
| if _ZERO_GPU: | |
| def _run_detection( | |
| pil_image: Image.Image, | |
| prompts: list[str], | |
| threshold: float, | |
| mode: str, | |
| ) -> list[dict]: | |
| """ZeroGPU entry-point: GPU is allocated for the duration of this call.""" | |
| import torch | |
| return _infer_on_device( | |
| pil_image, prompts, threshold, mode, | |
| device="cuda", dtype=torch.float16, | |
| ) | |
| else: | |
| def _run_detection( | |
| pil_image: Image.Image, | |
| prompts: list[str], | |
| threshold: float, | |
| mode: str, | |
| ) -> list[dict]: | |
| return _infer_on_device( | |
| pil_image, prompts, threshold, mode, | |
| device=settings.device, dtype=settings.torch_dtype, | |
| ) | |
| def _parse_prompts(text: str) -> list[str]: | |
| return [p.strip() for p in text.split(",") if p.strip()] | |
| # --------------------------------------------------------------------------- | |
| # Object crops | |
| # --------------------------------------------------------------------------- | |
| def _make_crops( | |
| pil_image: Image.Image, | |
| detections: list[dict], | |
| prompts: list[str], | |
| mode: str, | |
| ) -> list[tuple[Image.Image, str]]: | |
| """Return one (cropped PIL image, caption) pair per detection. | |
| Detection mode: plain bounding-box crop with a coloured border. | |
| Segmentation mode: tight crop around the mask's nonzero region; pixels | |
| outside the mask are set to white for a clean cutout. | |
| """ | |
| crops: list[tuple[Image.Image, str]] = [] | |
| img_w, img_h = pil_image.size | |
| for det in detections: | |
| x1, y1, x2, y2 = det["box_xyxy"] | |
| x1 = max(0, int(x1)); y1 = max(0, int(y1)) | |
| x2 = min(img_w, int(x2)); y2 = min(img_h, int(y2)) | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| r, g, b = _colour_for(det["label"], prompts) | |
| if mode == "Segmentation": | |
| mask = det.get("mask") | |
| if mask is not None and isinstance(mask, np.ndarray): | |
| # Find the tight bounding box of the mask's nonzero region | |
| rows = np.any(mask, axis=1) | |
| cols = np.any(mask, axis=0) | |
| if rows.any() and cols.any(): | |
| r_min, r_max = int(np.where(rows)[0][0]), int(np.where(rows)[0][-1]) | |
| c_min, c_max = int(np.where(cols)[0][0]), int(np.where(cols)[0][-1]) | |
| mask_tight = mask[r_min:r_max + 1, c_min:c_max + 1] | |
| region = np.array( | |
| pil_image.crop((c_min, r_min, c_max + 1, r_max + 1)).convert("RGB") | |
| ) | |
| # White background outside the mask | |
| region[~mask_tight] = [255, 255, 255] | |
| crop_rgb = Image.fromarray(region) | |
| else: | |
| crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB") | |
| else: | |
| crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB") | |
| else: | |
| crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB") | |
| # Coloured border | |
| bordered = Image.new("RGB", (crop_rgb.width + 6, crop_rgb.height + 6), (r, g, b)) | |
| bordered.paste(crop_rgb, (3, 3)) | |
| caption = f"{det['label']} {det['score']:.2f}" | |
| crops.append((bordered, caption)) | |
| return crops | |
| # --------------------------------------------------------------------------- | |
| # Tab 1 β Test | |
| # --------------------------------------------------------------------------- | |
| def run_test( | |
| image_np: Optional[np.ndarray], | |
| prompts_text: str, | |
| threshold: float, | |
| size_label: str, | |
| mode: str, | |
| ): | |
| if image_np is None or not prompts_text.strip(): | |
| return image_np, [], [] | |
| prompts = _parse_prompts(prompts_text) | |
| if not prompts: | |
| return image_np, [], [] | |
| pil = _resize(Image.fromarray(image_np), size_label) | |
| detections = _run_detection(pil, prompts, threshold, mode) | |
| table = [ | |
| [i + 1, d["label"], f"{d['score']:.3f}", | |
| f"[{d['box_xyxy'][0]:.0f}, {d['box_xyxy'][1]:.0f}, " | |
| f"{d['box_xyxy'][2]:.0f}, {d['box_xyxy'][3]:.0f}]"] | |
| for i, d in enumerate(detections) | |
| ] | |
| crops = _make_crops(pil, detections, prompts, mode) | |
| return np.array(_annotate(pil, detections, prompts, mode)), table, crops | |
| # --------------------------------------------------------------------------- | |
| # Tab 2 β Batch | |
| # --------------------------------------------------------------------------- | |
| def run_batch(files, prompts_text: str, threshold: float, size_label: str, mode: str): | |
| if not files or not prompts_text.strip(): | |
| return [], "Upload images and enter prompts to get started.", None | |
| prompts = _parse_prompts(prompts_text) | |
| if not prompts: | |
| return [], "No valid prompts.", None | |
| # Fresh temp dir for this run | |
| run_dir = _TMPDIR / "current_run" | |
| if run_dir.exists(): | |
| shutil.rmtree(run_dir) | |
| images_dir = run_dir / "images" | |
| images_dir.mkdir(parents=True) | |
| gallery: list[Image.Image] = [] | |
| total_dets = 0 | |
| for f in files: | |
| try: | |
| src = Path(f.name if hasattr(f, "name") else str(f)) | |
| pil = _resize(Image.open(src).convert("RGB"), size_label) | |
| w, h = pil.size | |
| detections = _run_detection(pil, prompts, threshold, mode) | |
| total_dets += len(detections) | |
| # Save resized image (included in the ZIP) | |
| img_name = src.name | |
| pil.save(images_dir / img_name) | |
| # Per-image JSON consumed by build_coco. | |
| # Drop numpy mask arrays β they are not JSON-serialisable. | |
| json_dets = [ | |
| {k: v for k, v in det.items() if k != "mask"} | |
| for det in detections | |
| ] | |
| save_json( | |
| {"image_path": img_name, "image_width": w, | |
| "image_height": h, "detections": json_dets}, | |
| run_dir / (src.stem + ".json"), | |
| ) | |
| gallery.append(_annotate(pil, detections, prompts, mode)) | |
| except Exception: | |
| logger.exception("Failed to process %s", f) | |
| # Build COCO JSON | |
| coco = build_coco(run_dir) | |
| coco_path = run_dir / "coco_export.json" | |
| if coco: | |
| save_json(coco, coco_path) | |
| # Package everything into a ZIP | |
| zip_path = run_dir / "autolabel_export.zip" | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: | |
| if coco_path.exists(): | |
| zf.write(coco_path, "coco_export.json") | |
| for img_file in sorted(images_dir.iterdir()): | |
| zf.write(img_file, f"images/{img_file.name}") | |
| n_ann = len(coco.get("annotations", [])) if coco else 0 | |
| size_note = f" Β· resized to {size_label}" if size_label != "As is" else "" | |
| mode_note = f" Β· {mode.lower()}" | |
| stats = ( | |
| f"{len(gallery)} image(s) Β· {total_dets} detection(s) Β· " | |
| f"{n_ann} annotations{size_note}{mode_note}" | |
| ) | |
| return gallery, stats, str(zip_path) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| _DEFAULT_PROMPTS = ", ".join(settings.prompts[:8]) | |
| _HOW_IT_WORKS_MD = """\ | |
| ## How it works | |
| | Mode | Models | Output | | |
| |------|--------|--------| | |
| | **Detection** | OWLv2 | Bounding boxes + class labels | | |
| | **Segmentation** | OWLv2 β SAM2 | Bounding boxes + pixel masks + COCO polygons | | |
| **Detection** uses [OWLv2](https://huggingface.co/google/owlv2-large-patch14-finetuned), an | |
| open-vocabulary detector that converts your text prompts directly into bounding boxes β no | |
| fixed class list required. | |
| **Segmentation** uses the **Grounded SAM2** pattern: | |
| 1. **OWLv2** reads your text prompts and produces bounding boxes | |
| 2. **SAM2** (`sam2-hiera-tiny`) takes each box as a spatial prompt and refines it into a | |
| pixel-level mask | |
| SAM2 has no concept of text β it only understands spatial prompts (boxes, points, masks). | |
| OWLv2 acts as the *grounding* step, translating words into coordinates that SAM2 can use. | |
| Both models must run in Segmentation mode; Detection mode skips SAM2 entirely. | |
| """ | |
| with gr.Blocks(title="autolabel") as demo: | |
| gr.Markdown("# autolabel β OWLv2 + SAM2") | |
| with gr.Accordion("βΉοΈ How it works", open=False): | |
| gr.Markdown(_HOW_IT_WORKS_MD) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Test ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π§ͺ Test"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| t1_image = gr.Image(label="Image β upload, paste, or pick a sample below", | |
| type="numpy", sources=["upload", "clipboard"]) | |
| t1_mode = gr.Radio( | |
| ["Detection", "Segmentation"], | |
| label="Mode", value="Detection", | |
| info="Detection: OWLv2 β boxes only. " | |
| "Segmentation: OWLv2 β boxes β SAM2 β pixel masks.", | |
| ) | |
| t1_prompts = gr.Textbox(label="Prompts (comma-separated)", | |
| value=_DEFAULT_PROMPTS, lines=2) | |
| t1_threshold = gr.Slider(label="Threshold", minimum=0.01, | |
| maximum=0.9, step=0.01, value=settings.threshold) | |
| t1_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS, | |
| value="As is") | |
| t1_btn = gr.Button("Detect", variant="primary") | |
| with gr.Column(scale=1): | |
| t1_output = gr.Image(label="Result", type="numpy") | |
| t1_table = gr.Dataframe( | |
| headers=["#", "Label", "Score", "Box (xyxy)"], | |
| row_count=(0, "dynamic"), column_count=(4, "fixed"), | |
| ) | |
| t1_crops = gr.Gallery( | |
| label="Object crops", | |
| columns=4, height=220, | |
| object_fit="contain", show_label=True, | |
| ) | |
| # Sample images β click any thumbnail to load it into the image input | |
| _SAMPLES_DIR = Path(__file__).parent / "samples" | |
| gr.Examples( | |
| label="Sample images (click to load)", | |
| examples=[ | |
| [str(_SAMPLES_DIR / "animals.jpg"), "Detection", | |
| "crown, necklace, ball, animal eye", 0.40, "As is"], | |
| [str(_SAMPLES_DIR / "kitchen.jpg"), "Detection", | |
| "apple, banana, orange, broccoli, carrot, bottle, bowl", 0.40, "As is"], | |
| [str(_SAMPLES_DIR / "dog.jpg"), "Detection", | |
| "dog", 0.40, "As is"], | |
| [str(_SAMPLES_DIR / "cat.jpg"), "Detection", | |
| "cat", 0.40, "As is"] | |
| ], | |
| inputs=[t1_image, t1_mode, t1_prompts, t1_threshold, t1_size], | |
| examples_per_page=5, | |
| cache_examples=False, | |
| ) | |
| t1_btn.click( | |
| run_test, | |
| inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode], | |
| outputs=[t1_output, t1_table, t1_crops], | |
| ) | |
| t1_prompts.submit( | |
| run_test, | |
| inputs=[t1_image, t1_prompts, t1_threshold, t1_size, t1_mode], | |
| outputs=[t1_output, t1_table, t1_crops], | |
| ) | |
| # ββ Tab 2: Batch βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Batch"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| t2_files = gr.File(label="Images", file_count="multiple", | |
| file_types=["image"]) | |
| t2_mode = gr.Radio( | |
| ["Detection", "Segmentation"], | |
| label="Mode", value="Detection", | |
| info="Detection: OWLv2 β boxes only. " | |
| "Segmentation: OWLv2 β boxes β SAM2 β pixel masks.", | |
| ) | |
| t2_prompts = gr.Textbox(label="Prompts (comma-separated)", | |
| value=_DEFAULT_PROMPTS, lines=2) | |
| t2_threshold = gr.Slider(label="Threshold", minimum=0.01, | |
| maximum=0.9, step=0.01, value=settings.threshold) | |
| t2_size = gr.Dropdown(label="Input size", choices=_SIZE_LABELS, | |
| value="640 Γ 640") | |
| t2_btn = gr.Button("Run", variant="primary") | |
| t2_stats = gr.Textbox(label="Stats", interactive=False) | |
| t2_download = gr.DownloadButton( | |
| label="Download ZIP (images + COCO JSON)", | |
| visible=False, variant="secondary", size="sm", | |
| ) | |
| with gr.Column(scale=2): | |
| t2_gallery = gr.Gallery(label="Results", columns=3, | |
| height="auto", object_fit="contain") | |
| def _run_and_reveal(files, prompts_text, threshold, size_label, mode): | |
| gallery, stats, zip_path = run_batch( | |
| files, prompts_text, threshold, size_label, mode | |
| ) | |
| return gallery, stats, gr.update(value=zip_path, visible=zip_path is not None) | |
| t2_btn.click( | |
| _run_and_reveal, | |
| inputs=[t2_files, t2_prompts, t2_threshold, t2_size, t2_mode], | |
| outputs=[t2_gallery, t2_stats, t2_download], | |
| ) | |
| demo.queue(max_size=5) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, | |
| share=False, inbrowser=True, theme=gr.themes.Soft()) | |