import pathlib import safetensors.torch import segmentation_models_pytorch as smp import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F class SegformerBranch(nn.Module): def __init__(self, in_channels=4, classes=4): super(SegformerBranch, self).__init__() self.segformer = smp.Segformer( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=in_channels, classes=classes, ) def forward(self, x): return self.segformer(x) class PixelWiseNet(nn.Module): def __init__(self, in_channels=4, out_channels=4, base_channels=32): super(PixelWiseNet, self).__init__() self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(base_channels) self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(base_channels) self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.conv3(x) return x class CombinedNet(nn.Module): def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): super(CombinedNet, self).__init__() self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes) self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels) self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) self.benchmark = benchmark def forward(self, x): seg_out = self.seg_branch(x) pixel_out = self.pixel_branch(x) fused = seg_out + pixel_out out = self.fusion_conv(fused) if self.benchmark: out = torch.sigmoid(out) return out # MLSTAC API ----------------------------------------------------------------------- def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs): data_f = path / "example_data.safetensor" sample = safetensors.torch.load_file(data_f) return sample["image"].float().unsqueeze(0).to(device) def trainable_model(path, device: str = "cpu", *args, **kwargs): trainable_f = path / "model.safetensor" # Load model parameters cloud_model_weights = safetensors.torch.load_file(trainable_f) cloud_model = CombinedNet(classes=1) cloud_model.load_state_dict(cloud_model_weights) cloud_model = cloud_model.eval() return cloud_model def compiled_model(path, device: str = "cpu", *args, **kwargs): trainable_f = path / "model.safetensor" # Load model parameters cloud_model_weights = safetensors.torch.load_file(trainable_f) cloud_model = CombinedNet(classes=1, benchmark=True) cloud_model.load_state_dict(cloud_model_weights) cloud_model = cloud_model.eval() # Move model to device cloud_model = cloud_model.to(device) # Desativate gradients for param in cloud_model.parameters(): param.requires_grad = False return cloud_model def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs): # Load model model = compiled_model(path, device, benchmark=True) # Load data probav = example_data(path) # Run model cloudprobs = model(probav).squeeze().cpu() #Display results fig, ax = plt.subplots(1, 2, figsize=(8, 4)) ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0)) ax[0].set_title("Input") ax[1].imshow(cloudprobs.cpu().detach().numpy(), cmap="gray") ax[1].set_title("Output") for a in ax: a.axis("off") fig.tight_layout() return fig