SentinelWatch / models /change_detector.py
VishaliniS456's picture
Upload 8 files
9875bf8 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel, ViTImageProcessor
import numpy as np
from typing import Tuple, Optional
import cv2
class ChangeDetector:
"""
Change detection model using Siamese ViT architecture.
Detects changes between two temporal satellite images.
Produces spatial 2D confidence maps and masks.
"""
def __init__(
self,
model_name: str = "google/vit-base-patch16-224",
device: Optional[str] = None
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
self.processor = ViTImageProcessor.from_pretrained(model_name)
self.encoder = ViTModel.from_pretrained(model_name)
self.encoder.to(self.device)
self.encoder.eval()
hidden_size = self.encoder.config.hidden_size
# Lightweight head to score each patch token
self.patch_head = nn.Sequential(
nn.Linear(hidden_size * 2, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
).to(self.device)
def _encode_patches(self, image: np.ndarray) -> torch.Tensor:
"""
Encodes the image and return ALL patch tokens (not just CLS).
Returns:
Tensor of shape (num_patches, hidden_size)
"""
# Ensure uint8 [0,255] for processor
if image.dtype != np.uint8:
img_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
else:
img_uint8 = image
inputs = self.processor(images=img_uint8, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.encoder(**inputs)
# last_hidden_state: (1, 1+num_patches, hidden_size)
# index 0 is CLS, 1: are patch tokens
patch_tokens = outputs.last_hidden_state[0, 1:, :] # (num_patches, H)
return patch_tokens
def detect_changes(
self,
before_image: np.ndarray,
after_image: np.ndarray,
threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray]:
"""
Detect changes between two temporal images.
Returns:
change_mask: 2D binary array (H, W) matching input image size
confidence_map: 2D float array (H, W) in [0, 1]
"""
h, w = before_image.shape[:2]
before_patches = self._encode_patches(before_image) # (N, D)
after_patches = self._encode_patches(after_image) # (N, D)
# Concatenate patch-wise features
combined = torch.cat([before_patches, after_patches], dim=-1) # (N, 2D)
with torch.no_grad():
patch_scores = self.patch_head(combined).squeeze(-1) # (N,)
patch_scores_np = patch_scores.cpu().numpy() # shape (num_patches,)
# ViT-base/16 on 224x224 → 14x14 = 196 patches
n = patch_scores_np.shape[0]
grid = int(np.sqrt(n))
if grid * grid != n:
# Fallback: pad to nearest square
grid = int(np.ceil(np.sqrt(n)))
pad = grid * grid - n
patch_scores_np = np.concatenate([patch_scores_np, np.zeros(pad)])
patch_map = patch_scores_np.reshape(grid, grid)
# Upsample patch-level map to original image size
confidence_map = cv2.resize(
patch_map.astype(np.float32),
(w, h),
interpolation=cv2.INTER_LINEAR
)
confidence_map = np.clip(confidence_map, 0.0, 1.0)
# Threshold to binary mask
change_mask = (confidence_map > threshold).astype(np.uint8)
# Morphological cleanup to reduce noise
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_OPEN, kernel, iterations=1)
return change_mask, confidence_map
def batch_detect_changes(
self,
before_images: np.ndarray,
after_images: np.ndarray,
threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray]:
masks, confidences = [], []
for b, a in zip(before_images, after_images):
mask, conf = self.detect_changes(b, a, threshold)
masks.append(mask)
confidences.append(conf)
return np.array(masks), np.array(confidences)