| import torch, os |
| from torch import nn, Tensor |
| import numpy as np |
| from .unet import PlainConvUNet |
| from .segairway import SegAirwayModel |
| import torch.nn.functional as F |
|
|
| """ |
| from : https://github.com/Project-MONAI/GenerativeModels/blob/main/generative/losses/perceptual.py |
| return only bottleneck layer ?? |
| """ |
| class MedicalNetPerceptualSimilarity(nn.Module): |
| """ |
| Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer |
| Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from |
| "Warvito/MedicalNet-models". |
| |
| Args: |
| net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} |
| Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. |
| verbose: if false, mute messages from torch Hub load function. |
| """ |
|
|
| def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: |
| super().__init__() |
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
| self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) |
| self.model = self.model.to(device='cuda', dtype=torch.float16) |
| self.eval() |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| self.criterion = nn.L1Loss() |
|
|
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the |
| pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across |
| the channels. Finally, we compute the difference between the input and target features and calculate the mean |
| value from the spatial dimensions to obtain the perceptual loss. |
| |
| Args: |
| input: 3D input tensor with shape BCDHW. |
| target: 3D target tensor with shape BCDHW. |
| """ |
| input = medicalnet_intensity_normalisation(input[:,0:1]) |
| target = medicalnet_intensity_normalisation(target) |
|
|
| |
| |
| |
| |
|
|
| outs_input = self.model.forward(input) |
| outs_target = self.model.forward(target) |
|
|
| feats_input = normalize_tensor(outs_input) |
| feats_target = normalize_tensor(outs_target) |
|
|
| |
| |
|
|
| results = (feats_input - feats_target) ** 2 |
| results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) |
|
|
| return results.mean() |
|
|
|
|
| def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: |
| return x.mean([2, 3, 4], keepdim=keepdim) |
|
|
| def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: |
| norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) |
| return x / (norm_factor + eps) |
|
|
| def medicalnet_intensity_normalisation(volume): |
| """Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133""" |
| mean = volume.mean() |
| std = volume.std() |
| return (volume - mean) / std |
|
|
|
|
| class MedicalNetL1(nn.Module): |
| def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: |
| super().__init__() |
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
| self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) |
| self.model = self.model.to(device='cuda', dtype=torch.float16) |
| self.eval() |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| self.criterion = nn.L1Loss() |
|
|
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| input = medicalnet_intensity_normalisation(input[:,0:1]) |
| target = medicalnet_intensity_normalisation(target) |
|
|
| outs_input = self.model.forward(input) |
| outs_target = self.model.forward(target) |
|
|
| feats_input = normalize_tensor(outs_input) |
| feats_target = normalize_tensor(outs_target) |
|
|
| return self.criterion(feats_input, feats_target) |
|
|
|
|
| class UNet_layers(nn.Module): |
| def __init__(self, net: str = "", layers=[]): |
| super().__init__() |
| model_params = { |
| "TotalSeg_vessels": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_vessels.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 1, 2]], |
| "num_classes": 3, |
| "model_type": "PlainConvUNet" |
| }, |
| "TotalSeg_V2": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_V2.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| "num_classes": 8, |
| "model_type": "PlainConvUNet" |
| }, |
| "TotalSeg_pelvis_V2": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_pelvis_V2.pth", |
| "strides": [[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| "kernels" : [[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], |
| "num_classes": 38, |
| "model_type": "PlainConvUNet" |
| }, |
| "Imene8": { |
| "weights_path": "/data2/alonguefosse/checkpoints/nnUNet_Imene8_best.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], |
| "num_classes": 9, |
| "model_type": "PlainConvUNet", |
| }, |
| "NaviAirway": { |
| "weights_path" : "/data2/alonguefosse/checkpoints/naviairway_semi_supervise.pkl", |
| "model_type": "NaviAirway" |
| }, |
| } |
| params = model_params[net] |
| kernel = params.get("kernels", [[3, 3, 3]] * 6) |
| if params["model_type"] == "PlainConvUNet": |
| if layers!=[]: |
| self.layers = layers |
| else: |
| self.layers = [0,1,2,3,4,5,6,7,8] |
| self.stages = 5 |
| model = PlainConvUNet(input_channels=1, n_stages=6, features_per_stage=[32, 64, 128, 256, 320, 320], |
| conv_op=nn.Conv3d, kernel_sizes=kernel, strides=params["strides"], |
| num_classes=params["num_classes"], deep_supervision=False, n_conv_per_stage=[2] * 6, |
| n_conv_per_stage_decoder=[2] * 5, conv_bias=True, norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, nonlin=nn.LeakyReLU, |
| nonlin_kwargs={'inplace': True}) |
| elif params["model_type"] == "NaviAirway": |
| if layers!=[]: |
| self.layers = layers |
| else: |
| self.layers = [0,1,2,3,4,5,6] |
| self.stages = 4 |
| model = SegAirwayModel(in_channels=1, out_channels=2) |
| |
| if not os.path.exists(params["weights_path"]): |
| raise FileNotFoundError(f'Error: Checkpoint not found at {params["weights_path"]}') |
| checkpoint = torch.load(params["weights_path"], map_location='cuda') |
| model_state_dict = checkpoint.get('state_dict', checkpoint.get('network_weights', checkpoint.get('model_state_dict'))) |
| model.load_state_dict(model_state_dict, strict=False) |
| print(f"loaded model : {params['weights_path']}") |
| model.eval() |
| |
|
|
| for param in model.parameters(): |
| param.requires_grad = False |
| self.model = model |
| self.model = self.model.to(device='cuda', dtype=torch.float16) |
|
|
| self.L1 = nn.L1Loss() |
| self.net = net |
| self.print_perceptual_layers = False |
| self.debug = False |
| print("layers : ", self.layers) |
|
|
| def forward(self, x, y): |
| """ |
| todo : check if normalization of input tensors is needed |
| """ |
|
|
| if self.stages==5: |
| padding = (0, 0, 8, 8, 0, 0) |
| x = F.pad(x, padding, mode='constant', value=0) |
| y = F.pad(y, padding, mode='constant', value=0) |
|
|
| emb_x = self.model(x[:,0:1]) |
| emb_y = self.model(y) |
| |
|
|
| if self.debug: |
| torch.save(x, "embs/x_ep50") |
| torch.save(y, "embs/y_ep50") |
| for i in range(len(emb_y)): |
| torch.save(emb_y[i], f"embs/emb_y_{i}_ep50") |
| torch.save(emb_x[i], f"embs/emb_x_{i}_ep50") |
| assert(0) |
|
|
| sum_loss = 0 |
| layer_losses = [] |
| for i in self.layers: |
| layer_loss = self.L1(emb_x[i], emb_y[i].detach()) |
| sum_loss += layer_loss |
| layer_losses.append((i, layer_loss.item())) |
|
|
| if self.print_perceptual_layers: |
| print(f"task loss", i, " |", emb_x[i].shape) |
| print(layer_loss) |
|
|
| with open(f'losses_{self.net}.txt', 'a') as file: |
| for i, loss in layer_losses: |
| file.write(f"Layer {i}: Loss = {loss}\n") |
| file.write(f"-------------------\n") |
|
|
| return sum_loss |
| |
|
|
|
|
|
|
| class L1_UNet_layers(nn.Module): |
| def __init__(self, net: str = "", layers=[], mae_weight=10): |
| super().__init__() |
| model_params = { |
| "TotalSeg_vessels": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_vessels.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 1, 2]], |
| "num_classes": 3, |
| "model_type": "PlainConvUNet" |
| }, |
| "TotalSeg_V2": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_V2.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| "num_classes": 8, |
| "model_type": "PlainConvUNet" |
| }, |
| "TotalSeg_pelvis_V2": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_pelvis_V2.pth", |
| "strides": [[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| "kernels" : [[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], |
| "num_classes": 38, |
| "model_type": "PlainConvUNet" |
| }, |
| "Imene8": { |
| "weights_path": "/data2/alonguefosse/checkpoints/nnUNet_Imene8_best.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], |
| "num_classes": 9, |
| "model_type": "PlainConvUNet", |
| }, |
| "NaviAirway": { |
| "weights_path" : "/data2/alonguefosse/checkpoints/naviairway_semi_supervise.pkl", |
| "model_type": "NaviAirway" |
| }, |
| } |
| params = model_params[net] |
| kernel = params.get("kernels", [[3, 3, 3]] * 6) |
| if params["model_type"] == "PlainConvUNet": |
| if layers!=[]: |
| self.layers = layers |
| else: |
| self.layers = [0,1,2,3,4,5,6,7,8] |
| self.stages = 5 |
| model = PlainConvUNet(input_channels=1, n_stages=6, features_per_stage=[32, 64, 128, 256, 320, 320], |
| conv_op=nn.Conv3d, kernel_sizes=kernel, strides=params["strides"], |
| num_classes=params["num_classes"], deep_supervision=False, n_conv_per_stage=[2] * 6, |
| n_conv_per_stage_decoder=[2] * 5, conv_bias=True, norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, nonlin=nn.LeakyReLU, |
| nonlin_kwargs={'inplace': True}) |
| elif params["model_type"] == "NaviAirway": |
| if layers!=[]: |
| self.layers = layers |
| else: |
| self.layers = [0,1,2,3,4,5,6] |
| self.stages = 4 |
| model = SegAirwayModel(in_channels=1, out_channels=2) |
| |
| if not os.path.exists(params["weights_path"]): |
| raise FileNotFoundError(f'Error: Checkpoint not found at {params["weights_path"]}') |
| checkpoint = torch.load(params["weights_path"], map_location='cuda') |
| model_state_dict = checkpoint.get('state_dict', checkpoint.get('network_weights', checkpoint.get('model_state_dict'))) |
| model.load_state_dict(model_state_dict, strict=False) |
| print(f"loaded model : {params['weights_path']}") |
| model.eval() |
| |
|
|
| for param in model.parameters(): |
| param.requires_grad = False |
| self.model = model |
| self.model = self.model.to(device='cuda', dtype=torch.float16) |
|
|
| self.L1 = nn.L1Loss() |
| self.net = net |
| self.print_perceptual_layers = False |
| self.debug = False |
| self.mae_weight=mae_weight |
| print("layers : ", self.layers) |
| print("mae_weight :", mae_weight) |
|
|
| def forward(self, x, y): |
| """ |
| todo : check if normalization of input tensors is needed |
| """ |
|
|
| if self.stages==5: |
| padding = (0, 0, 8, 8, 0, 0) |
| x = F.pad(x, padding, mode='constant', value=0) |
| y = F.pad(y, padding, mode='constant', value=0) |
|
|
| emb_x = self.model(x[:,0:1]) |
| emb_y = self.model(y) |
| |
|
|
| if self.debug: |
| torch.save(x, "embs/x_ep50") |
| torch.save(y, "embs/y_ep50") |
| for i in range(len(emb_y)): |
| torch.save(emb_y[i], f"embs/emb_y_{i}_ep50") |
| torch.save(emb_x[i], f"embs/emb_x_{i}_ep50") |
| assert(0) |
|
|
| sum_loss = 0 |
| layer_losses = [] |
| for i in self.layers: |
| layer_loss = self.L1(emb_x[i], emb_y[i].detach()) |
| sum_loss += layer_loss |
| layer_losses.append((i, layer_loss.item())) |
|
|
| if self.print_perceptual_layers: |
| print(f"task loss", i, " |", emb_x[i].shape) |
| print(layer_loss) |
|
|
| with open(f'losses_{self.net}.txt', 'a') as file: |
| for i, loss in layer_losses: |
| file.write(f"Layer {i}: Loss = {loss}\n") |
| file.write(f"-------------------\n") |
| |
|
|
|
|
| mae_loss = self.L1(x[:, 0:1].cpu(), y.cpu()) * self.mae_weight |
| mae_loss = mae_loss.cuda() |
|
|
| |
| |
| |
| |
| with open(f'losses_airway_mae.txt', 'a') as file2: |
| file2.write(f"airway : {sum_loss:.3f}") |
| file2.write(f" | mae : {mae_loss:.3f} \n") |
|
|
| return sum_loss + mae_loss |
| |
|
|
|
|
|
|
|
|
| class UNet_layers2(nn.Module): |
| def __init__(self, net1 = "", net2 = "", w1 = 1.0, w2 = 1.0): |
| super().__init__() |
| model_params = { |
| "TotalSeg_vessels": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_vessels.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 1, 2]], |
| "num_classes": 3, |
| "model_type": "PlainConvUNet" |
| }, |
| "TotalSeg_V2": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_V2.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| "num_classes": 8, |
| "model_type": "PlainConvUNet" |
| }, |
| "TotalSeg_pelvis_V2": { |
| "weights_path": "/data2/alonguefosse/checkpoints/TotalSeg_pelvis_V2.pth", |
| "strides": [[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], |
| "kernels" : [[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], |
| "num_classes": 38, |
| "model_type": "PlainConvUNet" |
| }, |
| "Imene8": { |
| "weights_path": "/data2/alonguefosse/checkpoints/nnUNet_Imene8_best.pth", |
| "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], |
| "num_classes": 9, |
| "model_type": "PlainConvUNet", |
| }, |
| "NaviAirway": { |
| "weights_path" : "/data2/alonguefosse/checkpoints/naviairway_semi_supervise.pkl", |
| "model_type": "NaviAirway" |
| }, |
| } |
| params = model_params[net1] |
| kernel = params.get("kernels", [[3, 3, 3]] * 6) |
| if params["model_type"] == "PlainConvUNet": |
| self.layers1 = [0,1,2,3,4,5,6,7,8] |
| self.stages1 = 5 |
|
|
| model1 = PlainConvUNet(input_channels=1, n_stages=6, features_per_stage=[32, 64, 128, 256, 320, 320], |
| conv_op=nn.Conv3d, kernel_sizes=kernel, strides=params["strides"], |
| num_classes=params["num_classes"], deep_supervision=False, n_conv_per_stage=[2] * 6, |
| n_conv_per_stage_decoder=[2] * 5, conv_bias=True, norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, nonlin=nn.LeakyReLU, |
| nonlin_kwargs={'inplace': True}) |
| elif params["model_type"] == "NaviAirway": |
| self.layers1 = [0,1,2,3,4,5,6] |
| self.stages1 = 4 |
|
|
| model1 = SegAirwayModel(in_channels=1, out_channels=2) |
|
|
| if not os.path.exists(params["weights_path"]): |
| raise FileNotFoundError(f'Error: Checkpoint not found at {params["weights_path"]}') |
| checkpoint = torch.load(params["weights_path"], map_location='cuda') |
| model_state_dict = checkpoint.get('state_dict', checkpoint.get('network_weights', checkpoint.get('model_state_dict'))) |
| model1.load_state_dict(model_state_dict, strict=False) |
| print(f"loaded model1 : {params['weights_path']}") |
| model1.eval() |
| for param in model1.parameters(): |
| param.requires_grad = False |
| self.model1 = model1 |
| self.model1 = self.model1.to(device='cuda', dtype=torch.float16) |
|
|
| if net2!="": |
| params = model_params[net2] |
| kernel = params.get("kernels", [[3, 3, 3]] * 6) |
| if params["model_type"] == "PlainConvUNet": |
| self.layers2 = [0,1,2,3,4,5,6,7,8] |
| self.stages2 = 5 |
|
|
| model2 = PlainConvUNet(input_channels=1, n_stages=6, features_per_stage=[32, 64, 128, 256, 320, 320], |
| conv_op=nn.Conv3d, kernel_sizes=kernel, strides=params["strides"], |
| num_classes=params["num_classes"], deep_supervision=False, n_conv_per_stage=[2] * 6, |
| n_conv_per_stage_decoder=[2] * 5, conv_bias=True, norm_op=nn.InstanceNorm3d, |
| norm_op_kwargs={'eps': 1e-5, 'affine': True}, nonlin=nn.LeakyReLU, |
| nonlin_kwargs={'inplace': True}) |
| elif params["model_type"] == "NaviAirway": |
| self.layers2 = [0,1,2,3,4,5,6] |
| self.stages2 = 4 |
| model2 = SegAirwayModel(in_channels=1, out_channels=2) |
|
|
| if not os.path.exists(params["weights_path"]): |
| raise FileNotFoundError(f'Error: Checkpoint not found at {params["weights_path"]}') |
| checkpoint = torch.load(params["weights_path"], map_location='cuda') |
| model_state_dict = checkpoint.get('state_dict', checkpoint.get('network_weights', checkpoint.get('model_state_dict'))) |
| model2.load_state_dict(model_state_dict, strict=False) |
| print(f"loaded model2 : {params['weights_path']}") |
| model2.eval() |
| for param in model2.parameters(): |
| param.requires_grad = False |
| self.model2 = model2 |
| self.model2 = self.model2.to(device='cuda', dtype=torch.float16) |
| else: |
| self.model2 = None |
|
|
| self.L1 = nn.L1Loss() |
| self.net1 = net1 |
| self.net2 = net2 |
| self.w1 = w1 |
| self.w2 = w2 |
| self.print_perceptual_layers = False |
| self.debug = False |
|
|
| def forward(self, x, y): |
| """ |
| todo : check if normalization is needed |
| since 100% of models are trained on CT = no normalization needed ? |
| """ |
| emb_x1 = self.model1(x[:,0:1]) |
| emb_y1 = self.model1(y) |
|
|
| if self.stages2==5: |
| padding = (0, 0, 8, 8, 0, 0) |
| x = F.pad(x, padding, mode='constant', value=0) |
| y = F.pad(y, padding, mode='constant', value=0) |
|
|
| emb_x2 = self.model2(x[:,0:1]) |
| emb_y2 = self.model2(y) |
|
|
| sum_loss1 = 0 |
| sum_loss2 = 0 |
| total_loss = 0 |
| layer_losses1 = [] |
| layer_losses2 = [] |
| for i in self.layers1: |
| layer_loss1 = self.L1(emb_x1[i], emb_y1[i].detach()) |
| sum_loss1 += layer_loss1 |
| layer_losses1.append((i, layer_loss1.item())) |
|
|
| for i in self.layers2: |
| layer_loss2 = self.L1(emb_x2[i], emb_y2[i].detach()) |
| sum_loss2 += layer_loss2 |
| layer_losses2.append((i, layer_loss2.item())) |
|
|
| with open(f'losses1_{self.net1}.txt', 'a') as file: |
| for i, loss1 in layer_losses1: |
| file.write(f"Layer {i}: {self.net1} = {loss1} \n") |
| file.write(f"-------------------\n") |
|
|
| with open(f'losses2_{self.net2}.txt', 'a') as file: |
| for i, loss2 in layer_losses2: |
| file.write(f"Layer {i}: {self.net2} = {loss2} \n") |
| file.write(f"-------------------\n") |
|
|
| total_loss = sum_loss1 * self.w1 + sum_loss2 * self.w2 |
|
|
| print(self.net1, sum_loss1*self.w1) |
| print(self.net2, sum_loss2*self.w2) |
| print("----") |
| return sum_loss1 + sum_loss2 |