import torch import pathlib import importlib.util import safetensors.torch import matplotlib.pyplot as plt import math from typing import Literal def load_model_module(model_path: pathlib.Path): model_path = model_path.resolve() spec = importlib.util.spec_from_file_location("model", model_path) model = importlib.util.module_from_spec(spec) spec.loader.exec_module(model) return model class EnsembleModel(torch.nn.Module): def __init__(self, model1, model2, model3, model4, model5, model6, mode="max"): super(EnsembleModel, self).__init__() self.model1 = model1 self.model2 = model2 self.model3 = model3 self.model4 = model4 self.model5 = model5 self.model6 = model6 self.models = [model1, model2, model3, model4, model5, model6] self.mode = mode if mode not in ["min", "mean", "max", "none"]: raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.") def forward(self, x): outputs = [] for model in self.models: output = model(x) outputs.append(output) # Average the outputs if self.mode == "max": output_probs = torch.max(torch.cat(outputs, dim=1), dim=1)[0].squeeze() elif self.mode == "mean": output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze() elif self.mode == "min": output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze() elif self.mode == "none": return torch.cat(outputs, dim=1) else: raise ValueError("Mode must be 'min', 'mean', or 'max'.") # Kind of uncertainty std_output = torch.std(torch.cat(outputs, dim=1), dim=1)[0].squeeze() # Normalize the standard deviation [0 - 1] N = len(outputs) std_max = math.sqrt(0.25 * N / (N - 1)) std_output = std_output / std_max return output_probs, std_output # MLSTAC API ----------------------------------------------------------------------- def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs): data_f = path / "example_data.safetensor" sample = safetensors.torch.load_file(data_f) return sample["image"].float().unsqueeze(0).to(device) def trainable_model(*args, **kwargs): print("The model is not available in training mode.") return None def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max",*args, **kwargs): model1_f = path / "1dpwdeeplabv3.safetensor" model2_f = path / "1dpwseg.safetensor" model3_f = path / "1dpwunetpp.safetensor" model4_f = path / "unet.safetensor" model5_f = path / "unetpp.safetensor" model6_f = path / "c2r1km.safetensor" # Load model parameters model1_weights = safetensors.torch.load_file(model1_f) model2_weights = safetensors.torch.load_file(model2_f) model3_weights = safetensors.torch.load_file(model3_f) model4_weights = safetensors.torch.load_file(model4_f) model5_weights = safetensors.torch.load_file(model5_f) model6_weights = safetensors.torch.load_file(model6_f) # Load all models # Model 1 (DeepLabV3Branch + PixelWise) model1 = load_model_module(path / "model.py").CombinedNet4( classes=1, benchmark=True, in_channels=4 ) model1.load_state_dict(model1_weights) model1 = model1.to(device) for param in model1.parameters(): param.requires_grad = False model1 = model1.eval() # Model 2 (SegformerBranch + PixelWise) model2 = load_model_module(path / "model.py").CombinedNet( classes=1, benchmark=True ) model2.load_state_dict(model2_weights) model2 = model2.to(device) for param in model2.parameters(): param.requires_grad = False model2 = model2.eval() # Model 3 (UNetPlusPlusBranch + PixelWise) model3 = load_model_module(path / "model.py").CombinedNet3( classes=1, benchmark=True ) model3.load_state_dict(model3_weights) model3 = model3.to(device) for param in model3.parameters(): param.requires_grad = False model3 = model3.eval() # Model 4 (UNetBranch) model4 = load_model_module(path / "model.py").UNetBranch( classes=1, benchmark=True ) model4.load_state_dict(model4_weights) model4 = model4.to(device) for param in model4.parameters(): param.requires_grad = False model4 = model4.eval() # Model 5 (UNetPlusPlusBranch) model5 = load_model_module(path / "model.py").UNetPlusPlusBranch( classes=1, benchmark=True ) model5.load_state_dict(model5_weights) model5 = model5.to(device) for param in model5.parameters(): param.requires_grad = False model5 = model5.eval() # Model 6 (C2R1KM) model6 = load_model_module(path / "c2r1km.py").CloudMaskOne( hidden_layer_sizes=(128, 112), activation='relu', last_activation='sigmoid', dropout_rate=0.1, input_dim=40, batch_norm=False ) model6.load_state_dict(model6_weights) model6 = model6.to(device) for param in model6.parameters(): param.requires_grad = False model6 = model6.eval() # Create ensemble model cloud_model = EnsembleModel(model1, model2, model3, model4, model5, model6, mode=mode) return cloud_model def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs): # Load model model = compiled_model(path, device, mode=mode) # Load data probav = example_data(path) # Run model cloudprob, uncertainty = model(probav.float().to(device)) #Display results fig, ax = plt.subplots(1, 3, figsize=(12, 4)) ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0)) ax[0].set_title("Input") ax[1].imshow(cloudprob.cpu().detach().numpy(), cmap="gray") ax[1].set_title("Cloud Probability") ax[2].imshow(uncertainty.cpu().detach().numpy(), cmap="gray") ax[2].set_title("Uncertainty") for a in ax: a.axis("off") fig.tight_layout() return fig