GANFilling / load.py
csaybar's picture
Upload 5 files
ea2435e verified
import importlib.util
import pathlib
import matplotlib.pyplot as plt
import safetensors.torch
import torch
def load_model_module(model_path: pathlib.Path):
model_path = model_path.resolve()
spec = importlib.util.spec_from_file_location("model", model_path)
model = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model)
return model
class GeneratorNormal(torch.nn.Module):
def __init__(self, model):
super(GeneratorNormal, self).__init__()
self.model = model
def forward(self, X):
X = torch.clamp(X, 0, 1)
X = torch.nan_to_num(X, nan=1.0)
X = X.permute(0, 2, 3, 4, 1).contiguous()
return self.model(X)[0].permute(0, 4, 1, 2, 3)
def example_data(path: pathlib.Path, *args, **kwargs):
data_f = path / "example_data.safetensor"
return safetensors.torch.load_file(data_f)["example_data"][None].float()
def trainable_model(path, device="cpu", *args, **kwargs):
weights = safetensors.torch.load_file(path / "model.safetensor")
model = load_model_module(path / "model.py").Generator(
device=device, inputChannels=4, outputChannels=4
)
model.load_state_dict(weights)
return model
def compiled_model(path, device="cpu", *args, **kwargs):
weights = safetensors.torch.load_file(path / "model.safetensor")
model = load_model_module(path / "model.py").Generator(
device=device, inputChannels=4, outputChannels=4
)
model.load_state_dict(weights)
model = model.eval()
for param in model.parameters():
param.requires_grad = False
return GeneratorNormal(model.to(device))
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
# Load model
model = compiled_model(path, device)
# Load data
s2_ts = example_data(path)
# Run model
gap_filled = model(s2_ts.to(device))
# Convert to CPU and detach for plotting
s2_ts = s2_ts.squeeze(0).detach().cpu() # [T, C, H, W]
gap_filled = gap_filled.squeeze(0).detach().cpu()
num_timesteps = s2_ts.shape[0]
rgb_indices = [2, 1, 0] # Assuming RGB is BGR in channel order (4 bands)
fig, axs = plt.subplots(2, num_timesteps, figsize=(3 * num_timesteps, 6))
for t in range(num_timesteps):
original_rgb = s2_ts[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy()
filled_rgb = gap_filled[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy()
axs[0, t].imshow(original_rgb * 3)
axs[0, t].axis("off")
axs[0, t].set_title(f"Original t={t}")
axs[1, t].imshow(filled_rgb * 3)
axs[1, t].axis("off")
axs[1, t].set_title(f"Filled t={t}")
plt.tight_layout()
return fig