Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import ViTModel, ViTImageProcessor | |
| import numpy as np | |
| from typing import Tuple, Optional | |
| import cv2 | |
| class ChangeDetector: | |
| """ | |
| Change detection model using Siamese ViT architecture. | |
| Detects changes between two temporal satellite images. | |
| Produces spatial 2D confidence maps and masks. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "google/vit-base-patch16-224", | |
| device: Optional[str] = None | |
| ): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model_name = model_name | |
| self.processor = ViTImageProcessor.from_pretrained(model_name) | |
| self.encoder = ViTModel.from_pretrained(model_name) | |
| self.encoder.to(self.device) | |
| self.encoder.eval() | |
| hidden_size = self.encoder.config.hidden_size | |
| # Lightweight head to score each patch token | |
| self.patch_head = nn.Sequential( | |
| nn.Linear(hidden_size * 2, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, 1), | |
| nn.Sigmoid() | |
| ).to(self.device) | |
| def _encode_patches(self, image: np.ndarray) -> torch.Tensor: | |
| """ | |
| Encodes the image and return ALL patch tokens (not just CLS). | |
| Returns: | |
| Tensor of shape (num_patches, hidden_size) | |
| """ | |
| # Ensure uint8 [0,255] for processor | |
| if image.dtype != np.uint8: | |
| img_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8) | |
| else: | |
| img_uint8 = image | |
| inputs = self.processor(images=img_uint8, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.encoder(**inputs) | |
| # last_hidden_state: (1, 1+num_patches, hidden_size) | |
| # index 0 is CLS, 1: are patch tokens | |
| patch_tokens = outputs.last_hidden_state[0, 1:, :] # (num_patches, H) | |
| return patch_tokens | |
| def detect_changes( | |
| self, | |
| before_image: np.ndarray, | |
| after_image: np.ndarray, | |
| threshold: float = 0.5 | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Detect changes between two temporal images. | |
| Returns: | |
| change_mask: 2D binary array (H, W) matching input image size | |
| confidence_map: 2D float array (H, W) in [0, 1] | |
| """ | |
| h, w = before_image.shape[:2] | |
| before_patches = self._encode_patches(before_image) # (N, D) | |
| after_patches = self._encode_patches(after_image) # (N, D) | |
| # Concatenate patch-wise features | |
| combined = torch.cat([before_patches, after_patches], dim=-1) # (N, 2D) | |
| with torch.no_grad(): | |
| patch_scores = self.patch_head(combined).squeeze(-1) # (N,) | |
| patch_scores_np = patch_scores.cpu().numpy() # shape (num_patches,) | |
| # ViT-base/16 on 224x224 → 14x14 = 196 patches | |
| n = patch_scores_np.shape[0] | |
| grid = int(np.sqrt(n)) | |
| if grid * grid != n: | |
| # Fallback: pad to nearest square | |
| grid = int(np.ceil(np.sqrt(n))) | |
| pad = grid * grid - n | |
| patch_scores_np = np.concatenate([patch_scores_np, np.zeros(pad)]) | |
| patch_map = patch_scores_np.reshape(grid, grid) | |
| # Upsample patch-level map to original image size | |
| confidence_map = cv2.resize( | |
| patch_map.astype(np.float32), | |
| (w, h), | |
| interpolation=cv2.INTER_LINEAR | |
| ) | |
| confidence_map = np.clip(confidence_map, 0.0, 1.0) | |
| # Threshold to binary mask | |
| change_mask = (confidence_map > threshold).astype(np.uint8) | |
| # Morphological cleanup to reduce noise | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_OPEN, kernel, iterations=1) | |
| return change_mask, confidence_map | |
| def batch_detect_changes( | |
| self, | |
| before_images: np.ndarray, | |
| after_images: np.ndarray, | |
| threshold: float = 0.5 | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| masks, confidences = [], [] | |
| for b, a in zip(before_images, after_images): | |
| mask, conf = self.detect_changes(b, a, threshold) | |
| masks.append(mask) | |
| confidences.append(conf) | |
| return np.array(masks), np.array(confidences) | |