File size: 2,826 Bytes
9875bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
from typing import Tuple, Optional
import cv2


class CloudDetector:
    """
    Cloud detection for satellite imagery using brightness and saturation analysis.

    For RGB satellite images, clouds are detected based on:
    - High brightness (RGB values close to 255)
    - Low saturation (near-white appearance)
    - Spatial clustering (using morphological operations)
    """

    def __init__(self, device: Optional[str] = None):
        self.device = device or "cpu"

    def detect_clouds(
        self,
        image: np.ndarray,
        threshold: float = 0.5
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Detect clouds in satellite image.

        Args are:
            image: Input image (H, W, 3), uint8 [0,255] or float [0,1]
            threshold: Cloud confidence threshold (0-1)

        Returns:
            cloud_mask: 2D binary array (H, W)
            cloud_confidence: 2D float array (H, W) in [0,1]
        """
        # Normalise to float [0, 1]
        if image.dtype == np.uint8:
            img = image.astype(np.float32) / 255.0
        else:
            img = np.clip(image, 0, 1).astype(np.float32)

        # Handle grayscale
        if img.ndim == 2:
            img = np.stack([img, img, img], axis=-1)
        elif img.shape[2] == 1:
            img = np.concatenate([img, img, img], axis=-1)

        red   = img[:, :, 0]
        green = img[:, :, 1]
        blue  = img[:, :, 2]

        # Brightness: mean of RGB channels
        brightness = (red + green + blue) / 3.0

        # Saturation (HSV-style for RGB)
        max_rgb = np.maximum(np.maximum(red, green), blue)
        min_rgb = np.minimum(np.minimum(red, green), blue)
        saturation = np.where(
            max_rgb > 0,
            (max_rgb - min_rgb) / (max_rgb + 1e-8),
            0.0
        )

        # Cloud score: high brightness + low saturation
        brightness_score = np.clip((brightness - 0.4) / 0.6, 0, 1)
        saturation_score = np.clip((0.4 - saturation) / 0.4, 0, 1)

        cloud_confidence = (0.6 * brightness_score + 0.4 * saturation_score).astype(np.float32)

        # Binary mask + morphological cleanup
        cloud_mask = (cloud_confidence > threshold).astype(np.uint8)
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        cloud_mask = cv2.morphologyEx(cloud_mask, cv2.MORPH_OPEN, kernel, iterations=1)

        return cloud_mask, cloud_confidence

    def batch_detect(
        self,
        images: np.ndarray,
        threshold: float = 0.5
    ) -> Tuple[np.ndarray, np.ndarray]:
        masks, confidences = [], []
        for image in images:
            mask, conf = self.detect_clouds(image, threshold)
            masks.append(mask)
            confidences.append(conf)
        return np.array(masks), np.array(confidences)