SentinelWatch / models /cloud_detector.py
VishaliniS456's picture
Upload 8 files
9875bf8 verified
import numpy as np
from typing import Tuple, Optional
import cv2
class CloudDetector:
"""
Cloud detection for satellite imagery using brightness and saturation analysis.
For RGB satellite images, clouds are detected based on:
- High brightness (RGB values close to 255)
- Low saturation (near-white appearance)
- Spatial clustering (using morphological operations)
"""
def __init__(self, device: Optional[str] = None):
self.device = device or "cpu"
def detect_clouds(
self,
image: np.ndarray,
threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray]:
"""
Detect clouds in satellite image.
Args are:
image: Input image (H, W, 3), uint8 [0,255] or float [0,1]
threshold: Cloud confidence threshold (0-1)
Returns:
cloud_mask: 2D binary array (H, W)
cloud_confidence: 2D float array (H, W) in [0,1]
"""
# Normalise to float [0, 1]
if image.dtype == np.uint8:
img = image.astype(np.float32) / 255.0
else:
img = np.clip(image, 0, 1).astype(np.float32)
# Handle grayscale
if img.ndim == 2:
img = np.stack([img, img, img], axis=-1)
elif img.shape[2] == 1:
img = np.concatenate([img, img, img], axis=-1)
red = img[:, :, 0]
green = img[:, :, 1]
blue = img[:, :, 2]
# Brightness: mean of RGB channels
brightness = (red + green + blue) / 3.0
# Saturation (HSV-style for RGB)
max_rgb = np.maximum(np.maximum(red, green), blue)
min_rgb = np.minimum(np.minimum(red, green), blue)
saturation = np.where(
max_rgb > 0,
(max_rgb - min_rgb) / (max_rgb + 1e-8),
0.0
)
# Cloud score: high brightness + low saturation
brightness_score = np.clip((brightness - 0.4) / 0.6, 0, 1)
saturation_score = np.clip((0.4 - saturation) / 0.4, 0, 1)
cloud_confidence = (0.6 * brightness_score + 0.4 * saturation_score).astype(np.float32)
# Binary mask + morphological cleanup
cloud_mask = (cloud_confidence > threshold).astype(np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
cloud_mask = cv2.morphologyEx(cloud_mask, cv2.MORPH_OPEN, kernel, iterations=1)
return cloud_mask, cloud_confidence
def batch_detect(
self,
images: np.ndarray,
threshold: float = 0.5
) -> Tuple[np.ndarray, np.ndarray]:
masks, confidences = [], []
for image in images:
mask, conf = self.detect_clouds(image, threshold)
masks.append(mask)
confidences.append(conf)
return np.array(masks), np.array(confidences)