Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModel | |
| import cv2 | |
| from sklearn.decomposition import PCA | |
| import time | |
| import os | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_CKPT = "assets/dinov2-base" | |
| IMAGE_RES = 448 | |
| LAYERS_STR = "-1, -4,-5" | |
| PCA_EV = 0.99 | |
| AUG_COUNT = 30 | |
| AUG_LIST = ["rotate"] | |
| BATCH_SIZE = 4 | |
| EPS = 1e-6 | |
| def parse_layer_indices(arg_str: str): | |
| return [int(x.strip()) for x in arg_str.split(",")] | |
| LAYERS = parse_layer_indices(LAYERS_STR) | |
| def get_augmentation_transform(aug_list: list): | |
| import torchvision.transforms as T | |
| transforms_list = [] | |
| for aug_name in aug_list: | |
| if aug_name == "rotate": | |
| transforms_list.append(T.RandomRotation(degrees=(0, 345))) | |
| if not transforms_list: | |
| return lambda x: x | |
| return T.Compose(transforms_list) | |
| AUG_TRANSFORM = get_augmentation_transform(AUG_LIST) | |
| def min_max_norm(x: np.ndarray, eps: float = 1e-8) -> np.ndarray: | |
| x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) | |
| x_min = np.min(x, axis=(-1, -2), keepdims=True) | |
| x_max = np.max(x, axis=(-1, -2), keepdims=True) | |
| x_norm = (x - x_min) / (x_max - x_min + eps) | |
| return np.clip(x_norm, 0.0, 1.0) | |
| def pca_reconstruct(X: np.ndarray, pca: dict, drop_k: int = 0) -> np.ndarray: | |
| mu = np.asarray(pca["mu"], dtype=X.dtype) | |
| C = np.asarray(pca["components"][:, : pca["k"]], dtype=X.dtype) | |
| X0 = X - mu | |
| Z = X0 @ C | |
| if drop_k > 0: | |
| if drop_k >= Z.shape[1]: | |
| Z[:] = 0.0 | |
| else: | |
| Z[:, :drop_k] = 0.0 | |
| X_recon = (Z @ C.T) + mu | |
| return X_recon | |
| def _calculate_pca_scores(X: np.ndarray, pca: dict, method: str, drop_k: int = 0): | |
| if method == "reconstruction": | |
| X_recon = pca_reconstruct(X, pca, drop_k=drop_k) | |
| return np.sum((X - X_recon) ** 2, axis=1) | |
| raise ValueError(f"Unknown scoring method '{method}'.") | |
| def calculate_anomaly_scores(X: np.ndarray, pca: dict, method: str = "reconstruction", drop_k: int = 0): | |
| return _calculate_pca_scores(X, pca, method, drop_k) | |
| def post_process_map(anomaly_map: np.ndarray, res, blur: bool = True): | |
| if anomaly_map.dtype != np.float32: | |
| anomaly_map = anomaly_map.astype(np.float32) | |
| dsize = (res, res) if isinstance(res, int) else (res[1], res[0]) | |
| map_resized = cv2.resize(anomaly_map, dsize, interpolation=cv2.INTER_LINEAR) | |
| if blur: | |
| sigma = 4.0 | |
| k_size = 3 | |
| return cv2.GaussianBlur(map_resized, (k_size, k_size), sigma) | |
| else: | |
| return map_resized | |
| def _create_heatmap(anom_map_norm_float: np.ndarray) -> np.ndarray: | |
| anom_map_u8 = (anom_map_norm_float * 255).astype(np.uint8) | |
| return cv2.applyColorMap(anom_map_u8, cv2.COLORMAP_JET) | |
| def blend_visualization(img: Image.Image, anom_map_norm_float: np.ndarray) -> Image.Image: | |
| overlay_intensity = 0.4 | |
| kernel_size = 5 | |
| img_h, img_w = anom_map_norm_float.shape | |
| img_np = np.array(img.resize((img_w, img_h))) | |
| img_np_rgb = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| heatmap = _create_heatmap(anom_map_norm_float) | |
| anom_map_u8 = (anom_map_norm_float * 255).astype(np.uint8) | |
| try: | |
| _, binary_mask = cv2.threshold(anom_map_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| except cv2.error: | |
| binary_mask = np.zeros_like(anom_map_u8) | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| denoised_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) | |
| denoised_mask = cv2.dilate(denoised_mask, kernel, iterations=1) | |
| overlay = cv2.addWeighted(img_np_rgb, (1.0 - overlay_intensity), heatmap, overlay_intensity, 0) | |
| mask_3d = np.stack([denoised_mask] * 3, axis=-1) | |
| final_image = np.where(mask_3d > 0, overlay, img_np_rgb) | |
| return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)) | |
| def compute_image_fingerprint(img: Image.Image): | |
| """ | |
| Cheap, stable-ish fingerprint to detect if the reference image changed. | |
| Resizes to small thumbnail and takes mean pixel value. | |
| """ | |
| img_small = img.convert("RGB").copy() | |
| img_small.thumbnail((64, 64)) | |
| arr = np.array(img_small, dtype=np.float32) | |
| return (img_small.size, float(arr.mean())) | |
| class FeatureExtractor: | |
| def __init__(self, model_ckpt: str): | |
| # Decide if we're loading from a local folder or from HF Hub | |
| is_local = os.path.isdir(model_ckpt) | |
| load_kwargs = { | |
| "local_files_only": is_local, # don't hit network if local | |
| } | |
| # Processor | |
| self.processor = AutoImageProcessor.from_pretrained( | |
| model_ckpt, | |
| **load_kwargs, | |
| ) | |
| # Avoid meta tensors by disabling low_cpu_mem_usage and forcing device_map | |
| device_map = {"": DEVICE} | |
| self.model = AutoModel.from_pretrained( | |
| model_ckpt, | |
| device_map=device_map, | |
| dtype=torch.float32, | |
| low_cpu_mem_usage=False, | |
| **load_kwargs, | |
| ).eval() | |
| self.device = next(self.model.parameters()).device | |
| self.config = self.model.config | |
| def extract_tokens(self, pil_imgs: list, res: int, layers: list, agg_method: str): | |
| size = {"height": res, "width": res} | |
| inputs = self.processor( | |
| images=pil_imgs, | |
| return_tensors="pt", | |
| do_resize=True, | |
| size=size, | |
| do_center_crop=False, | |
| ).to(self.device) | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states | |
| ps = self.config.patch_size | |
| num_reg = getattr(self.config, "num_register_tokens", 0) | |
| drop_front = 1 + num_reg | |
| h_p, w_p = res // ps, res // ps | |
| n_expected = h_p * w_p | |
| def _spatial_converter(x): | |
| return x[:, drop_front: drop_front + n_expected, :].reshape( | |
| x.shape[0], h_p, w_p, x.shape[-1] | |
| ) | |
| feats = [_spatial_converter(hidden_states[li]) for li in layers] | |
| if agg_method == "mean": | |
| fused = torch.stack(feats, dim=0).mean(dim=0) | |
| else: | |
| raise ValueError(f"Unknown aggregation method: '{agg_method}'") | |
| return fused.cpu().numpy(), (h_p, w_p) | |
| GLOBAL_EXTRACTOR = None | |
| def get_extractor(logs=None) -> FeatureExtractor: | |
| global GLOBAL_EXTRACTOR | |
| if GLOBAL_EXTRACTOR is None: | |
| if logs: | |
| logs.append("Loading DINOv2-Base backbone (first run only)...") | |
| t0 = time.time() | |
| GLOBAL_EXTRACTOR = FeatureExtractor(MODEL_CKPT) | |
| if logs: | |
| logs.append(f"Backbone loaded in {time.time() - t0:.1f}s.") | |
| return GLOBAL_EXTRACTOR | |
| INITIAL_STATE = { | |
| "pca_params": None, | |
| "h_p": None, | |
| "w_p": None, | |
| "feature_dim": None, | |
| "calib_p99": None, | |
| "ref_fingerprint": None, # track which reference image PCA was trained on | |
| } | |
| def train_pca_model(reference_image: Image.Image, current_state: dict, logs=None): | |
| if reference_image is None: | |
| msg = "Please upload a normal reference image first." | |
| return msg, current_state | |
| if logs is None: | |
| logs = [] | |
| extractor = get_extractor(logs) | |
| all_imgs = [reference_image] | |
| for _ in range(AUG_COUNT): | |
| all_imgs.append(AUG_TRANSFORM(reference_image)) | |
| total_samples = len(all_imgs) | |
| logs.append(f"Extracting features from {total_samples} samples...") | |
| all_tokens_list = [] | |
| t0 = time.time() | |
| for i in range(0, total_samples, BATCH_SIZE): | |
| img_batch = all_imgs[i: i + BATCH_SIZE] | |
| tokens_batch, (h_p, w_p) = extractor.extract_tokens( | |
| img_batch, IMAGE_RES, LAYERS, "mean" | |
| ) | |
| b, h, w, c = tokens_batch.shape | |
| all_tokens_list.append(tokens_batch.reshape(b * h * w, c)) | |
| feat_time = time.time() - t0 | |
| logs.append(f"Feature extraction done in {feat_time:.1f}s.") | |
| all_train_tokens = np.concatenate(all_tokens_list) | |
| current_state["h_p"], current_state["w_p"], current_state["feature_dim"] = h_p, w_p, c | |
| logs.append(f"Fitting PCA (EV={PCA_EV})...") | |
| t0 = time.time() | |
| pca = PCA(n_components=PCA_EV, svd_solver="full") | |
| pca.fit(all_train_tokens) | |
| pca_time = time.time() - t0 | |
| current_state["pca_params"] = { | |
| "mu": pca.mean_.astype(np.float32), | |
| "components": pca.components_.T.astype(np.float32), | |
| "eigvals": pca.explained_variance_.astype(np.float32), | |
| "k": pca.n_components_, | |
| "eps": EPS, | |
| "whiten": False, | |
| } | |
| train_scores = calculate_anomaly_scores(all_train_tokens, current_state["pca_params"]) | |
| calib_p99 = float(np.quantile(train_scores, 0.99)) | |
| current_state["calib_p99"] = calib_p99 | |
| # Store fingerprint of this reference image | |
| current_state["ref_fingerprint"] = compute_image_fingerprint(reference_image) | |
| logs.append( | |
| f"PCA fitted in {pca_time:.1f}s. " | |
| f"Normal residual calibration (p99): {calib_p99:.3e}" | |
| ) | |
| return "\n".join(logs), current_state | |
| def segment_anomaly(test_image: Image.Image, reference_image: Image.Image, current_state: dict): | |
| logs = [] | |
| if test_image is None: | |
| return None, "Please upload a test image.", current_state | |
| # Decide if we need to (re)train PCA: | |
| need_train = current_state["pca_params"] is None | |
| if reference_image is not None: | |
| new_fp = compute_image_fingerprint(reference_image) | |
| old_fp = current_state.get("ref_fingerprint", None) | |
| if (old_fp is None) or (new_fp != old_fp): | |
| # Reference image changed -> retrain PCA | |
| need_train = True | |
| if need_train: | |
| if reference_image is None: | |
| return None, "Please upload a normal reference image first.", current_state | |
| _, current_state = train_pca_model(reference_image, current_state, logs) | |
| extractor = get_extractor() | |
| pca_params = current_state["pca_params"] | |
| calib_p99 = current_state.get("calib_p99", None) | |
| logs.append("Extracting DINOv2 features for test image...") | |
| t0 = time.time() | |
| tokens, (h_p, w_p) = extractor.extract_tokens([test_image], IMAGE_RES, LAYERS, "mean") | |
| b, h, w, c = tokens.shape | |
| tokens_reshaped = tokens.reshape(b * h * w, c) | |
| logs.append(f"Feature extraction done in {time.time() - t0:.1f}s.") | |
| logs.append("Computing reconstruction error...") | |
| scores = calculate_anomaly_scores(tokens_reshaped, pca_params) | |
| if calib_p99 is not None and calib_p99 > 0: | |
| scores = scores - calib_p99 | |
| anomaly_map_raw = scores.reshape(h, w) | |
| logs.append("Post-processing anomaly map...") | |
| anomaly_map_final = post_process_map(anomaly_map_raw, IMAGE_RES) | |
| anomaly_map_normalized = min_max_norm(anomaly_map_final) | |
| overlay = blend_visualization(test_image, anomaly_map_normalized) | |
| logs.append("Segmentation complete.") | |
| return overlay, "\n".join(logs), current_state | |
| def warmup(): | |
| logs = ["Initializing model on server..."] | |
| get_extractor(logs) | |
| return "\n".join(logs) | |
| with gr.Blocks(title="SubspaceAD – One-Shot Anomaly Segmentation") as demo: | |
| gr.Markdown( | |
| """ | |
| # SubspaceAD – One-Shot Anomaly Segmentation (Demo) | |
| Upload a normal reference image and a test image. | |
| SubspaceAD fits a PCA subspace over DINOv2 patch embeddings and highlights deviations. | |
| """ | |
| ) | |
| # Use a copy so the dict object isn't shared unexpectedly | |
| pca_state = gr.State(INITIAL_STATE.copy()) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Reference – define normal appearance") | |
| ref_image_input = gr.Image(label="Reference image (normal)", type="pil", height=448) | |
| gr.Markdown("### Test – segment anomalies") | |
| test_image_input = gr.Image(label="Test image (normal or anomalous)", type="pil", height=448) | |
| segment_button = gr.Button("Run anomaly segmentation") | |
| gr.Markdown("### Try it instantly – click an example") | |
| gr.Examples( | |
| examples=[ | |
| ["./assets/example_hazelnut_ref.png", "./assets/example_hazelnut_test.png"], | |
| ["./assets/example_bottle_ref.png", "./assets/example_bottle_test.png"], | |
| ], | |
| inputs=[ref_image_input, test_image_input], | |
| label="MVTec-AD Examples" | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Output") | |
| output_image = gr.Image( | |
| label="Anomaly overlay (448×448; red/yellow ≈ high anomaly)", | |
| type="pil", | |
| height=448, | |
| ) | |
| with gr.Accordion("Paper qualitative examples", open=False): | |
| gr.Image("./assets/mvtec_examples.png", interactive=False) | |
| gr.Image("./assets/visa_examples.png", interactive=False) | |
| status_box = gr.Textbox( | |
| label="Log", | |
| value="Model is initializing. Upload images or click the hazelnut example.", | |
| lines=8, | |
| ) | |
| demo.load(fn=warmup, inputs=None, outputs=status_box) | |
| segment_button.click( | |
| fn=segment_anomaly, | |
| inputs=[test_image_input, ref_image_input, pca_state], | |
| outputs=[output_image, status_box, pca_state], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |