sen2sr / load.py
RhodWeo's picture
Add SEN2SRLite RGBN x4 with WEO standard interface
83a44e8 verified
import pathlib
import safetensors.torch
import matplotlib.pyplot as plt
from sen2sr.models.opensr_baseline.cnn import CNNSR
from sen2sr.models.tricks import HardConstraint
from sen2sr.nonreference import srmodel
# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, *args, **kwargs):
data_f = path / "example_data.safetensor"
sample = safetensors.torch.load_file(data_f)
return sample["lr"], sample["hr"]
def trainable_model(path, device: str = "cpu", *args, **kwargs):
trainable_f = path / "model.safetensor"
# Load model parameters
sr_model_weights = safetensors.torch.load_file(trainable_f)
sr_model = CNNSR(4, 4, 24, 4, True, False, 6)
sr_model.load_state_dict(sr_model_weights)
sr_model.to(device)
# Load HardConstraint
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
hard_constraint = HardConstraint(low_pass_mask=hard_constraint_weights["weights"].to(device), device=device)
return srmodel(sr_model, hard_constraint, device)
def compiled_model(path, device: str = "cpu", *args, **kwargs):
trainable_f = path / "model.safetensor"
# Load model parameters
sr_model_weights = safetensors.torch.load_file(trainable_f)
sr_model = CNNSR(4, 4, 24, 4, True, False, 6)
sr_model.load_state_dict(sr_model_weights)
sr_model = sr_model.eval()
sr_model.to(device)
for param in sr_model.parameters():
param.requires_grad = False
# Load HardConstraint
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
hard_constraint = HardConstraint(low_pass_mask=hard_constraint_weights["weights"].to(device), device=device)
return srmodel(sr_model, hard_constraint, device)
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
# Load model
model = compiled_model(path, device)
# Load data
lr, hr = example_data(path)
# Run model
SuperX = model(lr.to(device))
#Display results
Xrgb = lr[0, 0:3].cpu().numpy().transpose(1, 2, 0)
SuperXrgb = SuperX[0, 0:3].cpu().numpy().transpose(1, 2, 0)
lr_slice = slice(16, 32+80)
hr_slice = slice(lr_slice.start*4, lr_slice.stop*4)
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(Xrgb[lr_slice, lr_slice]*3)
ax[0].set_title("Sentinel-2")
ax[1].imshow(SuperXrgb[hr_slice, hr_slice]*3)
ax[1].set_title("Super-Resolved")
ax[2].imshow(hr[0, 0:3].cpu().numpy().transpose(1, 2, 0)[hr_slice, hr_slice]*3)
ax[2].set_title("True HR")
for a in ax:
a.axis("off")
fig.tight_layout()
return fig