QA4EO-2 / load.py
JulioContrerasH's picture
Update load.py
ebf386f verified
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from tqdm import tqdm
class MSSSegmentationModel(pl.LightningModule):
"""UNet para cloud segmentation en MSS."""
def __init__(
self,
in_channels: int = 4,
num_classes: int = 4,
encoder: str = "efficientnet-b3",
lr: float = 3e-4,
weight_decay: float = 1e-4,
):
super().__init__()
self.save_hyperparameters()
self.model = smp.Unet(
encoder_name=encoder,
encoder_weights=None,
in_channels=in_channels,
classes=num_classes,
encoder_depth=5,
activation=None,
decoder_attention_type="scse",
)
def forward(self, x):
return self.model(x)
def get_spline_window(size: int, power: int = 2) -> np.ndarray:
"""Hann window for smooth blending."""
intersection = np.hanning(size)
window_2d = np.outer(intersection, intersection)
return (window_2d ** power).astype(np.float32)
def apply_physical_rules(
pred: np.ndarray,
image: np.ndarray,
merge_clouds: bool = False,
) -> np.ndarray:
"""Apply physical rules for saturated thick clouds."""
saturation_threshold = 0.4
pred = pred.copy()
# Nodata mask
nodata_mask = np.all(image == 0, axis=0)
# Saturated clouds (high values in visible bands)
bright_b0 = image[0] > saturation_threshold
bright_b1 = image[1] > saturation_threshold * 0.80
saturated_mask = bright_b0 & bright_b1
# Assign thick cloud class
if merge_clouds:
pred[saturated_mask] = 1 # Cloud (merged)
else:
pred[saturated_mask] = 2 # Thick cloud
# Set nodata to clear
pred[nodata_mask] = 0
return pred
def compiled_model(
model_dir: Path,
stac_item=None,
device: str = "cpu",
merge_clouds: bool = False,
**kwargs
) -> nn.Module:
"""
Load compiled model for inference.
Args:
model_dir: Directory containing the .ckpt file
stac_item: STAC item metadata (optional)
device: 'cpu' or 'cuda'
merge_clouds: If True, output 3 classes (clear, cloud, shadow)
If False, output 4 classes (clear, thin, thick, shadow)
Returns:
Loaded model in eval mode
"""
ckpt_files = list(model_dir.glob("*.ckpt"))
if not ckpt_files:
raise FileNotFoundError(f"No .ckpt file found in {model_dir}")
ckpt_path = ckpt_files[0]
model = MSSSegmentationModel.load_from_checkpoint(
ckpt_path,
map_location=device
)
model.eval()
model.to(device)
for param in model.parameters():
param.requires_grad = False
model.merge_clouds = merge_clouds
print(f"✅ Model loaded from {ckpt_path.name}")
print(f" Device: {device}")
print(f" Classes: {'3 (merged)' if merge_clouds else '4 (original)'}")
return model
def predict_large(
image: np.ndarray,
model: nn.Module,
chunk_size: int = 512,
overlap: int = None,
batch_size: int = 1,
device: str = "cpu",
merge_clouds: bool = False,
apply_rules: bool = False,
max_direct_size: int = 1024,
**kwargs
) -> np.ndarray:
"""
Predict on images of any size.
Automatically detects if model has 3 or 4 classes.
"""
model.eval()
model.to(device)
# Detect number of classes in the model
num_classes = model.hparams.get('num_classes', 4)
is_3class_model = (num_classes == 3)
C, H, W = image.shape
if overlap is None:
overlap = chunk_size // 2
# === DIRECT INFERENCE FOR SMALL IMAGES ===
if max(H, W) <= max_direct_size:
with torch.no_grad():
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
logits = model(img_tensor)
if is_3class_model:
# The model already has 3 classes: 0=clear, 1=cloud, 2=shadow
pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
elif merge_clouds:
# Model 4 classes → merge to 3
probs = torch.softmax(logits, dim=1)
probs_merged = torch.zeros(1, 3, H, W, device=device)
probs_merged[:, 0] = probs[:, 0] # Clear
probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # Cloud
probs_merged[:, 2] = probs[:, 3] # Shadow
pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
else:
# Model 4 classes without merge
pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
if apply_rules:
pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds)
return pred
# === SLIDING WINDOW FOR LARGE IMAGES ===
step = chunk_size - overlap
pad_h = (step - (H - chunk_size) % step) % step
pad_w = (step - (W - chunk_size) % step) % step
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
image_padded = np.pad(
image,
((0, 0), (pad_top, pad_bottom), (pad_left, pad_right)),
mode="reflect"
)
_, H_pad, W_pad = image_padded.shape
# Buffers according to number of classes
probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
window = get_spline_window(chunk_size, power=2)
coords = []
for r in range(0, H_pad - chunk_size + 1, step):
for c in range(0, W_pad - chunk_size + 1, step):
coords.append((r, c))
with torch.no_grad():
for i in range(0, len(coords), batch_size):
batch_coords = coords[i:i + batch_size]
tiles = np.stack([
image_padded[:, r:r + chunk_size, c:c + chunk_size]
for r, c in batch_coords
])
tiles_tensor = torch.from_numpy(tiles).float().to(device)
logits = model(tiles_tensor)
probs = torch.softmax(logits, dim=1).cpu().numpy()
for j, (r, c) in enumerate(batch_coords):
probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
weight_sum[r:r + chunk_size, c:c + chunk_size] += window
weight_sum = np.maximum(weight_sum, 1e-8)
probs_final = probs_sum / weight_sum
probs_final = probs_final[:, pad_top:pad_top + H, pad_left:pad_left + W]
# Final forecast
if is_3class_model:
# It already has 3 classes
pred = np.argmax(probs_final, axis=0).astype(np.uint8)
elif merge_clouds:
# Merge 4 → 3
probs_merged = np.zeros((3, H, W), dtype=np.float32)
probs_merged[0] = probs_final[0]
probs_merged[1] = probs_final[1] + probs_final[2]
probs_merged[2] = probs_final[3]
pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
else:
pred = np.argmax(probs_final, axis=0).astype(np.uint8)
if apply_rules:
pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds)
return pred
def example_data(model_dir: Path, **kwargs):
"""Load example data for testing."""
example_path = model_dir / "example_mss.npy"
if not example_path.exists():
print("⚠️ No example data found, generating synthetic")
return np.random.rand(4, 512, 512).astype(np.float32) * 0.5
return np.load(example_path)
def display_results(
model_dir: Path,
image: np.ndarray,
prediction: np.ndarray,
stac_item=None,
**kwargs
):
"""Display prediction results."""
try:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
except ImportError:
print("⚠️ matplotlib not installed, skipping visualization")
return
merge_clouds = prediction.max() <= 2
if merge_clouds:
colors = ['#2E7D32', '#FFFFFF', '#424242']
labels = ['Clear', 'Cloud', 'Shadow']
else:
colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242']
labels = ['Clear', 'Thin Cloud', 'Thick Cloud', 'Shadow']
cmap = ListedColormap(colors)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# RGB composite
rgb = np.stack([image[1], image[0], image[2]], axis=-1)
rgb = np.clip(rgb * 3, 0, 1)
axes[0].imshow(rgb)
axes[0].set_title("MSS RGB Composite")
axes[0].axis('off')
# Prediction
im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1)
axes[1].set_title("Cloud Detection")
axes[1].axis('off')
# Colorbar
cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels)))
cbar.ax.set_yticklabels(labels)
plt.tight_layout()
plt.show()