Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Dense Block as defined in: | |
| # Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. | |
| # "Densely connected convolutional networks." In Proceedings of the IEEE conference | |
| # on computer vision and pattern recognition, pp. 4700-4708. 2017. | |
| # | |
| # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import torch | |
| import torch.nn as nn | |
| from collections import OrderedDict | |
| class DenseBlock(nn.Module): | |
| """Dense Block as defined in: | |
| Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. | |
| "Densely connected convolutional networks." In Proceedings of the IEEE conference | |
| on computer vision and pattern recognition, pp. 4700-4708. 2017. | |
| Only performs `valid` convolution. | |
| """ | |
| def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): | |
| super(DenseBlock, self).__init__() | |
| assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" | |
| self.nr_unit = unit_count | |
| self.in_ch = in_ch | |
| self.unit_ch = unit_ch | |
| # ! For inference only so init values for batchnorm may not match tensorflow | |
| unit_in_ch = in_ch | |
| self.units = nn.ModuleList() | |
| for idx in range(unit_count): | |
| self.units.append( | |
| nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), | |
| ("preact_bna/relu", nn.ReLU(inplace=True)), | |
| ( | |
| "conv1", | |
| nn.Conv2d( | |
| unit_in_ch, | |
| unit_ch[0], | |
| unit_ksize[0], | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| ), | |
| ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), | |
| ("conv1/relu", nn.ReLU(inplace=True)), | |
| # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), | |
| ( | |
| "conv2", | |
| nn.Conv2d( | |
| unit_ch[0], | |
| unit_ch[1], | |
| unit_ksize[1], | |
| groups=split, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| ), | |
| ] | |
| ) | |
| ) | |
| ) | |
| unit_in_ch += unit_ch[1] | |
| self.blk_bna = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), | |
| ("relu", nn.ReLU(inplace=True)), | |
| ] | |
| ) | |
| ) | |
| def out_ch(self): | |
| return self.in_ch + self.nr_unit * self.unit_ch[-1] | |
| def init_weights(self): | |
| """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers""" | |
| for m in self.modules(): | |
| classname = m.__class__.__name__ | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| if "norm" in classname.lower(): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| if "linear" in classname.lower(): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, prev_feat): | |
| for idx in range(self.nr_unit): | |
| new_feat = self.units[idx](prev_feat) | |
| prev_feat = crop_to_shape(prev_feat, new_feat) | |
| prev_feat = torch.cat([prev_feat, new_feat], dim=1) | |
| prev_feat = self.blk_bna(prev_feat) | |
| return prev_feat | |
| # helper functions for cropping | |
| def crop_op(x, cropping, data_format="NCHW"): | |
| """Center crop image. | |
| Args: | |
| x: input image | |
| cropping: the substracted amount | |
| data_format: choose either `NCHW` or `NHWC` | |
| """ | |
| crop_t = cropping[0] // 2 | |
| crop_b = cropping[0] - crop_t | |
| crop_l = cropping[1] // 2 | |
| crop_r = cropping[1] - crop_l | |
| if data_format == "NCHW": | |
| x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] | |
| else: | |
| x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] | |
| return x | |
| def crop_to_shape(x, y, data_format="NCHW"): | |
| """Centre crop x so that x has shape of y. y dims must be smaller than x dims. | |
| Args: | |
| x: input array | |
| y: array with desired shape. | |
| """ | |
| assert ( | |
| y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] | |
| ), "Ensure that y dimensions are smaller than x dimensions!" | |
| x_shape = x.size() | |
| y_shape = y.size() | |
| if data_format == "NCHW": | |
| crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) | |
| else: | |
| crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) | |
| return crop_op(x, crop_shape, data_format) | |