import torch import torch.nn as nn import numpy as np from pathlib import Path import pytorch_lightning as pl import segmentation_models_pytorch as smp from tqdm import tqdm class MSSSegmentationModel(pl.LightningModule): """UNet para cloud segmentation en MSS.""" def __init__( self, in_channels: int = 4, num_classes: int = 4, encoder: str = "efficientnet-b3", lr: float = 3e-4, weight_decay: float = 1e-4, ): super().__init__() self.save_hyperparameters() self.model = smp.Unet( encoder_name=encoder, encoder_weights=None, in_channels=in_channels, classes=num_classes, encoder_depth=5, activation=None, decoder_attention_type="scse", ) def forward(self, x): return self.model(x) def get_spline_window(size: int, power: int = 2) -> np.ndarray: """Hann window for smooth blending.""" intersection = np.hanning(size) window_2d = np.outer(intersection, intersection) return (window_2d ** power).astype(np.float32) def apply_physical_rules( pred: np.ndarray, image: np.ndarray, merge_clouds: bool = False, ) -> np.ndarray: """Apply physical rules for saturated thick clouds.""" saturation_threshold = 0.4 pred = pred.copy() # Nodata mask nodata_mask = np.all(image == 0, axis=0) # Saturated clouds (high values in visible bands) bright_b0 = image[0] > saturation_threshold bright_b1 = image[1] > saturation_threshold * 0.80 saturated_mask = bright_b0 & bright_b1 # Assign thick cloud class if merge_clouds: pred[saturated_mask] = 1 # Cloud (merged) else: pred[saturated_mask] = 2 # Thick cloud # Set nodata to clear pred[nodata_mask] = 0 return pred def compiled_model( model_dir: Path, stac_item=None, device: str = "cpu", merge_clouds: bool = False, **kwargs ) -> nn.Module: """ Load compiled model for inference. Args: model_dir: Directory containing the .ckpt file stac_item: STAC item metadata (optional) device: 'cpu' or 'cuda' merge_clouds: If True, output 3 classes (clear, cloud, shadow) If False, output 4 classes (clear, thin, thick, shadow) Returns: Loaded model in eval mode """ ckpt_files = list(model_dir.glob("*.ckpt")) if not ckpt_files: raise FileNotFoundError(f"No .ckpt file found in {model_dir}") ckpt_path = ckpt_files[0] model = MSSSegmentationModel.load_from_checkpoint( ckpt_path, map_location=device ) model.eval() model.to(device) for param in model.parameters(): param.requires_grad = False model.merge_clouds = merge_clouds print(f"✅ Model loaded from {ckpt_path.name}") print(f" Device: {device}") print(f" Classes: {'3 (merged)' if merge_clouds else '4 (original)'}") return model def predict_large( image: np.ndarray, model: nn.Module, chunk_size: int = 512, overlap: int = None, batch_size: int = 1, device: str = "cpu", merge_clouds: bool = False, apply_rules: bool = False, max_direct_size: int = 1024, **kwargs ) -> np.ndarray: """ Predict on images of any size. Automatically detects if model has 3 or 4 classes. """ model.eval() model.to(device) # Detect number of classes in the model num_classes = model.hparams.get('num_classes', 4) is_3class_model = (num_classes == 3) C, H, W = image.shape if overlap is None: overlap = chunk_size // 2 # === DIRECT INFERENCE FOR SMALL IMAGES === if max(H, W) <= max_direct_size: with torch.no_grad(): img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device) logits = model(img_tensor) if is_3class_model: # The model already has 3 classes: 0=clear, 1=cloud, 2=shadow pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8) elif merge_clouds: # Model 4 classes → merge to 3 probs = torch.softmax(logits, dim=1) probs_merged = torch.zeros(1, 3, H, W, device=device) probs_merged[:, 0] = probs[:, 0] # Clear probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # Cloud probs_merged[:, 2] = probs[:, 3] # Shadow pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8) else: # Model 4 classes without merge pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8) if apply_rules: pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds) return pred # === SLIDING WINDOW FOR LARGE IMAGES === step = chunk_size - overlap pad_h = (step - (H - chunk_size) % step) % step pad_w = (step - (W - chunk_size) % step) % step pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left image_padded = np.pad( image, ((0, 0), (pad_top, pad_bottom), (pad_left, pad_right)), mode="reflect" ) _, H_pad, W_pad = image_padded.shape # Buffers according to number of classes probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32) weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32) window = get_spline_window(chunk_size, power=2) coords = [] for r in range(0, H_pad - chunk_size + 1, step): for c in range(0, W_pad - chunk_size + 1, step): coords.append((r, c)) with torch.no_grad(): for i in range(0, len(coords), batch_size): batch_coords = coords[i:i + batch_size] tiles = np.stack([ image_padded[:, r:r + chunk_size, c:c + chunk_size] for r, c in batch_coords ]) tiles_tensor = torch.from_numpy(tiles).float().to(device) logits = model(tiles_tensor) probs = torch.softmax(logits, dim=1).cpu().numpy() for j, (r, c) in enumerate(batch_coords): probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window weight_sum[r:r + chunk_size, c:c + chunk_size] += window weight_sum = np.maximum(weight_sum, 1e-8) probs_final = probs_sum / weight_sum probs_final = probs_final[:, pad_top:pad_top + H, pad_left:pad_left + W] # Final forecast if is_3class_model: # It already has 3 classes pred = np.argmax(probs_final, axis=0).astype(np.uint8) elif merge_clouds: # Merge 4 → 3 probs_merged = np.zeros((3, H, W), dtype=np.float32) probs_merged[0] = probs_final[0] probs_merged[1] = probs_final[1] + probs_final[2] probs_merged[2] = probs_final[3] pred = np.argmax(probs_merged, axis=0).astype(np.uint8) else: pred = np.argmax(probs_final, axis=0).astype(np.uint8) if apply_rules: pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds) return pred def example_data(model_dir: Path, **kwargs): """Load example data for testing.""" example_path = model_dir / "example_mss.npy" if not example_path.exists(): print("⚠️ No example data found, generating synthetic") return np.random.rand(4, 512, 512).astype(np.float32) * 0.5 return np.load(example_path) def display_results( model_dir: Path, image: np.ndarray, prediction: np.ndarray, stac_item=None, **kwargs ): """Display prediction results.""" try: import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap except ImportError: print("⚠️ matplotlib not installed, skipping visualization") return merge_clouds = prediction.max() <= 2 if merge_clouds: colors = ['#2E7D32', '#FFFFFF', '#424242'] labels = ['Clear', 'Cloud', 'Shadow'] else: colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242'] labels = ['Clear', 'Thin Cloud', 'Thick Cloud', 'Shadow'] cmap = ListedColormap(colors) fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # RGB composite rgb = np.stack([image[1], image[0], image[2]], axis=-1) rgb = np.clip(rgb * 3, 0, 1) axes[0].imshow(rgb) axes[0].set_title("MSS RGB Composite") axes[0].axis('off') # Prediction im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1) axes[1].set_title("Cloud Detection") axes[1].axis('off') # Colorbar cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels))) cbar.ax.set_yticklabels(labels) plt.tight_layout() plt.show()