import torch import torch.nn as nn import torch.nn.functional as F from transformers import ViTModel, ViTImageProcessor import numpy as np from typing import Tuple, Optional import cv2 class ChangeDetector: """ Change detection model using Siamese ViT architecture. Detects changes between two temporal satellite images. Produces spatial 2D confidence maps and masks. """ def __init__( self, model_name: str = "google/vit-base-patch16-224", device: Optional[str] = None ): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model_name = model_name self.processor = ViTImageProcessor.from_pretrained(model_name) self.encoder = ViTModel.from_pretrained(model_name) self.encoder.to(self.device) self.encoder.eval() hidden_size = self.encoder.config.hidden_size # Lightweight head to score each patch token self.patch_head = nn.Sequential( nn.Linear(hidden_size * 2, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 1), nn.Sigmoid() ).to(self.device) def _encode_patches(self, image: np.ndarray) -> torch.Tensor: """ Encodes the image and return ALL patch tokens (not just CLS). Returns: Tensor of shape (num_patches, hidden_size) """ # Ensure uint8 [0,255] for processor if image.dtype != np.uint8: img_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8) else: img_uint8 = image inputs = self.processor(images=img_uint8, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.encoder(**inputs) # last_hidden_state: (1, 1+num_patches, hidden_size) # index 0 is CLS, 1: are patch tokens patch_tokens = outputs.last_hidden_state[0, 1:, :] # (num_patches, H) return patch_tokens def detect_changes( self, before_image: np.ndarray, after_image: np.ndarray, threshold: float = 0.5 ) -> Tuple[np.ndarray, np.ndarray]: """ Detect changes between two temporal images. Returns: change_mask: 2D binary array (H, W) matching input image size confidence_map: 2D float array (H, W) in [0, 1] """ h, w = before_image.shape[:2] before_patches = self._encode_patches(before_image) # (N, D) after_patches = self._encode_patches(after_image) # (N, D) # Concatenate patch-wise features combined = torch.cat([before_patches, after_patches], dim=-1) # (N, 2D) with torch.no_grad(): patch_scores = self.patch_head(combined).squeeze(-1) # (N,) patch_scores_np = patch_scores.cpu().numpy() # shape (num_patches,) # ViT-base/16 on 224x224 → 14x14 = 196 patches n = patch_scores_np.shape[0] grid = int(np.sqrt(n)) if grid * grid != n: # Fallback: pad to nearest square grid = int(np.ceil(np.sqrt(n))) pad = grid * grid - n patch_scores_np = np.concatenate([patch_scores_np, np.zeros(pad)]) patch_map = patch_scores_np.reshape(grid, grid) # Upsample patch-level map to original image size confidence_map = cv2.resize( patch_map.astype(np.float32), (w, h), interpolation=cv2.INTER_LINEAR ) confidence_map = np.clip(confidence_map, 0.0, 1.0) # Threshold to binary mask change_mask = (confidence_map > threshold).astype(np.uint8) # Morphological cleanup to reduce noise kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_OPEN, kernel, iterations=1) return change_mask, confidence_map def batch_detect_changes( self, before_images: np.ndarray, after_images: np.ndarray, threshold: float = 0.5 ) -> Tuple[np.ndarray, np.ndarray]: masks, confidences = [], [] for b, a in zip(before_images, after_images): mask, conf = self.detect_changes(b, a, threshold) masks.append(mask) confidences.append(conf) return np.array(masks), np.array(confidences)