V4ldeLund's picture
Update code/comments from local workspace
ca079c9 verified
import functools
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import torch
import spaces
import utils.visualize as vis
vis.visualize_segmentation = lambda *args, **kwargs: None # type: ignore
from models.model_bank_knn import PatchKNNDetector
from backbones import get_backbone
from segmenters import SAM3Segmenter
DEFAULT_DEVICE = "cpu"
@functools.lru_cache(maxsize=1)
def load_backbone(name: str = "dinov3_small"):
# Keep on CPU will move to gpu if available
return get_backbone(name).to(DEFAULT_DEVICE).eval()
@functools.lru_cache(maxsize=4)
def load_sam3(prompt: str, device: str):
"""
Cache SAM3 per (prompt, device) to avoid reloading weights.
Created only inside inference after GPU slice is granted.
"""
if SAM3Segmenter is None:
raise gr.Error(
"SAM3 is unavailable (transformers build missing Sam3Processor/Sam3Model). "
"This Space installs transformers from GitHub; if you still see this, restart the Space "
"to rebuild with the latest image."
)
try:
return SAM3Segmenter(text_prompt=prompt, device=device)
except OSError as e:
# Common for gated models when HF_TOKEN is not set
raise gr.Error(
"Failed to load SAM3. The model is gated; please add your HF token as a Space secret "
"named HF_TOKEN (Settings → Secrets) and restart the Space.\n\n"
f"Loader error: {e}"
)
def _make_overlay(rgb_image: np.ndarray, anomaly_map: np.ndarray) -> Image.Image:
amap = anomaly_map.astype(np.float32)
amap = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8)
heat = cv2.applyColorMap((amap * 255).astype(np.uint8), cv2.COLORMAP_JET)
base_bgr = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
overlay_bgr = cv2.addWeighted(base_bgr, 0.6, heat, 0.4, 0)
return Image.fromarray(cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB))
@spaces.GPU
def infer(ref_files, test_file, use_sam3, sam_prompt):
if not ref_files:
raise gr.Error("Upload at least one reference image.")
if test_file is None:
raise gr.Error("Upload a test image.")
device = "cuda" if torch.cuda.is_available() else "cpu"
ref_paths = [f.name if hasattr(f, "name") else f for f in ref_files]
test_path = test_file.name if hasattr(test_file, "name") else test_file
status_lines = []
segmenter = None
if use_sam3:
if device == "cuda":
status_lines.append("Using SAM3 on GPU (ZeroGPU slice).")
else:
status_lines.append(
"SAM3 requested but GPU unavailable; running on CPU will be slow."
)
segmenter = load_sam3(sam_prompt, device)
backbone = load_backbone().to(device)
model = PatchKNNDetector(
backbone=backbone,
segmenter=segmenter,
device=device,
k_neighbors=1,
)
model.fit(ref_paths, n_ref=len(ref_paths))
image, amap, score = model.predict(test_path)
overlay = _make_overlay(image, amap)
status_text = "\n".join(status_lines) if status_lines else f"Ran on {device}."
return overlay, status_text
def build_demo():
with gr.Blocks(title="Patch KNN Anomaly Detection") as demo:
gr.Markdown(
"# Patch KNN Anomaly Detection\n"
"Upload reference (good) images, one test image, and optionally run SAM3 to segment the specific foreground object."
)
with gr.Row():
ref_in = gr.File(
label="Reference images (good)", file_types=["image"], file_count="multiple"
)
test_in = gr.File(label="Test image", file_types=["image"], file_count="single")
with gr.Accordion("Foreground segmentation (optional)", open=False):
use_sam = gr.Checkbox(label="Use SAM3", value=False)
sam_prompt = gr.Textbox(
label="Object text prompt (e.g. 'bottle')", value="object", visible=False
)
use_sam.change(lambda s: gr.update(visible=s), inputs=use_sam, outputs=sam_prompt)
run_btn = gr.Button("Run", variant="primary")
overlay_out = gr.Image(label="Heatmap overlay", type="pil")
status_out = gr.Markdown(label="Status / tips")
run_btn.click(
infer,
inputs=[ref_in, test_in, use_sam, sam_prompt],
outputs=[overlay_out, status_out],
)
return demo
demo = build_demo()
if __name__ == "__main__":
# Gradio 6 removed concurrency_count kwarg; use default queue
demo.queue().launch()