Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| import torchvision.models | |
| from torch import nn | |
| import torch | |
| import torch.nn.functional as F | |
| from AV.models.layers import * | |
| from torchvision.models.convnext import convnext_tiny, ConvNeXt_Tiny_Weights | |
| import numpy as np | |
| import math | |
| from torchvision import models | |
| import copy | |
| class PGNet(nn.Module): | |
| def __init__(self, input_ch=3, resnet='convnext_tiny', num_classes=3, use_cuda=False, pretrained=True,centerness=False, centerness_map_size=[128,128],use_global_semantic=False): | |
| super(PGNet, self).__init__() | |
| self.resnet = resnet | |
| base_model = convnext_tiny | |
| # layers = list(base_model(pretrained=pretrained,num_classes=num_classes,input_ch=input_ch).children())[:cut] | |
| self.use_high_semantic = False | |
| cut = 6 | |
| if pretrained: | |
| layers = list(base_model(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1).features)[:cut] | |
| else: | |
| layers = list(base_model().features)[:cut] | |
| base_layers = nn.Sequential(*layers) | |
| self.use_global_semantic = use_global_semantic | |
| ### global momentum | |
| if self.use_global_semantic: | |
| self.pg_fusion = PGFusion() | |
| self.base_layers_global_momentum = copy.deepcopy(base_layers) | |
| set_requires_grad(self.base_layers_global_momentum,requires_grad=False) | |
| # self.stage = [SaveFeatures(base_layers[0][1])] # stage 1 c=96 | |
| self.stage = [] | |
| self.stage.append(SaveFeatures(base_layers[0][1])) # stem c=96 | |
| self.stage.append(SaveFeatures(base_layers[1][2])) # stage 1 c=96 | |
| self.stage.append(SaveFeatures(base_layers[3][2])) # stage 2 c=192 | |
| self.stage.append(SaveFeatures(base_layers[5][8])) # stage 3 c=384 | |
| # self.stage.append(SaveFeatures(base_layers[7][2])) # stage 5 c=768 | |
| self.up2 = DBlock(384, 192) | |
| self.up3 = DBlock(192, 96) | |
| self.up4 = DBlock(96, 96) | |
| # final convolutional layers | |
| # predict artery, vein and vessel | |
| self.seg_head = SegmentationHead(96, num_classes, 3, upsample=4) | |
| self.sn_unet = base_layers | |
| self.num_classes = num_classes | |
| self.bn_out = nn.BatchNorm2d(3) | |
| #self.av_cross = AV_Cross(block=4,kernel_size=1) | |
| # use centerness block | |
| self.centerness = centerness | |
| if self.centerness and centerness_map_size[0] == 128: | |
| # block 1 | |
| self.cenBlock1 = [ | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| ] | |
| self.cenBlock1 = nn.Sequential(*self.cenBlock1) | |
| # centerness block | |
| self.cenBlockMid = [ | |
| nn.Conv2d(96, 48, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(48), | |
| # nn.Conv2d(48, 48, kernel_size=3, padding=3, bias=False), | |
| # nn.BatchNorm2d(48), | |
| nn.Conv2d(48, 96, kernel_size=1, padding=0, bias=False), | |
| ] | |
| self.cenBlockMid = nn.Sequential(*self.cenBlockMid) | |
| self.cenBlockFinal = [ | |
| nn.BatchNorm2d(96), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(96, 3, kernel_size=1, padding=0, bias=True), | |
| nn.Sigmoid() | |
| ] | |
| self.cenBlockFinal = nn.Sequential(*self.cenBlockFinal) | |
| def forward(self, x,y=None): | |
| x = self.sn_unet(x) | |
| global_rep = None | |
| if self.use_global_semantic: | |
| global_rep = self.base_layers_global_momentum(y) | |
| x = self.pg_fusion(x,global_rep) | |
| if len(x.shape) == 4 and x.shape[2] != x.shape[3]: | |
| B, H, W, C = x.shape | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| elif len(x.shape) == 3: | |
| B, L, C = x.shape | |
| h = int(L ** 0.5) | |
| x = x.view(B, h, h, C) | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| else: | |
| x = x | |
| if self.use_high_semantic: | |
| high_out = x.clone() | |
| else: | |
| high_out = x.clone() | |
| if self.resnet == 'swin_t' or self.resnet == 'convnext_tiny': | |
| # feature = self.stage[1:] | |
| feature = self.stage[::-1] | |
| # head = feature[0] | |
| skip = feature[1:] | |
| # x = self.up1(x,skip[0].features) | |
| x = self.up2(x, skip[0].features) | |
| x = self.up3(x, skip[1].features) | |
| x = self.up4(x, skip[2].features) | |
| x_out = self.seg_head(x) | |
| ######################## | |
| # baseline output | |
| # artery, vein and vessel | |
| output = x_out.clone() | |
| #av cross | |
| #output = self.av_cross(output) | |
| #output = F.relu(self.bn_out(output)) | |
| # use centerness block | |
| centerness_maps = None | |
| if self.centerness: | |
| block1 = self.cenBlock1(self.stage[1].features) # [96,64] | |
| _block1 = self.cenBlockMid(block1) # [96,64] | |
| block1 = block1 + _block1 | |
| blocks = [block1] | |
| blocks = torch.cat(blocks, dim=1) | |
| # print("blocks", blocks.shape) | |
| centerness_maps = self.cenBlockFinal(blocks) | |
| # print("maps:", centerness_maps.shape) | |
| return output, centerness_maps | |
| def forward_patch_rep(self, x): | |
| patch_rep = self.sn_unet(x) | |
| return patch_rep | |
| def forward_global_rep_momentum(self, x): | |
| global_rep = self.base_layers_global_momentum(x) | |
| return global_rep | |
| def close(self): | |
| for sf in self.stage: sf.remove() | |
| def close(self): | |
| for sf in self.stage: sf.remove() | |
| # set requies_grad=Fasle to avoid computation | |
| def set_requires_grad(nets, requires_grad=False): | |
| if not isinstance(nets, list): | |
| nets = [nets] | |
| for net in nets: | |
| if net is not None: | |
| for param in net.parameters(): | |
| param.requires_grad = requires_grad | |
| pretrained_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False).view((1, 3, 1, 1)) | |
| pretrained_std = torch.tensor([0.229, 0.224, 0.225], requires_grad=False).view((1, 3, 1, 1)) | |
| if __name__ == '__main__': | |
| s = PGNet(input_ch=3, resnet='convnext_tiny',centerness=True, pretrained=False,use_global_semantic=False) | |
| x = torch.randn(2, 3, 256, 256) | |
| y,Y2 = s(x) | |
| print(y.shape) | |
| print(Y2.shape) | |
| # pt = torch.load(r'F:\dw\MICCAI2023-STS-2D\segmentation\log\2023_07_25_18_10_10\G_0.pkl') | |
| # print(pt) | |
| # import torchvision.models as models | |
| # m = models.vit_b_16(pretrained=False) | |
| # print(m) | |
| # m = resnet18() | |
| # m_list = list(m.children()) | |
| # def hook(module, input, output): | |
| # print('fafafafgafa') | |
| # print(input[0].shape) | |
| # print(output[0].shape) | |
| # m_list[0].register_forward_hook(hook) | |
| # | |
| # | |
| # y = m(x) | |