|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import re |
|
|
import warnings |
|
|
from collections import OrderedDict |
|
|
from collections.abc import Callable, Sequence |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from monai.apps.utils import download_url |
|
|
from monai.networks.blocks import UpSample |
|
|
from monai.networks.layers.factories import Conv, Dropout |
|
|
from monai.networks.layers.utils import get_act_layer, get_norm_layer |
|
|
from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode |
|
|
from monai.utils.module import export, look_up_option |
|
|
|
|
|
__all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"] |
|
|
|
|
|
|
|
|
class _DenseLayerDecoder(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_features: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout_prob: float = 0.0, |
|
|
act: str | tuple = ("relu", {"inplace": True}), |
|
|
norm: str | tuple = "batch", |
|
|
kernel_size: int = 3, |
|
|
padding: int = 0, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
num_features: number of internal channels used for the layer |
|
|
in_channels: number of the input channels. |
|
|
out_channels: number of the output channels. |
|
|
dropout_prob: dropout rate after each dense layer. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
kernel_size: size of the kernel for >1 convolutions (dependent on mode) |
|
|
padding: padding value for >1 convolutions. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
conv_type: Callable = Conv[Conv.CONV, 2] |
|
|
dropout_type: Callable = Dropout[Dropout.DROPOUT, 2] |
|
|
|
|
|
self.layers = nn.Sequential() |
|
|
|
|
|
self.layers.add_module("preact_bna/bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels)) |
|
|
self.layers.add_module("preact_bna/relu", get_act_layer(name=act)) |
|
|
self.layers.add_module("conv1", conv_type(in_channels, num_features, kernel_size=1, bias=False)) |
|
|
self.layers.add_module("conv1/norm", get_norm_layer(name=norm, spatial_dims=2, channels=num_features)) |
|
|
self.layers.add_module("conv1/relu2", get_act_layer(name=act)) |
|
|
self.layers.add_module( |
|
|
"conv2", |
|
|
conv_type(num_features, out_channels, kernel_size=kernel_size, padding=padding, groups=4, bias=False), |
|
|
) |
|
|
|
|
|
if dropout_prob > 0: |
|
|
self.layers.add_module("dropout", dropout_type(dropout_prob)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x1 = self.layers(x) |
|
|
if x1.shape[-1] != x.shape[-1]: |
|
|
trim = (x.shape[-1] - x1.shape[-1]) // 2 |
|
|
x = x[:, :, trim:-trim, trim:-trim] |
|
|
|
|
|
x = torch.cat([x, x1], 1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class _DecoderBlock(nn.Sequential): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
layers: int, |
|
|
num_features: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout_prob: float = 0.0, |
|
|
act: str | tuple = ("relu", {"inplace": True}), |
|
|
norm: str | tuple = "batch", |
|
|
kernel_size: int = 3, |
|
|
same_padding: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
layers: number of layers in the block. |
|
|
num_features: number of internal features used. |
|
|
in_channels: number of the input channel. |
|
|
out_channels: number of the output channel. |
|
|
dropout_prob: dropout rate after each dense layer. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
kernel_size: size of the kernel for >1 convolutions (dependent on mode) |
|
|
same_padding: whether to do padding for >1 convolutions to ensure |
|
|
the output size is the same as the input size. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
conv_type: Callable = Conv[Conv.CONV, 2] |
|
|
|
|
|
padding: int = kernel_size // 2 if same_padding else 0 |
|
|
|
|
|
self.add_module( |
|
|
"conva", conv_type(in_channels, in_channels // 4, kernel_size=kernel_size, padding=padding, bias=False) |
|
|
) |
|
|
|
|
|
_in_channels = in_channels // 4 |
|
|
for i in range(layers): |
|
|
layer = _DenseLayerDecoder( |
|
|
num_features, |
|
|
_in_channels, |
|
|
out_channels, |
|
|
dropout_prob, |
|
|
act=act, |
|
|
norm=norm, |
|
|
kernel_size=kernel_size, |
|
|
padding=padding, |
|
|
) |
|
|
_in_channels += out_channels |
|
|
self.add_module("denselayerdecoder%d" % (i + 1), layer) |
|
|
|
|
|
trans = _Transition(_in_channels, act=act, norm=norm) |
|
|
self.add_module("bna_block", trans) |
|
|
self.add_module("convf", conv_type(_in_channels, _in_channels, kernel_size=1, bias=False)) |
|
|
|
|
|
|
|
|
class _DenseLayer(nn.Sequential): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_features: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout_prob: float = 0.0, |
|
|
act: str | tuple = ("relu", {"inplace": True}), |
|
|
norm: str | tuple = "batch", |
|
|
drop_first_norm_relu: int = 0, |
|
|
kernel_size: int = 3, |
|
|
) -> None: |
|
|
"""Dense Convolutional Block. |
|
|
|
|
|
References: |
|
|
Huang, Gao, et al. "Densely connected convolutional networks." |
|
|
Proceedings of the IEEE conference on computer vision and |
|
|
pattern recognition. 2017. |
|
|
|
|
|
Args: |
|
|
num_features: number of internal channels used for the layer |
|
|
in_channels: number of the input channels. |
|
|
out_channels: number of the output channels. |
|
|
dropout_prob: dropout rate after each dense layer. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
drop_first_norm_relu - omits the first norm/relu for the first layer |
|
|
kernel_size: size of the kernel for >1 convolutions (dependent on mode) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.Sequential() |
|
|
conv_type: Callable = Conv[Conv.CONV, 2] |
|
|
dropout_type: Callable = Dropout[Dropout.DROPOUT, 2] |
|
|
|
|
|
if not drop_first_norm_relu: |
|
|
self.layers.add_module("preact/bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels)) |
|
|
self.layers.add_module("preact/relu", get_act_layer(name=act)) |
|
|
|
|
|
self.layers.add_module("conv1", conv_type(in_channels, num_features, kernel_size=1, padding=0, bias=False)) |
|
|
self.layers.add_module("conv1/bn", get_norm_layer(name=norm, spatial_dims=2, channels=num_features)) |
|
|
self.layers.add_module("conv1/relu", get_act_layer(name=act)) |
|
|
|
|
|
if in_channels != 64 and drop_first_norm_relu: |
|
|
self.layers.add_module( |
|
|
"conv2", conv_type(num_features, num_features, kernel_size=kernel_size, stride=2, padding=2, bias=False) |
|
|
) |
|
|
else: |
|
|
self.layers.add_module( |
|
|
"conv2", conv_type(num_features, num_features, kernel_size=kernel_size, padding=1, bias=False) |
|
|
) |
|
|
|
|
|
self.layers.add_module("conv2/bn", get_norm_layer(name=norm, spatial_dims=2, channels=num_features)) |
|
|
self.layers.add_module("conv2/relu", get_act_layer(name=act)) |
|
|
self.layers.add_module("conv3", conv_type(num_features, out_channels, kernel_size=1, padding=0, bias=False)) |
|
|
|
|
|
if dropout_prob > 0: |
|
|
self.layers.add_module("dropout", dropout_type(dropout_prob)) |
|
|
|
|
|
|
|
|
class _Transition(nn.Sequential): |
|
|
|
|
|
def __init__( |
|
|
self, in_channels: int, act: str | tuple = ("relu", {"inplace": True}), norm: str | tuple = "batch" |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
in_channels: number of the input channel. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.add_module("bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels)) |
|
|
self.add_module("relu", get_act_layer(name=act)) |
|
|
|
|
|
|
|
|
class _ResidualBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
layers: int, |
|
|
num_features: int, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout_prob: float = 0.0, |
|
|
act: str | tuple = ("relu", {"inplace": True}), |
|
|
norm: str | tuple = "batch", |
|
|
freeze_dense_layer: bool = False, |
|
|
freeze_block: bool = False, |
|
|
) -> None: |
|
|
"""Residual block. |
|
|
|
|
|
References: |
|
|
He, Kaiming, et al. "Deep residual learning for image |
|
|
recognition." Proceedings of the IEEE conference on computer |
|
|
vision and pattern recognition. 2016. |
|
|
|
|
|
Args: |
|
|
layers: number of layers in the block. |
|
|
num_features: number of internal features used. |
|
|
in_channels: number of the input channel. |
|
|
out_channels: number of the output channel. |
|
|
dropout_prob: dropout rate after each dense layer. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
freeze_dense_layer: whether to freeze all dense layers within the block. |
|
|
freeze_block: whether to freeze the whole block. |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.Sequential() |
|
|
conv_type: Callable = Conv[Conv.CONV, 2] |
|
|
|
|
|
if in_channels == 64: |
|
|
self.shortcut = conv_type(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
else: |
|
|
self.shortcut = conv_type(in_channels, out_channels, kernel_size=1, stride=2, padding=1, bias=False) |
|
|
|
|
|
layer = _DenseLayer( |
|
|
num_features, in_channels, out_channels, dropout_prob, act=act, norm=norm, drop_first_norm_relu=True |
|
|
) |
|
|
self.layers.add_module("denselayer_0", layer) |
|
|
|
|
|
for i in range(1, layers): |
|
|
layer = _DenseLayer(num_features, out_channels, out_channels, dropout_prob, act=act, norm=norm) |
|
|
self.layers.add_module(f"denselayer_{i}", layer) |
|
|
|
|
|
self.bna_block = _Transition(out_channels, act=act, norm=norm) |
|
|
|
|
|
if freeze_dense_layer: |
|
|
self.layers.requires_grad_(False) |
|
|
if freeze_block: |
|
|
self.requires_grad_(False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
sc = self.shortcut(x) |
|
|
|
|
|
if self.shortcut.stride == (2, 2): |
|
|
sc = sc[:, :, :-1, :-1] |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer.forward(x) |
|
|
if x.shape[-2:] != sc.shape[-2:]: |
|
|
x = x[:, :, :-1, :-1] |
|
|
|
|
|
x = x + sc |
|
|
sc = x |
|
|
|
|
|
x = self.bna_block(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class _DecoderBranch(nn.ModuleList): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
decode_config: Sequence[int] = (8, 4), |
|
|
act: str | tuple = ("relu", {"inplace": True}), |
|
|
norm: str | tuple = "batch", |
|
|
dropout_prob: float = 0.0, |
|
|
out_channels: int = 2, |
|
|
kernel_size: int = 3, |
|
|
same_padding: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
decode_config: number of layers for each block. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
dropout_prob: dropout rate after each dense layer. |
|
|
out_channels: number of the output channel. |
|
|
kernel_size: size of the kernel for >1 convolutions (dependent on mode) |
|
|
same_padding: whether to do padding for >1 convolutions to ensure |
|
|
the output size is the same as the input size. |
|
|
""" |
|
|
super().__init__() |
|
|
conv_type: Callable = Conv[Conv.CONV, 2] |
|
|
|
|
|
|
|
|
_in_channels = 1024 |
|
|
_num_features = 128 |
|
|
_out_channels = 32 |
|
|
|
|
|
self.decoder_blocks = nn.Sequential() |
|
|
for i, num_layers in enumerate(decode_config): |
|
|
block = _DecoderBlock( |
|
|
layers=num_layers, |
|
|
num_features=_num_features, |
|
|
in_channels=_in_channels, |
|
|
out_channels=_out_channels, |
|
|
dropout_prob=dropout_prob, |
|
|
act=act, |
|
|
norm=norm, |
|
|
kernel_size=kernel_size, |
|
|
same_padding=same_padding, |
|
|
) |
|
|
self.decoder_blocks.add_module(f"decoderblock{i + 1}", block) |
|
|
_in_channels = 512 |
|
|
|
|
|
|
|
|
self.output_features = nn.Sequential() |
|
|
_i = len(decode_config) |
|
|
_pad_size = (kernel_size - 1) // 2 |
|
|
_seq_block = nn.Sequential( |
|
|
OrderedDict( |
|
|
[("conva", conv_type(256, 64, kernel_size=kernel_size, stride=1, bias=False, padding=_pad_size))] |
|
|
) |
|
|
) |
|
|
|
|
|
self.output_features.add_module(f"decoderblock{_i + 1}", _seq_block) |
|
|
|
|
|
_seq_block = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("bn", get_norm_layer(name=norm, spatial_dims=2, channels=64)), |
|
|
("relu", get_act_layer(name=act)), |
|
|
("conv", conv_type(64, out_channels, kernel_size=1, stride=1)), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
self.output_features.add_module(f"decoderblock{_i + 2}", _seq_block) |
|
|
|
|
|
self.upsample = UpSample( |
|
|
2, scale_factor=2, mode=UpsampleMode.NONTRAINABLE, interp_mode=InterpolateMode.BILINEAR, bias=False |
|
|
) |
|
|
|
|
|
def forward(self, xin: torch.Tensor, short_cuts: list[torch.Tensor]) -> torch.Tensor: |
|
|
block_number = len(short_cuts) - 1 |
|
|
x = xin + short_cuts[block_number] |
|
|
|
|
|
for block in self.decoder_blocks: |
|
|
x = block(x) |
|
|
x = self.upsample(x) |
|
|
block_number -= 1 |
|
|
trim = (short_cuts[block_number].shape[-1] - x.shape[-1]) // 2 |
|
|
if trim > 0: |
|
|
x += short_cuts[block_number][:, :, trim:-trim, trim:-trim] |
|
|
|
|
|
for block in self.output_features: |
|
|
x = block(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
@export("monai.networks.nets") |
|
|
class HoVerNet(nn.Module): |
|
|
"""HoVerNet model |
|
|
|
|
|
References: |
|
|
Graham, Simon et al. Hover-net: Simultaneous segmentation |
|
|
and classification of nuclei in multi-tissue histology images, |
|
|
Medical Image Analysis 2019 |
|
|
|
|
|
https://github.com/vqdang/hover_net |
|
|
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html |
|
|
|
|
|
This network is non-deterministic since it uses `torch.nn.Upsample` with ``UpsampleMode.NONTRAINABLE`` mode which |
|
|
is implemented with torch.nn.functional.interpolate(). Please check the link below for more details: |
|
|
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms |
|
|
|
|
|
Args: |
|
|
mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or |
|
|
a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`. |
|
|
in_channels: number of the input channel. |
|
|
np_out_channels: number of the output channel of the nucleus prediction branch. |
|
|
out_classes: number of the nuclear type classes. |
|
|
act: activation type and arguments. Defaults to relu. |
|
|
norm: feature normalization type and arguments. Defaults to batch norm. |
|
|
decoder_padding: whether to do padding on convolution layers in the decoders. In the conic branch |
|
|
of the referred repository, the architecture is changed to do padding on convolution layers in order to |
|
|
get the same output size as the input, and this changed version is used on CoNIC challenge. |
|
|
Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed. |
|
|
dropout_prob: dropout rate after each dense layer. |
|
|
pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url. |
|
|
There are two supported forms of weights: |
|
|
1. preact-resnet50 weights coming from the referred hover_net |
|
|
repository, each user is responsible for checking the content of model/datasets and the applicable licenses |
|
|
and determining if suitable for the intended use. please check the following link for more details: |
|
|
https://github.com/vqdang/hover_net#data-format |
|
|
2. standard resnet50 weights of torchvision. Please check the following link for more details: |
|
|
https://pytorch.org/vision/main/_modules/torchvision/models/resnet.html#ResNet50_Weights |
|
|
adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this |
|
|
value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format, |
|
|
this value should be `True`. |
|
|
pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. |
|
|
It is used to extract the expected state dict. |
|
|
freeze_encoder: whether to freeze the encoder of the network. |
|
|
""" |
|
|
|
|
|
Mode = HoVerNetMode |
|
|
Branch = HoVerNetBranch |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mode: HoVerNetMode | str = HoVerNetMode.FAST, |
|
|
in_channels: int = 3, |
|
|
np_out_channels: int = 2, |
|
|
out_classes: int = 0, |
|
|
act: str | tuple = ("relu", {"inplace": True}), |
|
|
norm: str | tuple = "batch", |
|
|
decoder_padding: bool = False, |
|
|
dropout_prob: float = 0.0, |
|
|
pretrained_url: str | None = None, |
|
|
adapt_standard_resnet: bool = False, |
|
|
pretrained_state_dict_key: str | None = None, |
|
|
freeze_encoder: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
if isinstance(mode, str): |
|
|
mode = mode.upper() |
|
|
self.mode = look_up_option(mode, HoVerNetMode) |
|
|
|
|
|
if self.mode == "ORIGINAL" and decoder_padding is True: |
|
|
warnings.warn( |
|
|
"'decoder_padding=True' only works when mode is 'FAST', otherwise the output size may not equal to the input." |
|
|
) |
|
|
|
|
|
if out_classes > 128: |
|
|
raise ValueError("Number of nuclear types classes exceeds maximum (128)") |
|
|
elif out_classes == 1: |
|
|
raise ValueError("Number of nuclear type classes should either be None or >1") |
|
|
|
|
|
if dropout_prob > 1 or dropout_prob < 0: |
|
|
raise ValueError("Dropout can only be in the range 0.0 to 1.0") |
|
|
|
|
|
|
|
|
_init_features: int = 64 |
|
|
|
|
|
_block_config: Sequence[int] = (3, 4, 6, 3) |
|
|
|
|
|
if self.mode == HoVerNetMode.FAST: |
|
|
_ksize = 3 |
|
|
_pad = 3 |
|
|
else: |
|
|
_ksize = 5 |
|
|
_pad = 0 |
|
|
|
|
|
conv_type: type[nn.Conv2d] = Conv[Conv.CONV, 2] |
|
|
|
|
|
self.conv0 = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("conv", conv_type(in_channels, _init_features, kernel_size=7, stride=1, padding=_pad, bias=False)), |
|
|
("bn", get_norm_layer(name=norm, spatial_dims=2, channels=_init_features)), |
|
|
("relu", get_act_layer(name=act)), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
_in_channels = _init_features |
|
|
_out_channels = 256 |
|
|
_num_features = _init_features |
|
|
|
|
|
self.res_blocks = nn.Sequential() |
|
|
|
|
|
for i, num_layers in enumerate(_block_config): |
|
|
freeze_dense_layer = False |
|
|
freeze_block = False |
|
|
if freeze_encoder: |
|
|
if i == 0: |
|
|
freeze_dense_layer = True |
|
|
else: |
|
|
freeze_block = True |
|
|
block = _ResidualBlock( |
|
|
layers=num_layers, |
|
|
num_features=_num_features, |
|
|
in_channels=_in_channels, |
|
|
out_channels=_out_channels, |
|
|
dropout_prob=dropout_prob, |
|
|
act=act, |
|
|
norm=norm, |
|
|
freeze_dense_layer=freeze_dense_layer, |
|
|
freeze_block=freeze_block, |
|
|
) |
|
|
self.res_blocks.add_module(f"d{i}", block) |
|
|
|
|
|
_in_channels = _out_channels |
|
|
_out_channels *= 2 |
|
|
_num_features *= 2 |
|
|
|
|
|
|
|
|
self.bottleneck = nn.Sequential() |
|
|
self.bottleneck.add_module( |
|
|
"conv_bottleneck", conv_type(_in_channels, _num_features, kernel_size=1, stride=1, padding=0, bias=False) |
|
|
) |
|
|
self.upsample = UpSample( |
|
|
2, scale_factor=2, mode=UpsampleMode.NONTRAINABLE, interp_mode=InterpolateMode.BILINEAR, bias=False |
|
|
) |
|
|
|
|
|
|
|
|
self.nucleus_prediction = _DecoderBranch( |
|
|
kernel_size=_ksize, same_padding=decoder_padding, out_channels=np_out_channels |
|
|
) |
|
|
self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize, same_padding=decoder_padding) |
|
|
self.type_prediction: _DecoderBranch | None = ( |
|
|
_DecoderBranch(out_channels=out_classes, kernel_size=_ksize, same_padding=decoder_padding) |
|
|
if out_classes > 0 |
|
|
else None |
|
|
) |
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, conv_type): |
|
|
nn.init.kaiming_normal_(torch.as_tensor(m.weight)) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
nn.init.constant_(torch.as_tensor(m.weight), 1) |
|
|
nn.init.constant_(torch.as_tensor(m.bias), 0) |
|
|
|
|
|
if pretrained_url is not None: |
|
|
if adapt_standard_resnet: |
|
|
weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key) |
|
|
else: |
|
|
weights = _remap_preact_resnet_model(pretrained_url) |
|
|
_load_pretrained_encoder(self, weights) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: |
|
|
if self.mode == HoVerNetMode.ORIGINAL.value: |
|
|
if x.shape[-1] != 270 or x.shape[-2] != 270: |
|
|
raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL") |
|
|
else: |
|
|
if x.shape[-1] != 256 or x.shape[-2] != 256: |
|
|
raise ValueError("Input size should be 256 x 256 when using HoVerNetMode.FAST") |
|
|
|
|
|
x = self.conv0(x) |
|
|
short_cuts = [] |
|
|
|
|
|
for i, block in enumerate(self.res_blocks): |
|
|
x = block.forward(x) |
|
|
|
|
|
if i <= 2: |
|
|
short_cuts.append(x) |
|
|
|
|
|
x = self.bottleneck(x) |
|
|
x = self.upsample(x) |
|
|
|
|
|
output = { |
|
|
HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts), |
|
|
HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts), |
|
|
} |
|
|
if self.type_prediction is not None: |
|
|
output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict): |
|
|
model_dict = model.state_dict() |
|
|
state_dict = { |
|
|
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) |
|
|
} |
|
|
|
|
|
model_dict.update(state_dict) |
|
|
model.load_state_dict(model_dict) |
|
|
if len(state_dict.keys()) == 0: |
|
|
warnings.warn( |
|
|
"no key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct." |
|
|
) |
|
|
else: |
|
|
print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.") |
|
|
|
|
|
|
|
|
def _remap_preact_resnet_model(model_url: str): |
|
|
pattern_conv0 = re.compile(r"^(conv0\.\/)(.+)$") |
|
|
pattern_block = re.compile(r"^(d\d+)\.(.+)$") |
|
|
pattern_layer = re.compile(r"^(.+\.d\d+)\.units\.(\d+)(.+)$") |
|
|
pattern_bna = re.compile(r"^(.+\.d\d+)\.blk_bna\.(.+)") |
|
|
|
|
|
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth") |
|
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) |
|
|
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[ |
|
|
"desc" |
|
|
] |
|
|
for key in list(state_dict.keys()): |
|
|
new_key = None |
|
|
if pattern_conv0.match(key): |
|
|
new_key = re.sub(pattern_conv0, r"conv0.conv\2", key) |
|
|
elif pattern_block.match(key): |
|
|
new_key = re.sub(pattern_block, r"res_blocks.\1.\2", key) |
|
|
if pattern_layer.match(new_key): |
|
|
new_key = re.sub(pattern_layer, r"\1.layers.denselayer_\2.layers\3", new_key) |
|
|
elif pattern_bna.match(new_key): |
|
|
new_key = re.sub(pattern_bna, r"\1.bna_block.\2", new_key) |
|
|
if new_key: |
|
|
state_dict[new_key] = state_dict[key] |
|
|
del state_dict[key] |
|
|
if "upsample2x" in key: |
|
|
del state_dict[key] |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None): |
|
|
pattern_conv0 = re.compile(r"^conv1\.(.+)$") |
|
|
pattern_bn1 = re.compile(r"^bn1\.(.+)$") |
|
|
pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$") |
|
|
|
|
|
pattern_block_bn3 = re.compile(r"^(res_blocks.d\d+\.layers\.denselayer_)(\d+)\.layers\.bn3\.(.+)$") |
|
|
|
|
|
pattern_block_bn = re.compile(r"^(res_blocks.d\d+\.layers\.denselayer_\d+\.layers)\.bn(\d+)\.(.+)$") |
|
|
pattern_downsample0 = re.compile(r"^(res_blocks.d\d+).+\.downsample\.0\.(.+)") |
|
|
pattern_downsample1 = re.compile(r"^(res_blocks.d\d+).+\.downsample\.1\.(.+)") |
|
|
|
|
|
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth") |
|
|
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) |
|
|
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu")) |
|
|
if state_dict_key is not None: |
|
|
state_dict = state_dict[state_dict_key] |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
|
new_key = None |
|
|
if pattern_conv0.match(key): |
|
|
new_key = re.sub(pattern_conv0, r"conv0.conv.\1", key) |
|
|
elif pattern_bn1.match(key): |
|
|
new_key = re.sub(pattern_bn1, r"conv0.bn.\1", key) |
|
|
elif pattern_block.match(key): |
|
|
new_key = re.sub( |
|
|
pattern_block, |
|
|
lambda s: "res_blocks.d" |
|
|
+ str(int(s.group(1)) - 1) |
|
|
+ ".layers.denselayer_" |
|
|
+ s.group(2) |
|
|
+ ".layers." |
|
|
+ s.group(3), |
|
|
key, |
|
|
) |
|
|
if pattern_block_bn3.match(new_key): |
|
|
new_key = re.sub( |
|
|
pattern_block_bn3, |
|
|
lambda s: s.group(1) + str(int(s.group(2)) + 1) + ".layers.preact/bn." + s.group(3), |
|
|
new_key, |
|
|
) |
|
|
elif pattern_block_bn.match(new_key): |
|
|
new_key = re.sub(pattern_block_bn, r"\1.conv\2/bn.\3", new_key) |
|
|
elif pattern_downsample0.match(new_key): |
|
|
new_key = re.sub(pattern_downsample0, r"\1.shortcut.\2", new_key) |
|
|
elif pattern_downsample1.match(new_key): |
|
|
new_key = re.sub(pattern_downsample1, r"\1.bna_block.bn.\2", new_key) |
|
|
if new_key: |
|
|
state_dict[new_key] = state_dict[key] |
|
|
del state_dict[key] |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
Hovernet = HoVernet = HoverNet = HoVerNet |
|
|
|