File size: 2,717 Bytes
ea2435e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import importlib.util
import pathlib

import matplotlib.pyplot as plt
import safetensors.torch
import torch


def load_model_module(model_path: pathlib.Path):
    model_path = model_path.resolve()

    spec = importlib.util.spec_from_file_location("model", model_path)
    model = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model)

    return model


class GeneratorNormal(torch.nn.Module):
    def __init__(self, model):
        super(GeneratorNormal, self).__init__()
        self.model = model

    def forward(self, X):
        X = torch.clamp(X, 0, 1)
        X = torch.nan_to_num(X, nan=1.0)
        X = X.permute(0, 2, 3, 4, 1).contiguous()
        return self.model(X)[0].permute(0, 4, 1, 2, 3)


def example_data(path: pathlib.Path, *args, **kwargs):
    data_f = path / "example_data.safetensor"
    return safetensors.torch.load_file(data_f)["example_data"][None].float()


def trainable_model(path, device="cpu", *args, **kwargs):
    weights = safetensors.torch.load_file(path / "model.safetensor")
    model = load_model_module(path / "model.py").Generator(
        device=device, inputChannels=4, outputChannels=4
    )
    model.load_state_dict(weights)
    return model


def compiled_model(path, device="cpu", *args, **kwargs):
    weights = safetensors.torch.load_file(path / "model.safetensor")
    model = load_model_module(path / "model.py").Generator(
        device=device, inputChannels=4, outputChannels=4
    )
    model.load_state_dict(weights)
    model = model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return GeneratorNormal(model.to(device))


def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
    # Load model
    model = compiled_model(path, device)

    # Load data
    s2_ts = example_data(path)

    # Run model
    gap_filled = model(s2_ts.to(device))

    # Convert to CPU and detach for plotting
    s2_ts = s2_ts.squeeze(0).detach().cpu()  # [T, C, H, W]
    gap_filled = gap_filled.squeeze(0).detach().cpu()

    num_timesteps = s2_ts.shape[0]
    rgb_indices = [2, 1, 0]  # Assuming RGB is BGR in channel order (4 bands)

    fig, axs = plt.subplots(2, num_timesteps, figsize=(3 * num_timesteps, 6))

    for t in range(num_timesteps):
        original_rgb = s2_ts[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy()
        filled_rgb = gap_filled[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy()

        axs[0, t].imshow(original_rgb * 3)
        axs[0, t].axis("off")
        axs[0, t].set_title(f"Original t={t}")

        axs[1, t].imshow(filled_rgb * 3)
        axs[1, t].axis("off")
        axs[1, t].set_title(f"Filled t={t}")

    plt.tight_layout()
    return fig