sen2sr / predictor.py
RhodWeo's picture
Add SEN2SRLite RGBN x4 with WEO standard interface
83a44e8 verified
"""
predictor.py
============
Inference wrapper for WEO-SAS/sen2sr (SEN2SRLite RGBN x4).
Super-resolves 4-band Sentinel-2 RGBN imagery from 10 m to 2.5 m (4x).
Usage
-----
predictor = SEN2SRPredictor("./sen2sr")
# Array inference: (4, H, W) float32 in [0, 1] -> (4, H*4, W*4) float32
sr = predictor.predict(image)
# GeoTIFF pipeline (reads Sentinel-2 DN, writes SR GeoTIFF at 2.5 m)
predictor.predict_tif("s2_scene.tif", "s2_sr.tif", bands=[0, 1, 2, 3])
Requirements
------------
torch, numpy, rasterio, safetensors, sen2sr (pip install sen2sr)
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import rasterio
class SEN2SRPredictor:
"""
SEN2SRLite RGBN x4 predictor.
Parameters
----------
local_dir : local path to a downloaded WEO-SAS/sen2sr model repo
device : torch device (auto-detected if None)
model : pre-built srmodel callable; bypasses weight loading (used by sen2sr_pt.py)
"""
def __init__(
self,
local_dir: str,
device: Optional[torch.device] = None,
model = None,
):
local_dir = Path(local_dir)
with open(local_dir / "config.json") as f:
cfg = json.load(f)
self.local_dir = local_dir
self.in_channels = cfg["in_channels"]
self.out_channels = cfg["out_channels"]
self.scaling_factor = cfg["scaling_factor"]
self.patch_size = cfg["patch_size"]
self.overlap = cfg["overlap"]
self.p_low = cfg["p_low"]
self.p_high = cfg["p_high"]
self.normalization_factor = cfg["normalization_factor"]
self.description = cfg.get("description", "")
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if model is not None:
self.model = model
else:
self._load_model(local_dir, cfg)
# ------------------------------------------------------------------
# Model loading (only used when model= is not injected)
# ------------------------------------------------------------------
def _load_model(self, local_dir: Path, cfg: dict) -> None:
try:
import safetensors.torch
from sen2sr.models.opensr_baseline.cnn import CNNSR
from sen2sr.models.tricks import HardConstraint
from sen2sr.nonreference import srmodel
except ImportError as exc:
raise ImportError(
"sen2sr and safetensors are required. "
"Install: pip install sen2sr safetensors"
) from exc
device = self.device
weights = safetensors.torch.load_file(local_dir / cfg["weights_file"])
sr_model = CNNSR(
cfg["in_channels"],
cfg["out_channels"],
cfg["feature_channels"],
cfg["scaling_factor"],
cfg["bias"],
cfg["train_mode"],
cfg["num_blocks"],
)
sr_model.load_state_dict(weights)
sr_model.to(device).eval()
for p in sr_model.parameters():
p.requires_grad = False
hc_weights = safetensors.torch.load_file(local_dir / cfg["hard_constraint_file"])
hard_constraint = HardConstraint(
low_pass_mask=hc_weights["weights"].to(device), device=device
)
self.model = srmodel(sr_model, hard_constraint, device)
# ------------------------------------------------------------------
# Array inference
# ------------------------------------------------------------------
def predict(self, image: np.ndarray) -> np.ndarray:
"""
Run 4x super-resolution on a (C, H, W) float32 image.
Uses sen2sr.predict_large for images larger than patch_size so that
tile boundaries are blended seamlessly.
Parameters
----------
image : (C, H, W) float32, values in [0, 1]
C must equal in_channels (4 for RGBN)
Returns
-------
(C, H*4, W*4) float32 in the same radiometric range as the input
"""
if image.ndim != 3 or image.shape[0] != self.in_channels:
raise ValueError(
f"Expected ({self.in_channels}, H, W), got {image.shape}"
)
try:
import sen2sr
except ImportError as exc:
raise ImportError("pip install sen2sr") from exc
X = torch.from_numpy(image).float().to(self.device)
if image.shape[1] <= self.patch_size and image.shape[2] <= self.patch_size:
with torch.no_grad():
out = self.model(X.unsqueeze(0)).squeeze(0) # (C, H*sf, W*sf)
else:
out = sen2sr.predict_large(
model = self.model,
X = X,
overlap = self.overlap,
)
return out.cpu().numpy()
# ------------------------------------------------------------------
# GeoTIFF pipeline
# ------------------------------------------------------------------
def predict_tif(
self,
input_path: str,
output_path: str,
bands: Optional[List[int]] = None,
) -> None:
"""
Full GeoTIFF super-resolution pipeline.
Reads bands from the input GeoTIFF, normalises Sentinel-2 DN to [0, 1]
(divides by normalization_factor if values suggest DN range, otherwise
leaves as-is), runs 4x SR, and writes the output GeoTIFF with the
geotransform pixel size divided by scaling_factor.
Parameters
----------
input_path : path to input Sentinel-2 GeoTIFF
output_path : output path for the 2.5 m SR GeoTIFF
bands : 0-based band indices to read (default: [0, 1, 2, 3])
"""
bands = bands or list(range(self.in_channels))
with rasterio.open(input_path) as src:
arr = src.read([b + 1 for b in bands]).astype(np.float32)
profile = src.profile.copy()
# Auto-normalise: if values look like raw Sentinel-2 DN (> 2.0) divide
# by normalization_factor, otherwise assume already in [0, 1]
if arr.max() > 2.0:
arr = np.clip(arr / self.normalization_factor, 0.0, 1.0)
print(
f"SR inference model=sen2sr input={arr.shape} "
f"factor={self.scaling_factor}x {input_path}"
)
sr = self.predict(arr) # (C, H*sf, W*sf)
print(
f"Output shape {sr.shape} "
f"range [{sr.min():.4f}, {sr.max():.4f}]"
)
tf = profile["transform"]
new_tf = tf * tf.scale(1.0 / self.scaling_factor, 1.0 / self.scaling_factor)
out_profile = profile.copy()
out_profile.update(
count = sr.shape[0],
height = sr.shape[1],
width = sr.shape[2],
dtype = "float32",
transform = new_tf,
compress = "lzw",
)
out_profile.pop("photometric", None)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with rasterio.open(output_path, "w", **out_profile) as dst:
dst.write(sr)
sr_res = abs(tf.a) / self.scaling_factor
print(f"Written: {output_path} (res={sr_res:.4f} m)")