import gradio as gr import numpy as np import torch from PIL import Image from transformers import AutoImageProcessor, AutoModel import cv2 from sklearn.decomposition import PCA import time import os DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_CKPT = "assets/dinov2-base" IMAGE_RES = 448 LAYERS_STR = "-1, -4,-5" PCA_EV = 0.99 AUG_COUNT = 30 AUG_LIST = ["rotate"] BATCH_SIZE = 4 EPS = 1e-6 def parse_layer_indices(arg_str: str): return [int(x.strip()) for x in arg_str.split(",")] LAYERS = parse_layer_indices(LAYERS_STR) def get_augmentation_transform(aug_list: list): import torchvision.transforms as T transforms_list = [] for aug_name in aug_list: if aug_name == "rotate": transforms_list.append(T.RandomRotation(degrees=(0, 345))) if not transforms_list: return lambda x: x return T.Compose(transforms_list) AUG_TRANSFORM = get_augmentation_transform(AUG_LIST) def min_max_norm(x: np.ndarray, eps: float = 1e-8) -> np.ndarray: x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) x_min = np.min(x, axis=(-1, -2), keepdims=True) x_max = np.max(x, axis=(-1, -2), keepdims=True) x_norm = (x - x_min) / (x_max - x_min + eps) return np.clip(x_norm, 0.0, 1.0) def pca_reconstruct(X: np.ndarray, pca: dict, drop_k: int = 0) -> np.ndarray: mu = np.asarray(pca["mu"], dtype=X.dtype) C = np.asarray(pca["components"][:, : pca["k"]], dtype=X.dtype) X0 = X - mu Z = X0 @ C if drop_k > 0: if drop_k >= Z.shape[1]: Z[:] = 0.0 else: Z[:, :drop_k] = 0.0 X_recon = (Z @ C.T) + mu return X_recon def _calculate_pca_scores(X: np.ndarray, pca: dict, method: str, drop_k: int = 0): if method == "reconstruction": X_recon = pca_reconstruct(X, pca, drop_k=drop_k) return np.sum((X - X_recon) ** 2, axis=1) raise ValueError(f"Unknown scoring method '{method}'.") def calculate_anomaly_scores(X: np.ndarray, pca: dict, method: str = "reconstruction", drop_k: int = 0): return _calculate_pca_scores(X, pca, method, drop_k) def post_process_map(anomaly_map: np.ndarray, res, blur: bool = True): if anomaly_map.dtype != np.float32: anomaly_map = anomaly_map.astype(np.float32) dsize = (res, res) if isinstance(res, int) else (res[1], res[0]) map_resized = cv2.resize(anomaly_map, dsize, interpolation=cv2.INTER_LINEAR) if blur: sigma = 4.0 k_size = 3 return cv2.GaussianBlur(map_resized, (k_size, k_size), sigma) else: return map_resized def _create_heatmap(anom_map_norm_float: np.ndarray) -> np.ndarray: anom_map_u8 = (anom_map_norm_float * 255).astype(np.uint8) return cv2.applyColorMap(anom_map_u8, cv2.COLORMAP_JET) def blend_visualization(img: Image.Image, anom_map_norm_float: np.ndarray) -> Image.Image: overlay_intensity = 0.4 kernel_size = 5 img_h, img_w = anom_map_norm_float.shape img_np = np.array(img.resize((img_w, img_h))) img_np_rgb = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) heatmap = _create_heatmap(anom_map_norm_float) anom_map_u8 = (anom_map_norm_float * 255).astype(np.uint8) try: _, binary_mask = cv2.threshold(anom_map_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) except cv2.error: binary_mask = np.zeros_like(anom_map_u8) kernel = np.ones((kernel_size, kernel_size), np.uint8) denoised_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) denoised_mask = cv2.dilate(denoised_mask, kernel, iterations=1) overlay = cv2.addWeighted(img_np_rgb, (1.0 - overlay_intensity), heatmap, overlay_intensity, 0) mask_3d = np.stack([denoised_mask] * 3, axis=-1) final_image = np.where(mask_3d > 0, overlay, img_np_rgb) return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)) def compute_image_fingerprint(img: Image.Image): """ Cheap, stable-ish fingerprint to detect if the reference image changed. Resizes to small thumbnail and takes mean pixel value. """ img_small = img.convert("RGB").copy() img_small.thumbnail((64, 64)) arr = np.array(img_small, dtype=np.float32) return (img_small.size, float(arr.mean())) class FeatureExtractor: def __init__(self, model_ckpt: str): # Decide if we're loading from a local folder or from HF Hub is_local = os.path.isdir(model_ckpt) load_kwargs = { "local_files_only": is_local, # don't hit network if local } # Processor self.processor = AutoImageProcessor.from_pretrained( model_ckpt, **load_kwargs, ) # Avoid meta tensors by disabling low_cpu_mem_usage and forcing device_map device_map = {"": DEVICE} self.model = AutoModel.from_pretrained( model_ckpt, device_map=device_map, dtype=torch.float32, low_cpu_mem_usage=False, **load_kwargs, ).eval() self.device = next(self.model.parameters()).device self.config = self.model.config @torch.no_grad() def extract_tokens(self, pil_imgs: list, res: int, layers: list, agg_method: str): size = {"height": res, "width": res} inputs = self.processor( images=pil_imgs, return_tensors="pt", do_resize=True, size=size, do_center_crop=False, ).to(self.device) outputs = self.model(**inputs, output_hidden_states=True) hidden_states = outputs.hidden_states ps = self.config.patch_size num_reg = getattr(self.config, "num_register_tokens", 0) drop_front = 1 + num_reg h_p, w_p = res // ps, res // ps n_expected = h_p * w_p def _spatial_converter(x): return x[:, drop_front: drop_front + n_expected, :].reshape( x.shape[0], h_p, w_p, x.shape[-1] ) feats = [_spatial_converter(hidden_states[li]) for li in layers] if agg_method == "mean": fused = torch.stack(feats, dim=0).mean(dim=0) else: raise ValueError(f"Unknown aggregation method: '{agg_method}'") return fused.cpu().numpy(), (h_p, w_p) GLOBAL_EXTRACTOR = None def get_extractor(logs=None) -> FeatureExtractor: global GLOBAL_EXTRACTOR if GLOBAL_EXTRACTOR is None: if logs: logs.append("Loading DINOv2-Base backbone (first run only)...") t0 = time.time() GLOBAL_EXTRACTOR = FeatureExtractor(MODEL_CKPT) if logs: logs.append(f"Backbone loaded in {time.time() - t0:.1f}s.") return GLOBAL_EXTRACTOR INITIAL_STATE = { "pca_params": None, "h_p": None, "w_p": None, "feature_dim": None, "calib_p99": None, "ref_fingerprint": None, # track which reference image PCA was trained on } def train_pca_model(reference_image: Image.Image, current_state: dict, logs=None): if reference_image is None: msg = "Please upload a normal reference image first." return msg, current_state if logs is None: logs = [] extractor = get_extractor(logs) all_imgs = [reference_image] for _ in range(AUG_COUNT): all_imgs.append(AUG_TRANSFORM(reference_image)) total_samples = len(all_imgs) logs.append(f"Extracting features from {total_samples} samples...") all_tokens_list = [] t0 = time.time() for i in range(0, total_samples, BATCH_SIZE): img_batch = all_imgs[i: i + BATCH_SIZE] tokens_batch, (h_p, w_p) = extractor.extract_tokens( img_batch, IMAGE_RES, LAYERS, "mean" ) b, h, w, c = tokens_batch.shape all_tokens_list.append(tokens_batch.reshape(b * h * w, c)) feat_time = time.time() - t0 logs.append(f"Feature extraction done in {feat_time:.1f}s.") all_train_tokens = np.concatenate(all_tokens_list) current_state["h_p"], current_state["w_p"], current_state["feature_dim"] = h_p, w_p, c logs.append(f"Fitting PCA (EV={PCA_EV})...") t0 = time.time() pca = PCA(n_components=PCA_EV, svd_solver="full") pca.fit(all_train_tokens) pca_time = time.time() - t0 current_state["pca_params"] = { "mu": pca.mean_.astype(np.float32), "components": pca.components_.T.astype(np.float32), "eigvals": pca.explained_variance_.astype(np.float32), "k": pca.n_components_, "eps": EPS, "whiten": False, } train_scores = calculate_anomaly_scores(all_train_tokens, current_state["pca_params"]) calib_p99 = float(np.quantile(train_scores, 0.99)) current_state["calib_p99"] = calib_p99 # Store fingerprint of this reference image current_state["ref_fingerprint"] = compute_image_fingerprint(reference_image) logs.append( f"PCA fitted in {pca_time:.1f}s. " f"Normal residual calibration (p99): {calib_p99:.3e}" ) return "\n".join(logs), current_state def segment_anomaly(test_image: Image.Image, reference_image: Image.Image, current_state: dict): logs = [] if test_image is None: return None, "Please upload a test image.", current_state # Decide if we need to (re)train PCA: need_train = current_state["pca_params"] is None if reference_image is not None: new_fp = compute_image_fingerprint(reference_image) old_fp = current_state.get("ref_fingerprint", None) if (old_fp is None) or (new_fp != old_fp): # Reference image changed -> retrain PCA need_train = True if need_train: if reference_image is None: return None, "Please upload a normal reference image first.", current_state _, current_state = train_pca_model(reference_image, current_state, logs) extractor = get_extractor() pca_params = current_state["pca_params"] calib_p99 = current_state.get("calib_p99", None) logs.append("Extracting DINOv2 features for test image...") t0 = time.time() tokens, (h_p, w_p) = extractor.extract_tokens([test_image], IMAGE_RES, LAYERS, "mean") b, h, w, c = tokens.shape tokens_reshaped = tokens.reshape(b * h * w, c) logs.append(f"Feature extraction done in {time.time() - t0:.1f}s.") logs.append("Computing reconstruction error...") scores = calculate_anomaly_scores(tokens_reshaped, pca_params) if calib_p99 is not None and calib_p99 > 0: scores = scores - calib_p99 anomaly_map_raw = scores.reshape(h, w) logs.append("Post-processing anomaly map...") anomaly_map_final = post_process_map(anomaly_map_raw, IMAGE_RES) anomaly_map_normalized = min_max_norm(anomaly_map_final) overlay = blend_visualization(test_image, anomaly_map_normalized) logs.append("Segmentation complete.") return overlay, "\n".join(logs), current_state def warmup(): logs = ["Initializing model on server..."] get_extractor(logs) return "\n".join(logs) with gr.Blocks(title="SubspaceAD – One-Shot Anomaly Segmentation") as demo: gr.Markdown( """ # SubspaceAD – One-Shot Anomaly Segmentation (Demo) Upload a normal reference image and a test image. SubspaceAD fits a PCA subspace over DINOv2 patch embeddings and highlights deviations. """ ) # Use a copy so the dict object isn't shared unexpectedly pca_state = gr.State(INITIAL_STATE.copy()) with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Reference – define normal appearance") ref_image_input = gr.Image(label="Reference image (normal)", type="pil", height=448) gr.Markdown("### Test – segment anomalies") test_image_input = gr.Image(label="Test image (normal or anomalous)", type="pil", height=448) segment_button = gr.Button("Run anomaly segmentation") gr.Markdown("### Try it instantly – click an example") gr.Examples( examples=[ ["./assets/example_hazelnut_ref.png", "./assets/example_hazelnut_test.png"], ["./assets/example_bottle_ref.png", "./assets/example_bottle_test.png"], ], inputs=[ref_image_input, test_image_input], label="MVTec-AD Examples" ) with gr.Column(scale=3): gr.Markdown("### Output") output_image = gr.Image( label="Anomaly overlay (448×448; red/yellow ≈ high anomaly)", type="pil", height=448, ) with gr.Accordion("Paper qualitative examples", open=False): gr.Image("./assets/mvtec_examples.png", interactive=False) gr.Image("./assets/visa_examples.png", interactive=False) status_box = gr.Textbox( label="Log", value="Model is initializing. Upload images or click the hazelnut example.", lines=8, ) demo.load(fn=warmup, inputs=None, outputs=status_box) segment_button.click( fn=segment_anomaly, inputs=[test_image_input, ref_image_input, pca_state], outputs=[output_image, status_box, pca_state], ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())