| import re |
| from collections import OrderedDict |
| from typing import List, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| import torch.nn.utils.spectral_norm as spectral_norm |
|
|
| class _DenseLayer(nn.Module): |
| def __init__( |
| self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False |
| ) -> None: |
| super().__init__() |
| |
| self.relu1 = nn.ReLU(inplace=True) |
| self.conv1 = spectral_norm(nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)) |
|
|
| |
| self.relu2 = nn.ReLU(inplace=True) |
| self.conv2 = spectral_norm(nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)) |
|
|
| self.drop_rate = float(drop_rate) |
| self.memory_efficient = memory_efficient |
|
|
| def bn_function(self, inputs: List[Tensor]) -> Tensor: |
| concated_features = torch.cat(inputs, 1) |
| bottleneck_output = self.conv1(self.relu1(concated_features)) |
| return bottleneck_output |
|
|
| |
| def any_requires_grad(self, input: List[Tensor]) -> bool: |
| for tensor in input: |
| if tensor.requires_grad: |
| return True |
| return False |
|
|
| |
| |
| def forward(self, input: Tensor) -> Tensor: |
| if isinstance(input, Tensor): |
| prev_features = [input] |
| else: |
| prev_features = input |
|
|
| if self.memory_efficient and self.any_requires_grad(prev_features): |
| if torch.jit.is_scripting(): |
| raise Exception("Memory Efficient not supported in JIT") |
|
|
| bottleneck_output = self.call_checkpoint_bottleneck(prev_features) |
| else: |
| bottleneck_output = self.bn_function(prev_features) |
|
|
| new_features = self.conv2(self.relu2(bottleneck_output)) |
| if self.drop_rate > 0: |
| new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) |
| return new_features |
|
|
|
|
| class _DenseBlock(nn.ModuleDict): |
| _version = 2 |
|
|
| def __init__( |
| self, |
| num_layers: int, |
| num_input_features: int, |
| bn_size: int, |
| growth_rate: int, |
| drop_rate: float, |
| memory_efficient: bool = False, |
| ) -> None: |
| super().__init__() |
| for i in range(num_layers): |
| layer = _DenseLayer( |
| num_input_features + i * growth_rate, |
| growth_rate=growth_rate, |
| bn_size=bn_size, |
| drop_rate=drop_rate, |
| memory_efficient=memory_efficient, |
| ) |
| self.add_module("denselayer%d" % (i + 1), layer) |
|
|
| def forward(self, init_features: Tensor) -> Tensor: |
| features = [init_features] |
| for name, layer in self.items(): |
| new_features = layer(features) |
| features.append(new_features) |
| return torch.cat(features, 1) |
|
|
|
|
| class _Transition(nn.Sequential): |
| def __init__(self, num_input_features: int, num_output_features: int) -> None: |
| super().__init__() |
| self.norm = nn.InstanceNorm2d(num_input_features) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False) |
| self.pool = nn.AvgPool2d(kernel_size=2, stride=2) |
|
|