File size: 2,715 Bytes
83a44e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pathlib
import safetensors.torch
import matplotlib.pyplot as plt
from sen2sr.models.opensr_baseline.cnn import CNNSR
from sen2sr.models.tricks import HardConstraint
from sen2sr.nonreference import srmodel


# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, *args, **kwargs):
    data_f = path / "example_data.safetensor"    
    sample = safetensors.torch.load_file(data_f)
    return  sample["lr"], sample["hr"]

def trainable_model(path, device: str = "cpu", *args, **kwargs):
    trainable_f = path / "model.safetensor"

    # Load model parameters
    sr_model_weights = safetensors.torch.load_file(trainable_f)
    sr_model = CNNSR(4, 4, 24, 4, True, False, 6)
    sr_model.load_state_dict(sr_model_weights)    
    sr_model.to(device)

    # Load HardConstraint
    hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
    hard_constraint = HardConstraint(low_pass_mask=hard_constraint_weights["weights"].to(device), device=device)    

    return srmodel(sr_model, hard_constraint, device)


def compiled_model(path, device: str = "cpu", *args, **kwargs):
    trainable_f = path / "model.safetensor"

    # Load model parameters
    sr_model_weights = safetensors.torch.load_file(trainable_f)
    sr_model = CNNSR(4, 4, 24, 4, True, False, 6)
    sr_model.load_state_dict(sr_model_weights)
    sr_model = sr_model.eval()
    sr_model.to(device)
    for param in sr_model.parameters():
        param.requires_grad = False

    # Load HardConstraint
    hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
    hard_constraint = HardConstraint(low_pass_mask=hard_constraint_weights["weights"].to(device), device=device)    

    return srmodel(sr_model, hard_constraint, device)


def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
    # Load model
    model = compiled_model(path, device)

    # Load data
    lr, hr = example_data(path)

    # Run model
    SuperX = model(lr.to(device))

    #Display results
    Xrgb = lr[0, 0:3].cpu().numpy().transpose(1, 2, 0)
    SuperXrgb = SuperX[0, 0:3].cpu().numpy().transpose(1, 2, 0)
    lr_slice = slice(16, 32+80)
    hr_slice = slice(lr_slice.start*4, lr_slice.stop*4)
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(Xrgb[lr_slice, lr_slice]*3)
    ax[0].set_title("Sentinel-2")
    ax[1].imshow(SuperXrgb[hr_slice, hr_slice]*3)
    ax[1].set_title("Super-Resolved")
    ax[2].imshow(hr[0, 0:3].cpu().numpy().transpose(1, 2, 0)[hr_slice, hr_slice]*3)
    ax[2].set_title("True HR")
    for a in ax:
        a.axis("off")
    fig.tight_layout()
    return fig