File size: 4,589 Bytes
2a98dc7
 
ce620b0
2a98dc7
 
 
 
ce620b0
2a98dc7
 
ce620b0
2a98dc7
 
 
 
ca079c9
2a98dc7
 
 
 
 
ca079c9
2a98dc7
 
 
 
 
 
 
 
 
24113f1
58606bd
 
 
 
 
281a357
 
 
 
 
 
 
 
 
2a98dc7
 
 
 
 
 
 
 
 
 
 
ca079c9
2a98dc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722c3cc
2a98dc7
 
 
 
 
 
ca079c9
2a98dc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722c3cc
2a98dc7
 
 
 
 
 
 
 
639defb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import functools
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import torch
import spaces

import utils.visualize as vis
vis.visualize_segmentation = lambda *args, **kwargs: None  # type: ignore

from models.model_bank_knn import PatchKNNDetector
from backbones import get_backbone
from segmenters import SAM3Segmenter


DEFAULT_DEVICE = "cpu"


@functools.lru_cache(maxsize=1)
def load_backbone(name: str = "dinov3_small"):
    # Keep on CPU will move to gpu if available
    return get_backbone(name).to(DEFAULT_DEVICE).eval()


@functools.lru_cache(maxsize=4)
def load_sam3(prompt: str, device: str):
    """
    Cache SAM3 per (prompt, device) to avoid reloading weights.
    Created only inside inference after GPU slice is granted.
    """
    if SAM3Segmenter is None:
        raise gr.Error(
            "SAM3 is unavailable (transformers build missing Sam3Processor/Sam3Model). "
            "This Space installs transformers from GitHub; if you still see this, restart the Space "
            "to rebuild with the latest image."
        )
    try:
        return SAM3Segmenter(text_prompt=prompt, device=device)
    except OSError as e:
        # Common for gated models when HF_TOKEN is not set
        raise gr.Error(
            "Failed to load SAM3. The model is gated; please add your HF token as a Space secret "
            "named HF_TOKEN (Settings → Secrets) and restart the Space.\n\n"
            f"Loader error: {e}"
        )


def _make_overlay(rgb_image: np.ndarray, anomaly_map: np.ndarray) -> Image.Image:
    amap = anomaly_map.astype(np.float32)
    amap = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8)
    heat = cv2.applyColorMap((amap * 255).astype(np.uint8), cv2.COLORMAP_JET)
    base_bgr = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
    overlay_bgr = cv2.addWeighted(base_bgr, 0.6, heat, 0.4, 0)
    return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB))


@spaces.GPU
def infer(ref_files, test_file, use_sam3, sam_prompt):
    if not ref_files:
        raise gr.Error("Upload at least one reference image.")
    if test_file is None:
        raise gr.Error("Upload a test image.")

    device = "cuda" if torch.cuda.is_available() else "cpu"

    ref_paths = [f.name if hasattr(f, "name") else f for f in ref_files]
    test_path = test_file.name if hasattr(test_file, "name") else test_file

    status_lines = []

    segmenter = None
    if use_sam3:
        if device == "cuda":
            status_lines.append("Using SAM3 on GPU (ZeroGPU slice).")
        else:
            status_lines.append(
                "SAM3 requested but GPU unavailable; running on CPU will be slow."
            )
        segmenter = load_sam3(sam_prompt, device)

    backbone = load_backbone().to(device)

    model = PatchKNNDetector(
        backbone=backbone,
        segmenter=segmenter,
        device=device,
        k_neighbors=1,
    )

    model.fit(ref_paths, n_ref=len(ref_paths))
    image, amap, score = model.predict(test_path)

    overlay = _make_overlay(image, amap)

    status_text = "\n".join(status_lines) if status_lines else f"Ran on {device}."
    return overlay, status_text


def build_demo():
    with gr.Blocks(title="Patch KNN Anomaly Detection") as demo:
        gr.Markdown(
            "# Patch KNN Anomaly Detection\n"
            "Upload reference (good) images, one test image, and optionally run SAM3 to segment the specific foreground object."
        )

        with gr.Row():
            ref_in = gr.File(
                label="Reference images (good)", file_types=["image"], file_count="multiple"
            )
            test_in = gr.File(label="Test image", file_types=["image"], file_count="single")

        with gr.Accordion("Foreground segmentation (optional)", open=False):
            use_sam = gr.Checkbox(label="Use SAM3", value=False)
            sam_prompt = gr.Textbox(
                label="Object text prompt (e.g. 'bottle')", value="object", visible=False
            )
        use_sam.change(lambda s: gr.update(visible=s), inputs=use_sam, outputs=sam_prompt)

        run_btn = gr.Button("Run", variant="primary")
        overlay_out = gr.Image(label="Heatmap overlay", type="pil")
        status_out = gr.Markdown(label="Status / tips")

        run_btn.click(
            infer,
            inputs=[ref_in, test_in, use_sam, sam_prompt],
            outputs=[overlay_out, status_out],
        )

    return demo


demo = build_demo()

if __name__ == "__main__":
    # Gradio 6 removed concurrency_count kwarg; use default queue
    demo.queue().launch()