import importlib.util import pathlib import matplotlib.pyplot as plt import safetensors.torch import torch def load_model_module(model_path: pathlib.Path): model_path = model_path.resolve() spec = importlib.util.spec_from_file_location("model", model_path) model = importlib.util.module_from_spec(spec) spec.loader.exec_module(model) return model class GeneratorNormal(torch.nn.Module): def __init__(self, model): super(GeneratorNormal, self).__init__() self.model = model def forward(self, X): X = torch.clamp(X, 0, 1) X = torch.nan_to_num(X, nan=1.0) X = X.permute(0, 2, 3, 4, 1).contiguous() return self.model(X)[0].permute(0, 4, 1, 2, 3) def example_data(path: pathlib.Path, *args, **kwargs): data_f = path / "example_data.safetensor" return safetensors.torch.load_file(data_f)["example_data"][None].float() def trainable_model(path, device="cpu", *args, **kwargs): weights = safetensors.torch.load_file(path / "model.safetensor") model = load_model_module(path / "model.py").Generator( device=device, inputChannels=4, outputChannels=4 ) model.load_state_dict(weights) return model def compiled_model(path, device="cpu", *args, **kwargs): weights = safetensors.torch.load_file(path / "model.safetensor") model = load_model_module(path / "model.py").Generator( device=device, inputChannels=4, outputChannels=4 ) model.load_state_dict(weights) model = model.eval() for param in model.parameters(): param.requires_grad = False return GeneratorNormal(model.to(device)) def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs): # Load model model = compiled_model(path, device) # Load data s2_ts = example_data(path) # Run model gap_filled = model(s2_ts.to(device)) # Convert to CPU and detach for plotting s2_ts = s2_ts.squeeze(0).detach().cpu() # [T, C, H, W] gap_filled = gap_filled.squeeze(0).detach().cpu() num_timesteps = s2_ts.shape[0] rgb_indices = [2, 1, 0] # Assuming RGB is BGR in channel order (4 bands) fig, axs = plt.subplots(2, num_timesteps, figsize=(3 * num_timesteps, 6)) for t in range(num_timesteps): original_rgb = s2_ts[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy() filled_rgb = gap_filled[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy() axs[0, t].imshow(original_rgb * 3) axs[0, t].axis("off") axs[0, t].set_title(f"Original t={t}") axs[1, t].imshow(filled_rgb * 3) axs[1, t].axis("off") axs[1, t].set_title(f"Filled t={t}") plt.tight_layout() return fig