|
|
import importlib.util |
|
|
import pathlib |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import safetensors.torch |
|
|
import torch |
|
|
|
|
|
|
|
|
def load_model_module(model_path: pathlib.Path): |
|
|
model_path = model_path.resolve() |
|
|
|
|
|
spec = importlib.util.spec_from_file_location("model", model_path) |
|
|
model = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(model) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class GeneratorNormal(torch.nn.Module): |
|
|
def __init__(self, model): |
|
|
super(GeneratorNormal, self).__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, X): |
|
|
X = torch.clamp(X, 0, 1) |
|
|
X = torch.nan_to_num(X, nan=1.0) |
|
|
X = X.permute(0, 2, 3, 4, 1).contiguous() |
|
|
return self.model(X)[0].permute(0, 4, 1, 2, 3) |
|
|
|
|
|
|
|
|
def example_data(path: pathlib.Path, *args, **kwargs): |
|
|
data_f = path / "example_data.safetensor" |
|
|
return safetensors.torch.load_file(data_f)["example_data"][None].float() |
|
|
|
|
|
|
|
|
def trainable_model(path, device="cpu", *args, **kwargs): |
|
|
weights = safetensors.torch.load_file(path / "model.safetensor") |
|
|
model = load_model_module(path / "model.py").Generator( |
|
|
device=device, inputChannels=4, outputChannels=4 |
|
|
) |
|
|
model.load_state_dict(weights) |
|
|
return model |
|
|
|
|
|
|
|
|
def compiled_model(path, device="cpu", *args, **kwargs): |
|
|
weights = safetensors.torch.load_file(path / "model.safetensor") |
|
|
model = load_model_module(path / "model.py").Generator( |
|
|
device=device, inputChannels=4, outputChannels=4 |
|
|
) |
|
|
model.load_state_dict(weights) |
|
|
model = model.eval() |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
return GeneratorNormal(model.to(device)) |
|
|
|
|
|
|
|
|
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs): |
|
|
|
|
|
model = compiled_model(path, device) |
|
|
|
|
|
|
|
|
s2_ts = example_data(path) |
|
|
|
|
|
|
|
|
gap_filled = model(s2_ts.to(device)) |
|
|
|
|
|
|
|
|
s2_ts = s2_ts.squeeze(0).detach().cpu() |
|
|
gap_filled = gap_filled.squeeze(0).detach().cpu() |
|
|
|
|
|
num_timesteps = s2_ts.shape[0] |
|
|
rgb_indices = [2, 1, 0] |
|
|
|
|
|
fig, axs = plt.subplots(2, num_timesteps, figsize=(3 * num_timesteps, 6)) |
|
|
|
|
|
for t in range(num_timesteps): |
|
|
original_rgb = s2_ts[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy() |
|
|
filled_rgb = gap_filled[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy() |
|
|
|
|
|
axs[0, t].imshow(original_rgb * 3) |
|
|
axs[0, t].axis("off") |
|
|
axs[0, t].set_title(f"Original t={t}") |
|
|
|
|
|
axs[1, t].imshow(filled_rgb * 3) |
|
|
axs[1, t].axis("off") |
|
|
axs[1, t].set_title(f"Filled t={t}") |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|