Spaces:
Running
Running
Upload 8 files
Browse files- models/__init__.py +4 -0
- models/change_detector.py +129 -0
- models/cloud_detector.py +86 -0
- utils/__init__.py +16 -0
- utils/evaluation.py +54 -0
- utils/metrics.py +95 -0
- utils/preprocessing.py +79 -0
- utils/visualization.py +96 -0
models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cloud_detector import CloudDetector
|
| 2 |
+
from .change_detector import ChangeDetector
|
| 3 |
+
|
| 4 |
+
__all__ = ["CloudDetector", "ChangeDetector"]
|
models/change_detector.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import ViTModel, ViTImageProcessor
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Tuple, Optional
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ChangeDetector:
|
| 11 |
+
"""
|
| 12 |
+
Change detection model using Siamese ViT architecture.
|
| 13 |
+
|
| 14 |
+
Detects changes between two temporal satellite images.
|
| 15 |
+
Produces spatial 2D confidence maps and masks.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
model_name: str = "google/vit-base-patch16-224",
|
| 21 |
+
device: Optional[str] = None
|
| 22 |
+
):
|
| 23 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
self.model_name = model_name
|
| 25 |
+
|
| 26 |
+
self.processor = ViTImageProcessor.from_pretrained(model_name)
|
| 27 |
+
self.encoder = ViTModel.from_pretrained(model_name)
|
| 28 |
+
self.encoder.to(self.device)
|
| 29 |
+
self.encoder.eval()
|
| 30 |
+
|
| 31 |
+
hidden_size = self.encoder.config.hidden_size
|
| 32 |
+
# Lightweight head to score each patch token
|
| 33 |
+
self.patch_head = nn.Sequential(
|
| 34 |
+
nn.Linear(hidden_size * 2, 256),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
nn.Dropout(0.2),
|
| 37 |
+
nn.Linear(256, 1),
|
| 38 |
+
nn.Sigmoid()
|
| 39 |
+
).to(self.device)
|
| 40 |
+
|
| 41 |
+
def _encode_patches(self, image: np.ndarray) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Encodes the image and return ALL patch tokens (not just CLS).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tensor of shape (num_patches, hidden_size)
|
| 47 |
+
"""
|
| 48 |
+
# Ensure uint8 [0,255] for processor
|
| 49 |
+
if image.dtype != np.uint8:
|
| 50 |
+
img_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
| 51 |
+
else:
|
| 52 |
+
img_uint8 = image
|
| 53 |
+
|
| 54 |
+
inputs = self.processor(images=img_uint8, return_tensors="pt").to(self.device)
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
outputs = self.encoder(**inputs)
|
| 58 |
+
# last_hidden_state: (1, 1+num_patches, hidden_size)
|
| 59 |
+
# index 0 is CLS, 1: are patch tokens
|
| 60 |
+
patch_tokens = outputs.last_hidden_state[0, 1:, :] # (num_patches, H)
|
| 61 |
+
|
| 62 |
+
return patch_tokens
|
| 63 |
+
|
| 64 |
+
def detect_changes(
|
| 65 |
+
self,
|
| 66 |
+
before_image: np.ndarray,
|
| 67 |
+
after_image: np.ndarray,
|
| 68 |
+
threshold: float = 0.5
|
| 69 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 70 |
+
"""
|
| 71 |
+
Detect changes between two temporal images.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
change_mask: 2D binary array (H, W) matching input image size
|
| 75 |
+
confidence_map: 2D float array (H, W) in [0, 1]
|
| 76 |
+
"""
|
| 77 |
+
h, w = before_image.shape[:2]
|
| 78 |
+
|
| 79 |
+
before_patches = self._encode_patches(before_image) # (N, D)
|
| 80 |
+
after_patches = self._encode_patches(after_image) # (N, D)
|
| 81 |
+
|
| 82 |
+
# Concatenate patch-wise features
|
| 83 |
+
combined = torch.cat([before_patches, after_patches], dim=-1) # (N, 2D)
|
| 84 |
+
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
patch_scores = self.patch_head(combined).squeeze(-1) # (N,)
|
| 87 |
+
|
| 88 |
+
patch_scores_np = patch_scores.cpu().numpy() # shape (num_patches,)
|
| 89 |
+
|
| 90 |
+
# ViT-base/16 on 224x224 → 14x14 = 196 patches
|
| 91 |
+
n = patch_scores_np.shape[0]
|
| 92 |
+
grid = int(np.sqrt(n))
|
| 93 |
+
if grid * grid != n:
|
| 94 |
+
# Fallback: pad to nearest square
|
| 95 |
+
grid = int(np.ceil(np.sqrt(n)))
|
| 96 |
+
pad = grid * grid - n
|
| 97 |
+
patch_scores_np = np.concatenate([patch_scores_np, np.zeros(pad)])
|
| 98 |
+
|
| 99 |
+
patch_map = patch_scores_np.reshape(grid, grid)
|
| 100 |
+
|
| 101 |
+
# Upsample patch-level map to original image size
|
| 102 |
+
confidence_map = cv2.resize(
|
| 103 |
+
patch_map.astype(np.float32),
|
| 104 |
+
(w, h),
|
| 105 |
+
interpolation=cv2.INTER_LINEAR
|
| 106 |
+
)
|
| 107 |
+
confidence_map = np.clip(confidence_map, 0.0, 1.0)
|
| 108 |
+
|
| 109 |
+
# Threshold to binary mask
|
| 110 |
+
change_mask = (confidence_map > threshold).astype(np.uint8)
|
| 111 |
+
|
| 112 |
+
# Morphological cleanup to reduce noise
|
| 113 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 114 |
+
change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
| 115 |
+
|
| 116 |
+
return change_mask, confidence_map
|
| 117 |
+
|
| 118 |
+
def batch_detect_changes(
|
| 119 |
+
self,
|
| 120 |
+
before_images: np.ndarray,
|
| 121 |
+
after_images: np.ndarray,
|
| 122 |
+
threshold: float = 0.5
|
| 123 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 124 |
+
masks, confidences = [], []
|
| 125 |
+
for b, a in zip(before_images, after_images):
|
| 126 |
+
mask, conf = self.detect_changes(b, a, threshold)
|
| 127 |
+
masks.append(mask)
|
| 128 |
+
confidences.append(conf)
|
| 129 |
+
return np.array(masks), np.array(confidences)
|
models/cloud_detector.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import Tuple, Optional
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CloudDetector:
|
| 7 |
+
"""
|
| 8 |
+
Cloud detection for satellite imagery using brightness and saturation analysis.
|
| 9 |
+
|
| 10 |
+
For RGB satellite images, clouds are detected based on:
|
| 11 |
+
- High brightness (RGB values close to 255)
|
| 12 |
+
- Low saturation (near-white appearance)
|
| 13 |
+
- Spatial clustering (using morphological operations)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, device: Optional[str] = None):
|
| 17 |
+
self.device = device or "cpu"
|
| 18 |
+
|
| 19 |
+
def detect_clouds(
|
| 20 |
+
self,
|
| 21 |
+
image: np.ndarray,
|
| 22 |
+
threshold: float = 0.5
|
| 23 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 24 |
+
"""
|
| 25 |
+
Detect clouds in satellite image.
|
| 26 |
+
|
| 27 |
+
Args are:
|
| 28 |
+
image: Input image (H, W, 3), uint8 [0,255] or float [0,1]
|
| 29 |
+
threshold: Cloud confidence threshold (0-1)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
cloud_mask: 2D binary array (H, W)
|
| 33 |
+
cloud_confidence: 2D float array (H, W) in [0,1]
|
| 34 |
+
"""
|
| 35 |
+
# Normalise to float [0, 1]
|
| 36 |
+
if image.dtype == np.uint8:
|
| 37 |
+
img = image.astype(np.float32) / 255.0
|
| 38 |
+
else:
|
| 39 |
+
img = np.clip(image, 0, 1).astype(np.float32)
|
| 40 |
+
|
| 41 |
+
# Handle grayscale
|
| 42 |
+
if img.ndim == 2:
|
| 43 |
+
img = np.stack([img, img, img], axis=-1)
|
| 44 |
+
elif img.shape[2] == 1:
|
| 45 |
+
img = np.concatenate([img, img, img], axis=-1)
|
| 46 |
+
|
| 47 |
+
red = img[:, :, 0]
|
| 48 |
+
green = img[:, :, 1]
|
| 49 |
+
blue = img[:, :, 2]
|
| 50 |
+
|
| 51 |
+
# Brightness: mean of RGB channels
|
| 52 |
+
brightness = (red + green + blue) / 3.0
|
| 53 |
+
|
| 54 |
+
# Saturation (HSV-style for RGB)
|
| 55 |
+
max_rgb = np.maximum(np.maximum(red, green), blue)
|
| 56 |
+
min_rgb = np.minimum(np.minimum(red, green), blue)
|
| 57 |
+
saturation = np.where(
|
| 58 |
+
max_rgb > 0,
|
| 59 |
+
(max_rgb - min_rgb) / (max_rgb + 1e-8),
|
| 60 |
+
0.0
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Cloud score: high brightness + low saturation
|
| 64 |
+
brightness_score = np.clip((brightness - 0.4) / 0.6, 0, 1)
|
| 65 |
+
saturation_score = np.clip((0.4 - saturation) / 0.4, 0, 1)
|
| 66 |
+
|
| 67 |
+
cloud_confidence = (0.6 * brightness_score + 0.4 * saturation_score).astype(np.float32)
|
| 68 |
+
|
| 69 |
+
# Binary mask + morphological cleanup
|
| 70 |
+
cloud_mask = (cloud_confidence > threshold).astype(np.uint8)
|
| 71 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 72 |
+
cloud_mask = cv2.morphologyEx(cloud_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
| 73 |
+
|
| 74 |
+
return cloud_mask, cloud_confidence
|
| 75 |
+
|
| 76 |
+
def batch_detect(
|
| 77 |
+
self,
|
| 78 |
+
images: np.ndarray,
|
| 79 |
+
threshold: float = 0.5
|
| 80 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 81 |
+
masks, confidences = [], []
|
| 82 |
+
for image in images:
|
| 83 |
+
mask, conf = self.detect_clouds(image, threshold)
|
| 84 |
+
masks.append(mask)
|
| 85 |
+
confidences.append(conf)
|
| 86 |
+
return np.array(masks), np.array(confidences)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility modules for satellite change detection."""
|
| 2 |
+
|
| 3 |
+
from .preprocessing import preprocess_image, mask_clouds
|
| 4 |
+
from .visualization import create_overlay, visualize_predictions
|
| 5 |
+
from .evaluation import calculate_metrics
|
| 6 |
+
from .metrics import calculate_change_statistics, compare_with_without_masking
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"preprocess_image",
|
| 10 |
+
"mask_clouds",
|
| 11 |
+
"create_overlay",
|
| 12 |
+
"visualize_predictions",
|
| 13 |
+
"calculate_metrics",
|
| 14 |
+
"calculate_change_statistics",
|
| 15 |
+
"compare_with_without_masking",
|
| 16 |
+
]
|
utils/evaluation.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation metrics for change and cloud detection."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def calculate_metrics(
|
| 8 |
+
pred_mask: np.ndarray,
|
| 9 |
+
gt_mask: np.ndarray,
|
| 10 |
+
threshold: float = 0.5
|
| 11 |
+
) -> Dict[str, float]:
|
| 12 |
+
"""
|
| 13 |
+
Calculate pixel-level classification metrics.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
pred_mask: Predicted binary mask (H, W) or confidence map (H, W)
|
| 17 |
+
gt_mask: Ground truth binary mask (H, W)
|
| 18 |
+
threshold: Threshold to binarise pred_mask if it's a confidence map
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Dict with keys: accuracy, precision, recall, f1, iou
|
| 22 |
+
"""
|
| 23 |
+
# Binarise predictions if needed
|
| 24 |
+
if pred_mask.dtype != np.uint8 or pred_mask.max() > 1:
|
| 25 |
+
pred = (pred_mask > threshold).astype(np.uint8)
|
| 26 |
+
else:
|
| 27 |
+
pred = pred_mask.astype(np.uint8)
|
| 28 |
+
|
| 29 |
+
gt = (gt_mask > 0).astype(np.uint8)
|
| 30 |
+
|
| 31 |
+
tp = int(np.sum((pred == 1) & (gt == 1)))
|
| 32 |
+
tn = int(np.sum((pred == 0) & (gt == 0)))
|
| 33 |
+
fp = int(np.sum((pred == 1) & (gt == 0)))
|
| 34 |
+
fn = int(np.sum((pred == 0) & (gt == 1)))
|
| 35 |
+
|
| 36 |
+
total = tp + tn + fp + fn
|
| 37 |
+
accuracy = (tp + tn) / total if total > 0 else 0.0
|
| 38 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 39 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 40 |
+
f1 = (2 * precision * recall / (precision + recall)
|
| 41 |
+
if (precision + recall) > 0 else 0.0)
|
| 42 |
+
iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
|
| 43 |
+
|
| 44 |
+
return {
|
| 45 |
+
"accuracy": accuracy,
|
| 46 |
+
"precision": precision,
|
| 47 |
+
"recall": recall,
|
| 48 |
+
"f1": f1,
|
| 49 |
+
"iou": iou,
|
| 50 |
+
"tp": tp,
|
| 51 |
+
"tn": tn,
|
| 52 |
+
"fp": fp,
|
| 53 |
+
"fn": fn,
|
| 54 |
+
}
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Advanced comparison metrics for change detection evaluation."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
from .evaluation import calculate_metrics
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def calculate_change_statistics(
|
| 9 |
+
change_mask: np.ndarray,
|
| 10 |
+
change_confidence: np.ndarray
|
| 11 |
+
) -> Dict:
|
| 12 |
+
"""
|
| 13 |
+
Calculate statistics from a 2D change mask and confidence map.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
change_mask: Binary 2D array (H, W), 1 = changed
|
| 17 |
+
change_confidence: Float 2D array (H, W) in [0, 1]
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Dict with total_pixels, changed_pixels, unchanged_pixels,
|
| 21 |
+
change_percentage, mean_confidence, min_confidence, max_confidence,
|
| 22 |
+
change_confidence_mean (mean conf among changed pixels only)
|
| 23 |
+
"""
|
| 24 |
+
# Ensure 2D arrays
|
| 25 |
+
mask = change_mask.astype(np.uint8)
|
| 26 |
+
conf = change_confidence.astype(np.float32)
|
| 27 |
+
|
| 28 |
+
total_pixels = int(mask.size)
|
| 29 |
+
changed_pixels = int(np.sum(mask == 1))
|
| 30 |
+
unchanged_pixels = total_pixels - changed_pixels
|
| 31 |
+
change_percentage = 100.0 * changed_pixels / total_pixels if total_pixels > 0 else 0.0
|
| 32 |
+
|
| 33 |
+
mean_confidence = float(conf.mean())
|
| 34 |
+
min_confidence = float(conf.min())
|
| 35 |
+
max_confidence = float(conf.max())
|
| 36 |
+
|
| 37 |
+
# Mean confidence among changed pixels only
|
| 38 |
+
if changed_pixels > 0:
|
| 39 |
+
change_confidence_mean = float(conf[mask == 1].mean())
|
| 40 |
+
else:
|
| 41 |
+
change_confidence_mean = 0.0
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"total_pixels": total_pixels,
|
| 45 |
+
"changed_pixels": changed_pixels,
|
| 46 |
+
"unchanged_pixels": unchanged_pixels,
|
| 47 |
+
"change_percentage": change_percentage,
|
| 48 |
+
"mean_confidence": mean_confidence,
|
| 49 |
+
"min_confidence": min_confidence,
|
| 50 |
+
"max_confidence": max_confidence,
|
| 51 |
+
"change_confidence_mean": change_confidence_mean,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compare_with_without_masking(
|
| 56 |
+
pred_with_mask: np.ndarray,
|
| 57 |
+
pred_without_mask: np.ndarray,
|
| 58 |
+
gt_mask: Optional[np.ndarray] = None
|
| 59 |
+
) -> Dict:
|
| 60 |
+
"""
|
| 61 |
+
Compare detection results with and without cloud masking.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
pred_with_mask: Change mask produced WITH cloud masking (H, W)
|
| 65 |
+
pred_without_mask: Change mask produced WITHOUT cloud masking (H, W)
|
| 66 |
+
gt_mask: Optional ground truth mask for metric computation (H, W)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dict with pixel-level comparison and optional metric differences
|
| 70 |
+
"""
|
| 71 |
+
agreement = int(np.sum(pred_with_mask == pred_without_mask))
|
| 72 |
+
total = int(pred_with_mask.size)
|
| 73 |
+
agreement_pct = 100.0 * agreement / total if total > 0 else 0.0
|
| 74 |
+
|
| 75 |
+
result = {
|
| 76 |
+
"agreement_pixels": agreement,
|
| 77 |
+
"total_pixels": total,
|
| 78 |
+
"agreement_percentage": agreement_pct,
|
| 79 |
+
"changed_with_mask": int(np.sum(pred_with_mask)),
|
| 80 |
+
"changed_without_mask": int(np.sum(pred_without_mask)),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if gt_mask is not None:
|
| 84 |
+
metrics_with = calculate_metrics(pred_with_mask, gt_mask)
|
| 85 |
+
metrics_without = calculate_metrics(pred_without_mask, gt_mask)
|
| 86 |
+
|
| 87 |
+
result["iou_with_mask"] = metrics_with["iou"]
|
| 88 |
+
result["iou_without_mask"] = metrics_without["iou"]
|
| 89 |
+
result["iou_improvement"] = metrics_with["iou"] - metrics_without["iou"]
|
| 90 |
+
|
| 91 |
+
result["f1_with_mask"] = metrics_with["f1"]
|
| 92 |
+
result["f1_without_mask"] = metrics_without["f1"]
|
| 93 |
+
result["f1_improvement"] = metrics_with["f1"] - metrics_without["f1"]
|
| 94 |
+
|
| 95 |
+
return result
|
utils/preprocessing.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocessing utilities for satellite imagery."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def preprocess_image(
|
| 9 |
+
image: np.ndarray,
|
| 10 |
+
target_size: Optional[tuple] = None,
|
| 11 |
+
normalize: bool = True
|
| 12 |
+
) -> np.ndarray:
|
| 13 |
+
"""
|
| 14 |
+
Preprocess a satellite image for model input.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image: Input image (H, W, C), uint8 or float
|
| 18 |
+
target_size: Optional (width, height) to resize to
|
| 19 |
+
normalize: If True, output is float32 in [0, 1]
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Preprocessed image as float32 [0,1] or uint8 [0,255]
|
| 23 |
+
"""
|
| 24 |
+
if image is None:
|
| 25 |
+
raise ValueError("Input image is None")
|
| 26 |
+
|
| 27 |
+
img = image.copy()
|
| 28 |
+
|
| 29 |
+
# Ensure 3-channel
|
| 30 |
+
if img.ndim == 2:
|
| 31 |
+
img = np.stack([img, img, img], axis=-1)
|
| 32 |
+
elif img.shape[2] == 1:
|
| 33 |
+
img = np.concatenate([img, img, img], axis=-1)
|
| 34 |
+
elif img.shape[2] > 3:
|
| 35 |
+
img = img[:, :, :3]
|
| 36 |
+
|
| 37 |
+
# Resize if requested
|
| 38 |
+
if target_size is not None:
|
| 39 |
+
img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
|
| 40 |
+
|
| 41 |
+
# Normalise
|
| 42 |
+
if normalize:
|
| 43 |
+
if img.dtype == np.uint8:
|
| 44 |
+
img = img.astype(np.float32) / 255.0
|
| 45 |
+
else:
|
| 46 |
+
img = np.clip(img, 0, 1).astype(np.float32)
|
| 47 |
+
else:
|
| 48 |
+
if img.dtype != np.uint8:
|
| 49 |
+
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
|
| 50 |
+
|
| 51 |
+
return img
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mask_clouds(
|
| 55 |
+
image: np.ndarray,
|
| 56 |
+
cloud_mask: np.ndarray,
|
| 57 |
+
fill_value: float = 0.0
|
| 58 |
+
) -> np.ndarray:
|
| 59 |
+
"""
|
| 60 |
+
Apply cloud mask to image, replacing cloud pixels with fill_value.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
image: Input image (H, W, C)
|
| 64 |
+
cloud_mask: Binary mask (H, W), 1 = cloud
|
| 65 |
+
fill_value: Value to fill masked pixels with
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Masked image same dtype as input
|
| 69 |
+
"""
|
| 70 |
+
masked = image.copy().astype(np.float32)
|
| 71 |
+
mask_bool = cloud_mask.astype(bool)
|
| 72 |
+
|
| 73 |
+
for c in range(masked.shape[2]):
|
| 74 |
+
masked[:, :, c][mask_bool] = fill_value
|
| 75 |
+
|
| 76 |
+
if image.dtype == np.uint8:
|
| 77 |
+
masked = np.clip(masked, 0, 255).astype(np.uint8)
|
| 78 |
+
|
| 79 |
+
return masked
|
utils/visualization.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visualization utilities for satellite change detection."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
from typing import Tuple, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_overlay(
|
| 9 |
+
image: np.ndarray,
|
| 10 |
+
mask: np.ndarray,
|
| 11 |
+
alpha: float = 0.5,
|
| 12 |
+
color: Tuple[int, int, int] = (255, 0, 0)
|
| 13 |
+
) -> np.ndarray:
|
| 14 |
+
"""
|
| 15 |
+
Overlay a binary mask on an image with transparency.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
image: Base image (H, W, 3), uint8 or float [0,1]
|
| 19 |
+
mask: Binary mask (H, W), values 0 or 1
|
| 20 |
+
alpha: Overlay transparency (0 = invisible, 1 = opaque)
|
| 21 |
+
color: RGB color for the overlay
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Blended image as uint8 (H, W, 3)
|
| 25 |
+
"""
|
| 26 |
+
# Convert image to uint8
|
| 27 |
+
if image.dtype != np.uint8:
|
| 28 |
+
base = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
| 29 |
+
else:
|
| 30 |
+
base = image.copy()
|
| 31 |
+
|
| 32 |
+
# Ensure 3 channels
|
| 33 |
+
if base.ndim == 2:
|
| 34 |
+
base = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB)
|
| 35 |
+
|
| 36 |
+
overlay = base.copy()
|
| 37 |
+
mask_bool = mask.astype(bool)
|
| 38 |
+
|
| 39 |
+
# Apply colour to masked region
|
| 40 |
+
overlay[mask_bool] = [color[0], color[1], color[2]]
|
| 41 |
+
|
| 42 |
+
# Blend
|
| 43 |
+
result = cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0)
|
| 44 |
+
return result.astype(np.uint8)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def visualize_predictions(
|
| 48 |
+
image: np.ndarray,
|
| 49 |
+
pred_mask: np.ndarray,
|
| 50 |
+
gt_mask: Optional[np.ndarray] = None,
|
| 51 |
+
confidence: Optional[np.ndarray] = None
|
| 52 |
+
) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Create a side-by-side visualization of image, prediction, and (optionally) ground truth.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
image: Original image (H, W, 3)
|
| 58 |
+
pred_mask: Predicted binary mask (H, W)
|
| 59 |
+
gt_mask: Optional ground truth mask (H, W)
|
| 60 |
+
confidence: Optional confidence map (H, W)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Combined visualization as uint8 (H, W*N, 3)
|
| 64 |
+
"""
|
| 65 |
+
if image.dtype != np.uint8:
|
| 66 |
+
img_u8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
| 67 |
+
else:
|
| 68 |
+
img_u8 = image.copy()
|
| 69 |
+
|
| 70 |
+
h, w = img_u8.shape[:2]
|
| 71 |
+
panels = [img_u8]
|
| 72 |
+
|
| 73 |
+
# Prediction overlay (red)
|
| 74 |
+
pred_overlay = create_overlay(img_u8, pred_mask, alpha=0.5, color=(255, 0, 0))
|
| 75 |
+
panels.append(pred_overlay)
|
| 76 |
+
|
| 77 |
+
# Ground truth overlay (green)
|
| 78 |
+
if gt_mask is not None:
|
| 79 |
+
gt_overlay = create_overlay(img_u8, gt_mask, alpha=0.5, color=(0, 255, 0))
|
| 80 |
+
panels.append(gt_overlay)
|
| 81 |
+
|
| 82 |
+
# Confidence heatmap
|
| 83 |
+
if confidence is not None:
|
| 84 |
+
heatmap = (np.clip(confidence, 0, 1) * 255).astype(np.uint8)
|
| 85 |
+
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 86 |
+
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
|
| 87 |
+
panels.append(heatmap_color)
|
| 88 |
+
|
| 89 |
+
# Resize all panels to same height
|
| 90 |
+
panels_resized = [
|
| 91 |
+
cv2.resize(p, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 92 |
+
if p.shape[:2] != (h, w) else p
|
| 93 |
+
for p in panels
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
return np.concatenate(panels_resized, axis=1)
|