import functools import cv2 import gradio as gr import numpy as np from PIL import Image import torch import spaces import utils.visualize as vis vis.visualize_segmentation = lambda *args, **kwargs: None # type: ignore from models.model_bank_knn import PatchKNNDetector from backbones import get_backbone from segmenters import SAM3Segmenter DEFAULT_DEVICE = "cpu" @functools.lru_cache(maxsize=1) def load_backbone(name: str = "dinov3_small"): # Keep on CPU will move to gpu if available return get_backbone(name).to(DEFAULT_DEVICE).eval() @functools.lru_cache(maxsize=4) def load_sam3(prompt: str, device: str): """ Cache SAM3 per (prompt, device) to avoid reloading weights. Created only inside inference after GPU slice is granted. """ if SAM3Segmenter is None: raise gr.Error( "SAM3 is unavailable (transformers build missing Sam3Processor/Sam3Model). " "This Space installs transformers from GitHub; if you still see this, restart the Space " "to rebuild with the latest image." ) try: return SAM3Segmenter(text_prompt=prompt, device=device) except OSError as e: # Common for gated models when HF_TOKEN is not set raise gr.Error( "Failed to load SAM3. The model is gated; please add your HF token as a Space secret " "named HF_TOKEN (Settings → Secrets) and restart the Space.\n\n" f"Loader error: {e}" ) def _make_overlay(rgb_image: np.ndarray, anomaly_map: np.ndarray) -> Image.Image: amap = anomaly_map.astype(np.float32) amap = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) heat = cv2.applyColorMap((amap * 255).astype(np.uint8), cv2.COLORMAP_JET) base_bgr = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) overlay_bgr = cv2.addWeighted(base_bgr, 0.6, heat, 0.4, 0) return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)) @spaces.GPU def infer(ref_files, test_file, use_sam3, sam_prompt): if not ref_files: raise gr.Error("Upload at least one reference image.") if test_file is None: raise gr.Error("Upload a test image.") device = "cuda" if torch.cuda.is_available() else "cpu" ref_paths = [f.name if hasattr(f, "name") else f for f in ref_files] test_path = test_file.name if hasattr(test_file, "name") else test_file status_lines = [] segmenter = None if use_sam3: if device == "cuda": status_lines.append("Using SAM3 on GPU (ZeroGPU slice).") else: status_lines.append( "SAM3 requested but GPU unavailable; running on CPU will be slow." ) segmenter = load_sam3(sam_prompt, device) backbone = load_backbone().to(device) model = PatchKNNDetector( backbone=backbone, segmenter=segmenter, device=device, k_neighbors=1, ) model.fit(ref_paths, n_ref=len(ref_paths)) image, amap, score = model.predict(test_path) overlay = _make_overlay(image, amap) status_text = "\n".join(status_lines) if status_lines else f"Ran on {device}." return overlay, status_text def build_demo(): with gr.Blocks(title="Patch KNN Anomaly Detection") as demo: gr.Markdown( "# Patch KNN Anomaly Detection\n" "Upload reference (good) images, one test image, and optionally run SAM3 to segment the specific foreground object." ) with gr.Row(): ref_in = gr.File( label="Reference images (good)", file_types=["image"], file_count="multiple" ) test_in = gr.File(label="Test image", file_types=["image"], file_count="single") with gr.Accordion("Foreground segmentation (optional)", open=False): use_sam = gr.Checkbox(label="Use SAM3", value=False) sam_prompt = gr.Textbox( label="Object text prompt (e.g. 'bottle')", value="object", visible=False ) use_sam.change(lambda s: gr.update(visible=s), inputs=use_sam, outputs=sam_prompt) run_btn = gr.Button("Run", variant="primary") overlay_out = gr.Image(label="Heatmap overlay", type="pil") status_out = gr.Markdown(label="Status / tips") run_btn.click( infer, inputs=[ref_in, test_in, use_sam, sam_prompt], outputs=[overlay_out, status_out], ) return demo demo = build_demo() if __name__ == "__main__": # Gradio 6 removed concurrency_count kwarg; use default queue demo.queue().launch()