File size: 7,498 Bytes
83a44e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """
predictor.py
============
Inference wrapper for WEO-SAS/sen2sr (SEN2SRLite RGBN x4).
Super-resolves 4-band Sentinel-2 RGBN imagery from 10 m to 2.5 m (4x).
Usage
-----
predictor = SEN2SRPredictor("./sen2sr")
# Array inference: (4, H, W) float32 in [0, 1] -> (4, H*4, W*4) float32
sr = predictor.predict(image)
# GeoTIFF pipeline (reads Sentinel-2 DN, writes SR GeoTIFF at 2.5 m)
predictor.predict_tif("s2_scene.tif", "s2_sr.tif", bands=[0, 1, 2, 3])
Requirements
------------
torch, numpy, rasterio, safetensors, sen2sr (pip install sen2sr)
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import rasterio
class SEN2SRPredictor:
"""
SEN2SRLite RGBN x4 predictor.
Parameters
----------
local_dir : local path to a downloaded WEO-SAS/sen2sr model repo
device : torch device (auto-detected if None)
model : pre-built srmodel callable; bypasses weight loading (used by sen2sr_pt.py)
"""
def __init__(
self,
local_dir: str,
device: Optional[torch.device] = None,
model = None,
):
local_dir = Path(local_dir)
with open(local_dir / "config.json") as f:
cfg = json.load(f)
self.local_dir = local_dir
self.in_channels = cfg["in_channels"]
self.out_channels = cfg["out_channels"]
self.scaling_factor = cfg["scaling_factor"]
self.patch_size = cfg["patch_size"]
self.overlap = cfg["overlap"]
self.p_low = cfg["p_low"]
self.p_high = cfg["p_high"]
self.normalization_factor = cfg["normalization_factor"]
self.description = cfg.get("description", "")
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if model is not None:
self.model = model
else:
self._load_model(local_dir, cfg)
# ------------------------------------------------------------------
# Model loading (only used when model= is not injected)
# ------------------------------------------------------------------
def _load_model(self, local_dir: Path, cfg: dict) -> None:
try:
import safetensors.torch
from sen2sr.models.opensr_baseline.cnn import CNNSR
from sen2sr.models.tricks import HardConstraint
from sen2sr.nonreference import srmodel
except ImportError as exc:
raise ImportError(
"sen2sr and safetensors are required. "
"Install: pip install sen2sr safetensors"
) from exc
device = self.device
weights = safetensors.torch.load_file(local_dir / cfg["weights_file"])
sr_model = CNNSR(
cfg["in_channels"],
cfg["out_channels"],
cfg["feature_channels"],
cfg["scaling_factor"],
cfg["bias"],
cfg["train_mode"],
cfg["num_blocks"],
)
sr_model.load_state_dict(weights)
sr_model.to(device).eval()
for p in sr_model.parameters():
p.requires_grad = False
hc_weights = safetensors.torch.load_file(local_dir / cfg["hard_constraint_file"])
hard_constraint = HardConstraint(
low_pass_mask=hc_weights["weights"].to(device), device=device
)
self.model = srmodel(sr_model, hard_constraint, device)
# ------------------------------------------------------------------
# Array inference
# ------------------------------------------------------------------
def predict(self, image: np.ndarray) -> np.ndarray:
"""
Run 4x super-resolution on a (C, H, W) float32 image.
Uses sen2sr.predict_large for images larger than patch_size so that
tile boundaries are blended seamlessly.
Parameters
----------
image : (C, H, W) float32, values in [0, 1]
C must equal in_channels (4 for RGBN)
Returns
-------
(C, H*4, W*4) float32 in the same radiometric range as the input
"""
if image.ndim != 3 or image.shape[0] != self.in_channels:
raise ValueError(
f"Expected ({self.in_channels}, H, W), got {image.shape}"
)
try:
import sen2sr
except ImportError as exc:
raise ImportError("pip install sen2sr") from exc
X = torch.from_numpy(image).float().to(self.device)
if image.shape[1] <= self.patch_size and image.shape[2] <= self.patch_size:
with torch.no_grad():
out = self.model(X.unsqueeze(0)).squeeze(0) # (C, H*sf, W*sf)
else:
out = sen2sr.predict_large(
model = self.model,
X = X,
overlap = self.overlap,
)
return out.cpu().numpy()
# ------------------------------------------------------------------
# GeoTIFF pipeline
# ------------------------------------------------------------------
def predict_tif(
self,
input_path: str,
output_path: str,
bands: Optional[List[int]] = None,
) -> None:
"""
Full GeoTIFF super-resolution pipeline.
Reads bands from the input GeoTIFF, normalises Sentinel-2 DN to [0, 1]
(divides by normalization_factor if values suggest DN range, otherwise
leaves as-is), runs 4x SR, and writes the output GeoTIFF with the
geotransform pixel size divided by scaling_factor.
Parameters
----------
input_path : path to input Sentinel-2 GeoTIFF
output_path : output path for the 2.5 m SR GeoTIFF
bands : 0-based band indices to read (default: [0, 1, 2, 3])
"""
bands = bands or list(range(self.in_channels))
with rasterio.open(input_path) as src:
arr = src.read([b + 1 for b in bands]).astype(np.float32)
profile = src.profile.copy()
# Auto-normalise: if values look like raw Sentinel-2 DN (> 2.0) divide
# by normalization_factor, otherwise assume already in [0, 1]
if arr.max() > 2.0:
arr = np.clip(arr / self.normalization_factor, 0.0, 1.0)
print(
f"SR inference model=sen2sr input={arr.shape} "
f"factor={self.scaling_factor}x {input_path}"
)
sr = self.predict(arr) # (C, H*sf, W*sf)
print(
f"Output shape {sr.shape} "
f"range [{sr.min():.4f}, {sr.max():.4f}]"
)
tf = profile["transform"]
new_tf = tf * tf.scale(1.0 / self.scaling_factor, 1.0 / self.scaling_factor)
out_profile = profile.copy()
out_profile.update(
count = sr.shape[0],
height = sr.shape[1],
width = sr.shape[2],
dtype = "float32",
transform = new_tf,
compress = "lzw",
)
out_profile.pop("photometric", None)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with rasterio.open(output_path, "w", **out_profile) as dst:
dst.write(sr)
sr_res = abs(tf.a) / self.scaling_factor
print(f"Written: {output_path} (res={sr_res:.4f} m)")
|