Spaces:
Running
on
Zero
Running
on
Zero
| import functools | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import spaces | |
| import utils.visualize as vis | |
| vis.visualize_segmentation = lambda *args, **kwargs: None # type: ignore | |
| from models.model_bank_knn import PatchKNNDetector | |
| from backbones import get_backbone | |
| from segmenters import SAM3Segmenter | |
| DEFAULT_DEVICE = "cpu" | |
| def load_backbone(name: str = "dinov3_small"): | |
| # Keep on CPU will move to gpu if available | |
| return get_backbone(name).to(DEFAULT_DEVICE).eval() | |
| def load_sam3(prompt: str, device: str): | |
| """ | |
| Cache SAM3 per (prompt, device) to avoid reloading weights. | |
| Created only inside inference after GPU slice is granted. | |
| """ | |
| if SAM3Segmenter is None: | |
| raise gr.Error( | |
| "SAM3 is unavailable (transformers build missing Sam3Processor/Sam3Model). " | |
| "This Space installs transformers from GitHub; if you still see this, restart the Space " | |
| "to rebuild with the latest image." | |
| ) | |
| try: | |
| return SAM3Segmenter(text_prompt=prompt, device=device) | |
| except OSError as e: | |
| # Common for gated models when HF_TOKEN is not set | |
| raise gr.Error( | |
| "Failed to load SAM3. The model is gated; please add your HF token as a Space secret " | |
| "named HF_TOKEN (Settings → Secrets) and restart the Space.\n\n" | |
| f"Loader error: {e}" | |
| ) | |
| def _make_overlay(rgb_image: np.ndarray, anomaly_map: np.ndarray) -> Image.Image: | |
| amap = anomaly_map.astype(np.float32) | |
| amap = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) | |
| heat = cv2.applyColorMap((amap * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| base_bgr = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) | |
| overlay_bgr = cv2.addWeighted(base_bgr, 0.6, heat, 0.4, 0) | |
| return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)) | |
| def infer(ref_files, test_file, use_sam3, sam_prompt): | |
| if not ref_files: | |
| raise gr.Error("Upload at least one reference image.") | |
| if test_file is None: | |
| raise gr.Error("Upload a test image.") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ref_paths = [f.name if hasattr(f, "name") else f for f in ref_files] | |
| test_path = test_file.name if hasattr(test_file, "name") else test_file | |
| status_lines = [] | |
| segmenter = None | |
| if use_sam3: | |
| if device == "cuda": | |
| status_lines.append("Using SAM3 on GPU (ZeroGPU slice).") | |
| else: | |
| status_lines.append( | |
| "SAM3 requested but GPU unavailable; running on CPU will be slow." | |
| ) | |
| segmenter = load_sam3(sam_prompt, device) | |
| backbone = load_backbone().to(device) | |
| model = PatchKNNDetector( | |
| backbone=backbone, | |
| segmenter=segmenter, | |
| device=device, | |
| k_neighbors=1, | |
| ) | |
| model.fit(ref_paths, n_ref=len(ref_paths)) | |
| image, amap, score = model.predict(test_path) | |
| overlay = _make_overlay(image, amap) | |
| status_text = "\n".join(status_lines) if status_lines else f"Ran on {device}." | |
| return overlay, status_text | |
| def build_demo(): | |
| with gr.Blocks(title="Patch KNN Anomaly Detection") as demo: | |
| gr.Markdown( | |
| "# Patch KNN Anomaly Detection\n" | |
| "Upload reference (good) images, one test image, and optionally run SAM3 to segment the specific foreground object." | |
| ) | |
| with gr.Row(): | |
| ref_in = gr.File( | |
| label="Reference images (good)", file_types=["image"], file_count="multiple" | |
| ) | |
| test_in = gr.File(label="Test image", file_types=["image"], file_count="single") | |
| with gr.Accordion("Foreground segmentation (optional)", open=False): | |
| use_sam = gr.Checkbox(label="Use SAM3", value=False) | |
| sam_prompt = gr.Textbox( | |
| label="Object text prompt (e.g. 'bottle')", value="object", visible=False | |
| ) | |
| use_sam.change(lambda s: gr.update(visible=s), inputs=use_sam, outputs=sam_prompt) | |
| run_btn = gr.Button("Run", variant="primary") | |
| overlay_out = gr.Image(label="Heatmap overlay", type="pil") | |
| status_out = gr.Markdown(label="Status / tips") | |
| run_btn.click( | |
| infer, | |
| inputs=[ref_in, test_in, use_sam, sam_prompt], | |
| outputs=[overlay_out, status_out], | |
| ) | |
| return demo | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| # Gradio 6 removed concurrency_count kwarg; use default queue | |
| demo.queue().launch() | |