Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.layers import ChebConv, Pool, residualBlock | |
| import torchvision.ops.roi_align as roi_align | |
| import numpy as np | |
| class SEBlock(nn.Module): | |
| def __init__(self, channel, reduction=16): | |
| super(SEBlock, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.fc = nn.Sequential( | |
| nn.Linear(channel, channel // reduction, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(channel // reduction, channel, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| b, c, _, _ = x.size() | |
| y = self.avg_pool(x).view(b, c) | |
| y = self.fc(y).view(b, c, 1, 1) | |
| return x * y.expand_as(x) | |
| class SEResNeXtBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride=1, cardinality=32, base_width=4, reduction=16): | |
| super(SEResNeXtBlock, self).__init__() | |
| width = int(out_channels * (base_width / 64.)) * cardinality | |
| self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(width, track_running_stats=False) | |
| self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) | |
| self.bn2 = nn.BatchNorm2d(width, track_running_stats=False) | |
| self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(out_channels, track_running_stats=False) | |
| self.se = SEBlock(out_channels, reduction) | |
| self.shortcut = nn.Sequential() | |
| if stride != 1 or in_channels != out_channels: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | |
| nn.BatchNorm2d(out_channels, track_running_stats=False) | |
| ) | |
| def forward(self, x): | |
| residual = x | |
| out = F.relu(self.bn1(self.conv1(x))) | |
| out = F.relu(self.bn2(self.conv2(out))) | |
| out = self.bn3(self.conv3(out)) | |
| out = self.se(out) | |
| if self.shortcut is not None: | |
| residual = self.shortcut(x) | |
| out += residual | |
| out = F.relu(out) | |
| return out | |
| class EncoderConv(nn.Module): | |
| def __init__(self, config): | |
| super(EncoderConv, self).__init__() | |
| self.latents = config['latents'] | |
| self.c = config['initial_filters'] | |
| maximum_amount_of_layers = int(np.log2(config['inputsize'])) - 2 | |
| number_of_layers = min(maximum_amount_of_layers, 5) | |
| self.hw = config['inputsize'] // (2**number_of_layers) | |
| self.filters = [self.c * 2**i for i in range(number_of_layers)] | |
| self.maxpool = nn.MaxPool2d(2) | |
| if config['raster_as_input']: | |
| input_channels = len(config['organs']) | |
| else: | |
| input_channels = 1 | |
| # Create downsampling layers dynamically | |
| self.dconv_down = nn.ModuleList() | |
| for i in range(len(self.filters)): | |
| in_channels = input_channels if i == 0 else self.filters[i-1] | |
| out_channels = self.filters[i] | |
| self.dconv_down.append(SEResNeXtBlock(in_channels, out_channels)) | |
| # Final convolutional layer | |
| self.dconv_final = SEResNeXtBlock(self.filters[-1], self.filters[-1]) | |
| # Fully connected layers for mu and logvar | |
| final_conv_size = self.filters[-1] * self.hw * self.hw | |
| self.fc_mu = nn.Linear(in_features=final_conv_size, out_features=self.latents) | |
| self.fc_logvar = nn.Linear(in_features=final_conv_size, out_features=self.latents) | |
| def forward(self, x): | |
| conv_outputs = [] | |
| for i, dconv in enumerate(self.dconv_down): | |
| x = dconv(x) | |
| conv_outputs.append(x) | |
| x = self.maxpool(x) | |
| x = self.dconv_final(x) | |
| x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors | |
| x_mu = self.fc_mu(x) | |
| x_logvar = self.fc_logvar(x) | |
| return x_mu, x_logvar, list(reversed(conv_outputs)) | |
| class SkipBlock(nn.Module): | |
| def __init__(self, in_filters, window): | |
| super(SkipBlock, self).__init__() | |
| self.window = window | |
| self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False) | |
| def lookup(self, pos, layer, output_size = (1,1)): | |
| B = pos.shape[0] | |
| N = pos.shape[1] | |
| F = layer.shape[1] | |
| h = layer.shape[-1] | |
| ## Scale from [0,1] to [0, h] | |
| pos = pos * h | |
| _x1 = (self.window[0] // 2) * 1.0 | |
| _x2 = (self.window[0] // 2 + 1) * 1.0 | |
| _y1 = (self.window[1] // 2) * 1.0 | |
| _y2 = (self.window[1] // 2 + 1) * 1.0 | |
| boxes = [] | |
| for batch in range(0, B): | |
| x1 = pos[batch,:,0].reshape(-1, 1) - _x1 | |
| x2 = pos[batch,:,0].reshape(-1, 1) + _x2 | |
| y1 = pos[batch,:,1].reshape(-1, 1) - _y1 | |
| y2 = pos[batch,:,1].reshape(-1, 1) + _y2 | |
| aux = torch.cat([x1, y1, x2, y2], axis = 1) | |
| boxes.append(aux) | |
| skip = roi_align(layer, boxes, output_size = output_size, aligned=True) | |
| vista = skip.view([B, N, -1]) | |
| return vista | |
| def forward(self, x, adj, layer): | |
| pos = self.graphConv_pre(x, adj) | |
| pos = torch.clip(pos, 0, 1) | |
| skip = self.lookup(pos, layer) | |
| return torch.cat([x, pos, skip], axis = 2), pos | |
| class Hybrid(nn.Module): | |
| def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices): | |
| super(Hybrid, self).__init__() | |
| self.config = config | |
| hw = config['inputsize'] // 2 ** len(config['filters']) | |
| self.z = config['latents'] | |
| self.encoder = EncoderConv(config) | |
| self.downsample_matrices = downsample_matrices | |
| self.upsample_matrices = upsample_matrices | |
| self.adjacency_matrices = adjacency_matrices | |
| self.n_nodes = config['n_nodes'] | |
| self.filters = config['filters'] | |
| self.K = 6 | |
| self.window = (3,3) | |
| # Decoder fully connected layer | |
| outshape = self.filters[-1] * self.n_nodes[-1] | |
| self.dec_lin = nn.Linear(self.z, outshape) | |
| # Dynamic block creation | |
| # Estimate the number of feature maps after each skip connection | |
| # It's the same as the number of filters in the encoder | |
| skip_values = [0] + [self.encoder.filters[-i] + 2 for i in range(1, len(upsample_matrices)+2)] | |
| self.blocks = nn.ModuleList() | |
| for i in range(len(upsample_matrices)+1): | |
| block = nn.ModuleList([ | |
| ChebConv(self.filters[-(2*i+1)] + skip_values[i], self.filters[-(2*i+2)], self.K), | |
| nn.InstanceNorm1d(self.filters[-(2*i+2)]), | |
| ChebConv(self.filters[-(2*i+2)], self.filters[-(2*i+3)], self.K), | |
| nn.InstanceNorm1d(self.filters[-(2*i+3)]), | |
| SkipBlock(self.filters[-(2*i+3)], self.window) | |
| ]) | |
| if i < len(upsample_matrices): # Don't add skip connection and pool to the last block | |
| block.extend([ | |
| Pool() | |
| ]) | |
| self.blocks.append(block) | |
| # Final convolution layers | |
| self.final_conv1 = ChebConv(self.filters[1] + skip_values[-1], self.filters[1], self.K) | |
| self.final_norm1 = nn.InstanceNorm1d(self.filters[1]) | |
| self.final_conv2 = ChebConv(self.filters[1], self.filters[0], self.K, bias=False) | |
| self.reset_parameters() | |
| def save_checkpoint(self, path, epoch, iterations, optimizer): | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'iterations': iterations, | |
| 'model_state_dict': self.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'config': self.config | |
| }, path) | |
| def load_checkpoint(self, path, device): | |
| checkpoint = torch.load(path, map_location=device) | |
| self.load_state_dict(checkpoint['model_state_dict']) | |
| return checkpoint | |
| def reset_parameters(self): | |
| nn.init.normal_(self.dec_lin.weight, 0, 0.5) | |
| def sampling(self, mu, log_var): | |
| std = torch.exp(0.5 * log_var).clamp(min=1e-6) | |
| eps = torch.randn_like(std) | |
| return eps.mul(std).add_(mu) | |
| def forward(self, x): | |
| if x.dim() != 4: | |
| raise ValueError(f"Expected 4D input, got {x.dim()}D") | |
| self.mu, self.log_var, conv_outputs = self.encoder(x) | |
| z = self.sampling(self.mu, self.log_var) if self.training else self.mu | |
| x = F.relu(self.dec_lin(z)) | |
| x = x.reshape(x.shape[0], -1, self.filters[-1]) | |
| positions = [] | |
| for i, block in enumerate(self.blocks): | |
| x = F.relu(block[1](block[0](x, self.adjacency_matrices[-(i+2)]._indices()))) | |
| x = F.relu(block[3](block[2](x, self.adjacency_matrices[-(i+2)]._indices()))) | |
| x, pos = block[4](x, self.adjacency_matrices[-(i+2)]._indices(), conv_outputs[i]) | |
| positions.append(pos) | |
| if len(block) > 5: # If the block has pool | |
| x = block[5](x, self.upsample_matrices[-(i+1)]) | |
| # Final convolutions | |
| x = F.relu(self.final_norm1(self.final_conv1(x, self.adjacency_matrices[0]._indices()))) | |
| x = self.final_conv2(x, self.adjacency_matrices[0]._indices()) | |
| return x, positions[::-1] |