|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import pytorch_lightning as pl |
|
|
import segmentation_models_pytorch as smp |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class MSSSegmentationModel(pl.LightningModule): |
|
|
"""UNet para cloud segmentation en MSS.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 4, |
|
|
num_classes: int = 4, |
|
|
encoder: str = "efficientnet-b3", |
|
|
lr: float = 3e-4, |
|
|
weight_decay: float = 1e-4, |
|
|
): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
|
|
|
self.model = smp.Unet( |
|
|
encoder_name=encoder, |
|
|
encoder_weights=None, |
|
|
in_channels=in_channels, |
|
|
classes=num_classes, |
|
|
encoder_depth=5, |
|
|
activation=None, |
|
|
decoder_attention_type="scse", |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
def get_spline_window(size: int, power: int = 2) -> np.ndarray: |
|
|
"""Hann window for smooth blending.""" |
|
|
intersection = np.hanning(size) |
|
|
window_2d = np.outer(intersection, intersection) |
|
|
return (window_2d ** power).astype(np.float32) |
|
|
|
|
|
|
|
|
def apply_physical_rules( |
|
|
pred: np.ndarray, |
|
|
image: np.ndarray, |
|
|
merge_clouds: bool = False, |
|
|
) -> np.ndarray: |
|
|
"""Apply physical rules for saturated thick clouds.""" |
|
|
saturation_threshold = 0.4 |
|
|
|
|
|
pred = pred.copy() |
|
|
|
|
|
|
|
|
nodata_mask = np.all(image == 0, axis=0) |
|
|
|
|
|
|
|
|
bright_b0 = image[0] > saturation_threshold |
|
|
bright_b1 = image[1] > saturation_threshold * 0.80 |
|
|
saturated_mask = bright_b0 & bright_b1 |
|
|
|
|
|
|
|
|
if merge_clouds: |
|
|
pred[saturated_mask] = 1 |
|
|
else: |
|
|
pred[saturated_mask] = 2 |
|
|
|
|
|
|
|
|
pred[nodata_mask] = 0 |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
def compiled_model( |
|
|
model_dir: Path, |
|
|
stac_item=None, |
|
|
device: str = "cpu", |
|
|
merge_clouds: bool = False, |
|
|
**kwargs |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Load compiled model for inference. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing the .ckpt file |
|
|
stac_item: STAC item metadata (optional) |
|
|
device: 'cpu' or 'cuda' |
|
|
merge_clouds: If True, output 3 classes (clear, cloud, shadow) |
|
|
If False, output 4 classes (clear, thin, thick, shadow) |
|
|
|
|
|
Returns: |
|
|
Loaded model in eval mode |
|
|
""" |
|
|
ckpt_files = list(model_dir.glob("*.ckpt")) |
|
|
if not ckpt_files: |
|
|
raise FileNotFoundError(f"No .ckpt file found in {model_dir}") |
|
|
|
|
|
ckpt_path = ckpt_files[0] |
|
|
|
|
|
model = MSSSegmentationModel.load_from_checkpoint( |
|
|
ckpt_path, |
|
|
map_location=device |
|
|
) |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
model.merge_clouds = merge_clouds |
|
|
|
|
|
print(f"✅ Model loaded from {ckpt_path.name}") |
|
|
print(f" Device: {device}") |
|
|
print(f" Classes: {'3 (merged)' if merge_clouds else '4 (original)'}") |
|
|
|
|
|
return model |
|
|
|
|
|
def predict_large( |
|
|
image: np.ndarray, |
|
|
model: nn.Module, |
|
|
chunk_size: int = 512, |
|
|
overlap: int = None, |
|
|
batch_size: int = 1, |
|
|
device: str = "cpu", |
|
|
merge_clouds: bool = False, |
|
|
apply_rules: bool = False, |
|
|
max_direct_size: int = 1024, |
|
|
**kwargs |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Predict on images of any size. |
|
|
|
|
|
Automatically detects if model has 3 or 4 classes. |
|
|
""" |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
num_classes = model.hparams.get('num_classes', 4) |
|
|
is_3class_model = (num_classes == 3) |
|
|
|
|
|
C, H, W = image.shape |
|
|
|
|
|
if overlap is None: |
|
|
overlap = chunk_size // 2 |
|
|
|
|
|
|
|
|
if max(H, W) <= max_direct_size: |
|
|
with torch.no_grad(): |
|
|
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device) |
|
|
logits = model(img_tensor) |
|
|
|
|
|
if is_3class_model: |
|
|
|
|
|
pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8) |
|
|
elif merge_clouds: |
|
|
|
|
|
probs = torch.softmax(logits, dim=1) |
|
|
probs_merged = torch.zeros(1, 3, H, W, device=device) |
|
|
probs_merged[:, 0] = probs[:, 0] |
|
|
probs_merged[:, 1] = probs[:, 1] + probs[:, 2] |
|
|
probs_merged[:, 2] = probs[:, 3] |
|
|
pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8) |
|
|
else: |
|
|
|
|
|
pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8) |
|
|
|
|
|
if apply_rules: |
|
|
pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds) |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
step = chunk_size - overlap |
|
|
|
|
|
pad_h = (step - (H - chunk_size) % step) % step |
|
|
pad_w = (step - (W - chunk_size) % step) % step |
|
|
|
|
|
pad_top = pad_h // 2 |
|
|
pad_bottom = pad_h - pad_top |
|
|
pad_left = pad_w // 2 |
|
|
pad_right = pad_w - pad_left |
|
|
|
|
|
image_padded = np.pad( |
|
|
image, |
|
|
((0, 0), (pad_top, pad_bottom), (pad_left, pad_right)), |
|
|
mode="reflect" |
|
|
) |
|
|
|
|
|
_, H_pad, W_pad = image_padded.shape |
|
|
|
|
|
|
|
|
probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32) |
|
|
weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32) |
|
|
|
|
|
window = get_spline_window(chunk_size, power=2) |
|
|
|
|
|
coords = [] |
|
|
for r in range(0, H_pad - chunk_size + 1, step): |
|
|
for c in range(0, W_pad - chunk_size + 1, step): |
|
|
coords.append((r, c)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(0, len(coords), batch_size): |
|
|
batch_coords = coords[i:i + batch_size] |
|
|
|
|
|
tiles = np.stack([ |
|
|
image_padded[:, r:r + chunk_size, c:c + chunk_size] |
|
|
for r, c in batch_coords |
|
|
]) |
|
|
|
|
|
tiles_tensor = torch.from_numpy(tiles).float().to(device) |
|
|
logits = model(tiles_tensor) |
|
|
probs = torch.softmax(logits, dim=1).cpu().numpy() |
|
|
|
|
|
for j, (r, c) in enumerate(batch_coords): |
|
|
probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window |
|
|
weight_sum[r:r + chunk_size, c:c + chunk_size] += window |
|
|
|
|
|
weight_sum = np.maximum(weight_sum, 1e-8) |
|
|
probs_final = probs_sum / weight_sum |
|
|
|
|
|
probs_final = probs_final[:, pad_top:pad_top + H, pad_left:pad_left + W] |
|
|
|
|
|
|
|
|
if is_3class_model: |
|
|
|
|
|
pred = np.argmax(probs_final, axis=0).astype(np.uint8) |
|
|
elif merge_clouds: |
|
|
|
|
|
probs_merged = np.zeros((3, H, W), dtype=np.float32) |
|
|
probs_merged[0] = probs_final[0] |
|
|
probs_merged[1] = probs_final[1] + probs_final[2] |
|
|
probs_merged[2] = probs_final[3] |
|
|
pred = np.argmax(probs_merged, axis=0).astype(np.uint8) |
|
|
else: |
|
|
pred = np.argmax(probs_final, axis=0).astype(np.uint8) |
|
|
|
|
|
if apply_rules: |
|
|
pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds) |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
def example_data(model_dir: Path, **kwargs): |
|
|
"""Load example data for testing.""" |
|
|
example_path = model_dir / "example_mss.npy" |
|
|
|
|
|
if not example_path.exists(): |
|
|
print("⚠️ No example data found, generating synthetic") |
|
|
return np.random.rand(4, 512, 512).astype(np.float32) * 0.5 |
|
|
|
|
|
return np.load(example_path) |
|
|
|
|
|
|
|
|
def display_results( |
|
|
model_dir: Path, |
|
|
image: np.ndarray, |
|
|
prediction: np.ndarray, |
|
|
stac_item=None, |
|
|
**kwargs |
|
|
): |
|
|
"""Display prediction results.""" |
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.colors import ListedColormap |
|
|
except ImportError: |
|
|
print("⚠️ matplotlib not installed, skipping visualization") |
|
|
return |
|
|
|
|
|
merge_clouds = prediction.max() <= 2 |
|
|
|
|
|
if merge_clouds: |
|
|
colors = ['#2E7D32', '#FFFFFF', '#424242'] |
|
|
labels = ['Clear', 'Cloud', 'Shadow'] |
|
|
else: |
|
|
colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242'] |
|
|
labels = ['Clear', 'Thin Cloud', 'Thick Cloud', 'Shadow'] |
|
|
|
|
|
cmap = ListedColormap(colors) |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
|
|
|
|
|
|
rgb = np.stack([image[1], image[0], image[2]], axis=-1) |
|
|
rgb = np.clip(rgb * 3, 0, 1) |
|
|
axes[0].imshow(rgb) |
|
|
axes[0].set_title("MSS RGB Composite") |
|
|
axes[0].axis('off') |
|
|
|
|
|
|
|
|
im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1) |
|
|
axes[1].set_title("Cloud Detection") |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels))) |
|
|
cbar.ax.set_yticklabels(labels) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |