V4ldeLund commited on
Commit
2a98dc7
·
verified ·
1 Parent(s): ce620b0

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +125 -4
app.py CHANGED
@@ -1,7 +1,128 @@
 
 
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import cv2
3
  import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ import spaces
8
 
9
+ # Disable matplotlib visualizations inside the backend call (Spaces are headless)
10
+ import utils.visualize as vis
11
+ vis.visualize_segmentation = lambda *args, **kwargs: None # type: ignore
12
 
13
+ from models.model_bank_knn import PatchKNNDetector
14
+ from backbones import get_backbone
15
+ from segmenters import SAM3Segmenter
16
+
17
+ # ZeroGPU: avoid initializing CUDA at import time. Keep everything on CPU until the
18
+ # GPU-decorated inference runs and a slice is attached.
19
+ DEFAULT_DEVICE = "cpu"
20
+
21
+
22
+ @functools.lru_cache(maxsize=1)
23
+ def load_backbone(name: str = "dinov3_small"):
24
+ # Keep on CPU; move to GPU inside infer when available
25
+ return get_backbone(name).to(DEFAULT_DEVICE).eval()
26
+
27
+
28
+ @functools.lru_cache(maxsize=4)
29
+ def load_sam3(prompt: str, device: str):
30
+ """
31
+ Cache SAM3 per (prompt, device) to avoid reloading weights.
32
+ Created only inside inference after GPU slice is granted.
33
+ """
34
+ return SAM3Segmenter(text_prompt=prompt, device=device)
35
+
36
+
37
+ def _make_overlay(rgb_image: np.ndarray, anomaly_map: np.ndarray) -> Image.Image:
38
+ amap = anomaly_map.astype(np.float32)
39
+ amap = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8)
40
+ heat = cv2.applyColorMap((amap * 255).astype(np.uint8), cv2.COLORMAP_JET)
41
+ base_bgr = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
42
+ overlay_bgr = cv2.addWeighted(base_bgr, 0.6, heat, 0.4, 0)
43
+ return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB))
44
+
45
+
46
+ @spaces.GPU # When running on ZeroGPU, this grants a short-lived GPU slice for the call.
47
+ def infer(ref_files, test_file, use_sam3, sam_prompt):
48
+ if not ref_files:
49
+ raise gr.Error("Upload at least one reference image.")
50
+ if test_file is None:
51
+ raise gr.Error("Upload a test image.")
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+
55
+ ref_paths = [f.name if hasattr(f, "name") else f for f in ref_files]
56
+ test_path = test_file.name if hasattr(test_file, "name") else test_file
57
+
58
+ status_lines = []
59
+
60
+ segmenter = None
61
+ if use_sam3:
62
+ if device == "cuda":
63
+ status_lines.append("Using SAM3 on GPU (ZeroGPU slice).")
64
+ else:
65
+ status_lines.append(
66
+ "SAM3 requested but GPU unavailable; running on CPU will be slow."
67
+ )
68
+ segmenter = load_sam3(sam_prompt, device)
69
+
70
+ backbone = load_backbone().to(device)
71
+
72
+ model = PatchKNNDetector(
73
+ backbone=backbone,
74
+ segmenter=segmenter,
75
+ device=device,
76
+ k_neighbors=1,
77
+ )
78
+
79
+ model.fit(ref_paths, n_ref=len(ref_paths))
80
+ image, amap, score = model.predict(test_path)
81
+
82
+ overlay = _make_overlay(image, amap)
83
+ amap_norm = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8)
84
+ amap_img = Image.fromarray((amap_norm * 255).astype(np.uint8))
85
+
86
+ status_text = "\n".join(status_lines) if status_lines else f"Ran on {device}."
87
+ return overlay, amap_img, float(score), status_text
88
+
89
+
90
+ def build_demo():
91
+ with gr.Blocks(title="Patch KNN Anomaly Detection") as demo:
92
+ gr.Markdown(
93
+ "# Patch KNN Anomaly Detection\n"
94
+ "Upload reference (normal) images, one test image, and optionally run SAM3 to focus on a specific object."
95
+ )
96
+
97
+ with gr.Row():
98
+ ref_in = gr.File(
99
+ label="Reference images (good)", file_types=["image"], file_count="multiple"
100
+ )
101
+ test_in = gr.File(label="Test image", file_types=["image"], file_count="single")
102
+
103
+ with gr.Accordion("Foreground segmentation (optional)", open=False):
104
+ use_sam = gr.Checkbox(label="Use SAM3", value=False)
105
+ sam_prompt = gr.Textbox(
106
+ label="Object text prompt (e.g. 'bottle')", value="object", visible=False
107
+ )
108
+ use_sam.change(lambda s: gr.update(visible=s), inputs=use_sam, outputs=sam_prompt)
109
+
110
+ run_btn = gr.Button("Run", variant="primary")
111
+ overlay_out = gr.Image(label="Heatmap overlay", type="pil")
112
+ amap_out = gr.Image(label="Raw anomaly map", type="pil", image_mode="L")
113
+ score_out = gr.Number(label="Image anomaly score (mean top 1% distance)")
114
+ status_out = gr.Markdown(label="Status / tips")
115
+
116
+ run_btn.click(
117
+ infer,
118
+ inputs=[ref_in, test_in, use_sam, sam_prompt],
119
+ outputs=[overlay_out, amap_out, score_out, status_out],
120
+ )
121
+
122
+ return demo
123
+
124
+
125
+ demo = build_demo()
126
+
127
+ if __name__ == "__main__":
128
+ demo.queue(concurrency_count=1).launch()