VishaliniS456 commited on
Commit
9875bf8
·
verified ·
1 Parent(s): faf1433

Upload 8 files

Browse files
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cloud_detector import CloudDetector
2
+ from .change_detector import ChangeDetector
3
+
4
+ __all__ = ["CloudDetector", "ChangeDetector"]
models/change_detector.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import ViTModel, ViTImageProcessor
5
+ import numpy as np
6
+ from typing import Tuple, Optional
7
+ import cv2
8
+
9
+
10
+ class ChangeDetector:
11
+ """
12
+ Change detection model using Siamese ViT architecture.
13
+
14
+ Detects changes between two temporal satellite images.
15
+ Produces spatial 2D confidence maps and masks.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: str = "google/vit-base-patch16-224",
21
+ device: Optional[str] = None
22
+ ):
23
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
24
+ self.model_name = model_name
25
+
26
+ self.processor = ViTImageProcessor.from_pretrained(model_name)
27
+ self.encoder = ViTModel.from_pretrained(model_name)
28
+ self.encoder.to(self.device)
29
+ self.encoder.eval()
30
+
31
+ hidden_size = self.encoder.config.hidden_size
32
+ # Lightweight head to score each patch token
33
+ self.patch_head = nn.Sequential(
34
+ nn.Linear(hidden_size * 2, 256),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.2),
37
+ nn.Linear(256, 1),
38
+ nn.Sigmoid()
39
+ ).to(self.device)
40
+
41
+ def _encode_patches(self, image: np.ndarray) -> torch.Tensor:
42
+ """
43
+ Encodes the image and return ALL patch tokens (not just CLS).
44
+
45
+ Returns:
46
+ Tensor of shape (num_patches, hidden_size)
47
+ """
48
+ # Ensure uint8 [0,255] for processor
49
+ if image.dtype != np.uint8:
50
+ img_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
51
+ else:
52
+ img_uint8 = image
53
+
54
+ inputs = self.processor(images=img_uint8, return_tensors="pt").to(self.device)
55
+
56
+ with torch.no_grad():
57
+ outputs = self.encoder(**inputs)
58
+ # last_hidden_state: (1, 1+num_patches, hidden_size)
59
+ # index 0 is CLS, 1: are patch tokens
60
+ patch_tokens = outputs.last_hidden_state[0, 1:, :] # (num_patches, H)
61
+
62
+ return patch_tokens
63
+
64
+ def detect_changes(
65
+ self,
66
+ before_image: np.ndarray,
67
+ after_image: np.ndarray,
68
+ threshold: float = 0.5
69
+ ) -> Tuple[np.ndarray, np.ndarray]:
70
+ """
71
+ Detect changes between two temporal images.
72
+
73
+ Returns:
74
+ change_mask: 2D binary array (H, W) matching input image size
75
+ confidence_map: 2D float array (H, W) in [0, 1]
76
+ """
77
+ h, w = before_image.shape[:2]
78
+
79
+ before_patches = self._encode_patches(before_image) # (N, D)
80
+ after_patches = self._encode_patches(after_image) # (N, D)
81
+
82
+ # Concatenate patch-wise features
83
+ combined = torch.cat([before_patches, after_patches], dim=-1) # (N, 2D)
84
+
85
+ with torch.no_grad():
86
+ patch_scores = self.patch_head(combined).squeeze(-1) # (N,)
87
+
88
+ patch_scores_np = patch_scores.cpu().numpy() # shape (num_patches,)
89
+
90
+ # ViT-base/16 on 224x224 → 14x14 = 196 patches
91
+ n = patch_scores_np.shape[0]
92
+ grid = int(np.sqrt(n))
93
+ if grid * grid != n:
94
+ # Fallback: pad to nearest square
95
+ grid = int(np.ceil(np.sqrt(n)))
96
+ pad = grid * grid - n
97
+ patch_scores_np = np.concatenate([patch_scores_np, np.zeros(pad)])
98
+
99
+ patch_map = patch_scores_np.reshape(grid, grid)
100
+
101
+ # Upsample patch-level map to original image size
102
+ confidence_map = cv2.resize(
103
+ patch_map.astype(np.float32),
104
+ (w, h),
105
+ interpolation=cv2.INTER_LINEAR
106
+ )
107
+ confidence_map = np.clip(confidence_map, 0.0, 1.0)
108
+
109
+ # Threshold to binary mask
110
+ change_mask = (confidence_map > threshold).astype(np.uint8)
111
+
112
+ # Morphological cleanup to reduce noise
113
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
114
+ change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_OPEN, kernel, iterations=1)
115
+
116
+ return change_mask, confidence_map
117
+
118
+ def batch_detect_changes(
119
+ self,
120
+ before_images: np.ndarray,
121
+ after_images: np.ndarray,
122
+ threshold: float = 0.5
123
+ ) -> Tuple[np.ndarray, np.ndarray]:
124
+ masks, confidences = [], []
125
+ for b, a in zip(before_images, after_images):
126
+ mask, conf = self.detect_changes(b, a, threshold)
127
+ masks.append(mask)
128
+ confidences.append(conf)
129
+ return np.array(masks), np.array(confidences)
models/cloud_detector.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Tuple, Optional
3
+ import cv2
4
+
5
+
6
+ class CloudDetector:
7
+ """
8
+ Cloud detection for satellite imagery using brightness and saturation analysis.
9
+
10
+ For RGB satellite images, clouds are detected based on:
11
+ - High brightness (RGB values close to 255)
12
+ - Low saturation (near-white appearance)
13
+ - Spatial clustering (using morphological operations)
14
+ """
15
+
16
+ def __init__(self, device: Optional[str] = None):
17
+ self.device = device or "cpu"
18
+
19
+ def detect_clouds(
20
+ self,
21
+ image: np.ndarray,
22
+ threshold: float = 0.5
23
+ ) -> Tuple[np.ndarray, np.ndarray]:
24
+ """
25
+ Detect clouds in satellite image.
26
+
27
+ Args are:
28
+ image: Input image (H, W, 3), uint8 [0,255] or float [0,1]
29
+ threshold: Cloud confidence threshold (0-1)
30
+
31
+ Returns:
32
+ cloud_mask: 2D binary array (H, W)
33
+ cloud_confidence: 2D float array (H, W) in [0,1]
34
+ """
35
+ # Normalise to float [0, 1]
36
+ if image.dtype == np.uint8:
37
+ img = image.astype(np.float32) / 255.0
38
+ else:
39
+ img = np.clip(image, 0, 1).astype(np.float32)
40
+
41
+ # Handle grayscale
42
+ if img.ndim == 2:
43
+ img = np.stack([img, img, img], axis=-1)
44
+ elif img.shape[2] == 1:
45
+ img = np.concatenate([img, img, img], axis=-1)
46
+
47
+ red = img[:, :, 0]
48
+ green = img[:, :, 1]
49
+ blue = img[:, :, 2]
50
+
51
+ # Brightness: mean of RGB channels
52
+ brightness = (red + green + blue) / 3.0
53
+
54
+ # Saturation (HSV-style for RGB)
55
+ max_rgb = np.maximum(np.maximum(red, green), blue)
56
+ min_rgb = np.minimum(np.minimum(red, green), blue)
57
+ saturation = np.where(
58
+ max_rgb > 0,
59
+ (max_rgb - min_rgb) / (max_rgb + 1e-8),
60
+ 0.0
61
+ )
62
+
63
+ # Cloud score: high brightness + low saturation
64
+ brightness_score = np.clip((brightness - 0.4) / 0.6, 0, 1)
65
+ saturation_score = np.clip((0.4 - saturation) / 0.4, 0, 1)
66
+
67
+ cloud_confidence = (0.6 * brightness_score + 0.4 * saturation_score).astype(np.float32)
68
+
69
+ # Binary mask + morphological cleanup
70
+ cloud_mask = (cloud_confidence > threshold).astype(np.uint8)
71
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
72
+ cloud_mask = cv2.morphologyEx(cloud_mask, cv2.MORPH_OPEN, kernel, iterations=1)
73
+
74
+ return cloud_mask, cloud_confidence
75
+
76
+ def batch_detect(
77
+ self,
78
+ images: np.ndarray,
79
+ threshold: float = 0.5
80
+ ) -> Tuple[np.ndarray, np.ndarray]:
81
+ masks, confidences = [], []
82
+ for image in images:
83
+ mask, conf = self.detect_clouds(image, threshold)
84
+ masks.append(mask)
85
+ confidences.append(conf)
86
+ return np.array(masks), np.array(confidences)
utils/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility modules for satellite change detection."""
2
+
3
+ from .preprocessing import preprocess_image, mask_clouds
4
+ from .visualization import create_overlay, visualize_predictions
5
+ from .evaluation import calculate_metrics
6
+ from .metrics import calculate_change_statistics, compare_with_without_masking
7
+
8
+ __all__ = [
9
+ "preprocess_image",
10
+ "mask_clouds",
11
+ "create_overlay",
12
+ "visualize_predictions",
13
+ "calculate_metrics",
14
+ "calculate_change_statistics",
15
+ "compare_with_without_masking",
16
+ ]
utils/evaluation.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation metrics for change and cloud detection."""
2
+
3
+ import numpy as np
4
+ from typing import Dict, Optional
5
+
6
+
7
+ def calculate_metrics(
8
+ pred_mask: np.ndarray,
9
+ gt_mask: np.ndarray,
10
+ threshold: float = 0.5
11
+ ) -> Dict[str, float]:
12
+ """
13
+ Calculate pixel-level classification metrics.
14
+
15
+ Args:
16
+ pred_mask: Predicted binary mask (H, W) or confidence map (H, W)
17
+ gt_mask: Ground truth binary mask (H, W)
18
+ threshold: Threshold to binarise pred_mask if it's a confidence map
19
+
20
+ Returns:
21
+ Dict with keys: accuracy, precision, recall, f1, iou
22
+ """
23
+ # Binarise predictions if needed
24
+ if pred_mask.dtype != np.uint8 or pred_mask.max() > 1:
25
+ pred = (pred_mask > threshold).astype(np.uint8)
26
+ else:
27
+ pred = pred_mask.astype(np.uint8)
28
+
29
+ gt = (gt_mask > 0).astype(np.uint8)
30
+
31
+ tp = int(np.sum((pred == 1) & (gt == 1)))
32
+ tn = int(np.sum((pred == 0) & (gt == 0)))
33
+ fp = int(np.sum((pred == 1) & (gt == 0)))
34
+ fn = int(np.sum((pred == 0) & (gt == 1)))
35
+
36
+ total = tp + tn + fp + fn
37
+ accuracy = (tp + tn) / total if total > 0 else 0.0
38
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
39
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
40
+ f1 = (2 * precision * recall / (precision + recall)
41
+ if (precision + recall) > 0 else 0.0)
42
+ iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
43
+
44
+ return {
45
+ "accuracy": accuracy,
46
+ "precision": precision,
47
+ "recall": recall,
48
+ "f1": f1,
49
+ "iou": iou,
50
+ "tp": tp,
51
+ "tn": tn,
52
+ "fp": fp,
53
+ "fn": fn,
54
+ }
utils/metrics.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Advanced comparison metrics for change detection evaluation."""
2
+
3
+ import numpy as np
4
+ from typing import Dict, Optional
5
+ from .evaluation import calculate_metrics
6
+
7
+
8
+ def calculate_change_statistics(
9
+ change_mask: np.ndarray,
10
+ change_confidence: np.ndarray
11
+ ) -> Dict:
12
+ """
13
+ Calculate statistics from a 2D change mask and confidence map.
14
+
15
+ Args:
16
+ change_mask: Binary 2D array (H, W), 1 = changed
17
+ change_confidence: Float 2D array (H, W) in [0, 1]
18
+
19
+ Returns:
20
+ Dict with total_pixels, changed_pixels, unchanged_pixels,
21
+ change_percentage, mean_confidence, min_confidence, max_confidence,
22
+ change_confidence_mean (mean conf among changed pixels only)
23
+ """
24
+ # Ensure 2D arrays
25
+ mask = change_mask.astype(np.uint8)
26
+ conf = change_confidence.astype(np.float32)
27
+
28
+ total_pixels = int(mask.size)
29
+ changed_pixels = int(np.sum(mask == 1))
30
+ unchanged_pixels = total_pixels - changed_pixels
31
+ change_percentage = 100.0 * changed_pixels / total_pixels if total_pixels > 0 else 0.0
32
+
33
+ mean_confidence = float(conf.mean())
34
+ min_confidence = float(conf.min())
35
+ max_confidence = float(conf.max())
36
+
37
+ # Mean confidence among changed pixels only
38
+ if changed_pixels > 0:
39
+ change_confidence_mean = float(conf[mask == 1].mean())
40
+ else:
41
+ change_confidence_mean = 0.0
42
+
43
+ return {
44
+ "total_pixels": total_pixels,
45
+ "changed_pixels": changed_pixels,
46
+ "unchanged_pixels": unchanged_pixels,
47
+ "change_percentage": change_percentage,
48
+ "mean_confidence": mean_confidence,
49
+ "min_confidence": min_confidence,
50
+ "max_confidence": max_confidence,
51
+ "change_confidence_mean": change_confidence_mean,
52
+ }
53
+
54
+
55
+ def compare_with_without_masking(
56
+ pred_with_mask: np.ndarray,
57
+ pred_without_mask: np.ndarray,
58
+ gt_mask: Optional[np.ndarray] = None
59
+ ) -> Dict:
60
+ """
61
+ Compare detection results with and without cloud masking.
62
+
63
+ Args:
64
+ pred_with_mask: Change mask produced WITH cloud masking (H, W)
65
+ pred_without_mask: Change mask produced WITHOUT cloud masking (H, W)
66
+ gt_mask: Optional ground truth mask for metric computation (H, W)
67
+
68
+ Returns:
69
+ Dict with pixel-level comparison and optional metric differences
70
+ """
71
+ agreement = int(np.sum(pred_with_mask == pred_without_mask))
72
+ total = int(pred_with_mask.size)
73
+ agreement_pct = 100.0 * agreement / total if total > 0 else 0.0
74
+
75
+ result = {
76
+ "agreement_pixels": agreement,
77
+ "total_pixels": total,
78
+ "agreement_percentage": agreement_pct,
79
+ "changed_with_mask": int(np.sum(pred_with_mask)),
80
+ "changed_without_mask": int(np.sum(pred_without_mask)),
81
+ }
82
+
83
+ if gt_mask is not None:
84
+ metrics_with = calculate_metrics(pred_with_mask, gt_mask)
85
+ metrics_without = calculate_metrics(pred_without_mask, gt_mask)
86
+
87
+ result["iou_with_mask"] = metrics_with["iou"]
88
+ result["iou_without_mask"] = metrics_without["iou"]
89
+ result["iou_improvement"] = metrics_with["iou"] - metrics_without["iou"]
90
+
91
+ result["f1_with_mask"] = metrics_with["f1"]
92
+ result["f1_without_mask"] = metrics_without["f1"]
93
+ result["f1_improvement"] = metrics_with["f1"] - metrics_without["f1"]
94
+
95
+ return result
utils/preprocessing.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocessing utilities for satellite imagery."""
2
+
3
+ import numpy as np
4
+ import cv2
5
+ from typing import Optional
6
+
7
+
8
+ def preprocess_image(
9
+ image: np.ndarray,
10
+ target_size: Optional[tuple] = None,
11
+ normalize: bool = True
12
+ ) -> np.ndarray:
13
+ """
14
+ Preprocess a satellite image for model input.
15
+
16
+ Args:
17
+ image: Input image (H, W, C), uint8 or float
18
+ target_size: Optional (width, height) to resize to
19
+ normalize: If True, output is float32 in [0, 1]
20
+
21
+ Returns:
22
+ Preprocessed image as float32 [0,1] or uint8 [0,255]
23
+ """
24
+ if image is None:
25
+ raise ValueError("Input image is None")
26
+
27
+ img = image.copy()
28
+
29
+ # Ensure 3-channel
30
+ if img.ndim == 2:
31
+ img = np.stack([img, img, img], axis=-1)
32
+ elif img.shape[2] == 1:
33
+ img = np.concatenate([img, img, img], axis=-1)
34
+ elif img.shape[2] > 3:
35
+ img = img[:, :, :3]
36
+
37
+ # Resize if requested
38
+ if target_size is not None:
39
+ img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
40
+
41
+ # Normalise
42
+ if normalize:
43
+ if img.dtype == np.uint8:
44
+ img = img.astype(np.float32) / 255.0
45
+ else:
46
+ img = np.clip(img, 0, 1).astype(np.float32)
47
+ else:
48
+ if img.dtype != np.uint8:
49
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
50
+
51
+ return img
52
+
53
+
54
+ def mask_clouds(
55
+ image: np.ndarray,
56
+ cloud_mask: np.ndarray,
57
+ fill_value: float = 0.0
58
+ ) -> np.ndarray:
59
+ """
60
+ Apply cloud mask to image, replacing cloud pixels with fill_value.
61
+
62
+ Args:
63
+ image: Input image (H, W, C)
64
+ cloud_mask: Binary mask (H, W), 1 = cloud
65
+ fill_value: Value to fill masked pixels with
66
+
67
+ Returns:
68
+ Masked image same dtype as input
69
+ """
70
+ masked = image.copy().astype(np.float32)
71
+ mask_bool = cloud_mask.astype(bool)
72
+
73
+ for c in range(masked.shape[2]):
74
+ masked[:, :, c][mask_bool] = fill_value
75
+
76
+ if image.dtype == np.uint8:
77
+ masked = np.clip(masked, 0, 255).astype(np.uint8)
78
+
79
+ return masked
utils/visualization.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utilities for satellite change detection."""
2
+
3
+ import numpy as np
4
+ import cv2
5
+ from typing import Tuple, Optional
6
+
7
+
8
+ def create_overlay(
9
+ image: np.ndarray,
10
+ mask: np.ndarray,
11
+ alpha: float = 0.5,
12
+ color: Tuple[int, int, int] = (255, 0, 0)
13
+ ) -> np.ndarray:
14
+ """
15
+ Overlay a binary mask on an image with transparency.
16
+
17
+ Args:
18
+ image: Base image (H, W, 3), uint8 or float [0,1]
19
+ mask: Binary mask (H, W), values 0 or 1
20
+ alpha: Overlay transparency (0 = invisible, 1 = opaque)
21
+ color: RGB color for the overlay
22
+
23
+ Returns:
24
+ Blended image as uint8 (H, W, 3)
25
+ """
26
+ # Convert image to uint8
27
+ if image.dtype != np.uint8:
28
+ base = (np.clip(image, 0, 1) * 255).astype(np.uint8)
29
+ else:
30
+ base = image.copy()
31
+
32
+ # Ensure 3 channels
33
+ if base.ndim == 2:
34
+ base = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB)
35
+
36
+ overlay = base.copy()
37
+ mask_bool = mask.astype(bool)
38
+
39
+ # Apply colour to masked region
40
+ overlay[mask_bool] = [color[0], color[1], color[2]]
41
+
42
+ # Blend
43
+ result = cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0)
44
+ return result.astype(np.uint8)
45
+
46
+
47
+ def visualize_predictions(
48
+ image: np.ndarray,
49
+ pred_mask: np.ndarray,
50
+ gt_mask: Optional[np.ndarray] = None,
51
+ confidence: Optional[np.ndarray] = None
52
+ ) -> np.ndarray:
53
+ """
54
+ Create a side-by-side visualization of image, prediction, and (optionally) ground truth.
55
+
56
+ Args:
57
+ image: Original image (H, W, 3)
58
+ pred_mask: Predicted binary mask (H, W)
59
+ gt_mask: Optional ground truth mask (H, W)
60
+ confidence: Optional confidence map (H, W)
61
+
62
+ Returns:
63
+ Combined visualization as uint8 (H, W*N, 3)
64
+ """
65
+ if image.dtype != np.uint8:
66
+ img_u8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
67
+ else:
68
+ img_u8 = image.copy()
69
+
70
+ h, w = img_u8.shape[:2]
71
+ panels = [img_u8]
72
+
73
+ # Prediction overlay (red)
74
+ pred_overlay = create_overlay(img_u8, pred_mask, alpha=0.5, color=(255, 0, 0))
75
+ panels.append(pred_overlay)
76
+
77
+ # Ground truth overlay (green)
78
+ if gt_mask is not None:
79
+ gt_overlay = create_overlay(img_u8, gt_mask, alpha=0.5, color=(0, 255, 0))
80
+ panels.append(gt_overlay)
81
+
82
+ # Confidence heatmap
83
+ if confidence is not None:
84
+ heatmap = (np.clip(confidence, 0, 1) * 255).astype(np.uint8)
85
+ heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
86
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
87
+ panels.append(heatmap_color)
88
+
89
+ # Resize all panels to same height
90
+ panels_resized = [
91
+ cv2.resize(p, (w, h), interpolation=cv2.INTER_LINEAR)
92
+ if p.shape[:2] != (h, w) else p
93
+ for p in panels
94
+ ]
95
+
96
+ return np.concatenate(panels_resized, axis=1)