Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| try: | |
| from inplace_abn import InPlaceABN | |
| except ImportError: | |
| InPlaceABN = None | |
| class Conv2dReLU(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding=0, | |
| stride=1, | |
| use_batchnorm=True, | |
| ): | |
| if use_batchnorm == "inplace" and InPlaceABN is None: | |
| raise RuntimeError( | |
| "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " | |
| + "To install see: https://github.com/mapillary/inplace_abn" | |
| ) | |
| conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| bias=not (use_batchnorm), | |
| ) | |
| relu = nn.ReLU(inplace=True) | |
| if use_batchnorm == "inplace": | |
| bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) | |
| relu = nn.Identity() | |
| elif use_batchnorm and use_batchnorm != "inplace": | |
| bn = nn.BatchNorm2d(out_channels) | |
| else: | |
| bn = nn.Identity() | |
| super(Conv2dReLU, self).__init__(conv, bn, relu) | |
| class SCSEModule(nn.Module): | |
| def __init__(self, in_channels, reduction=16): | |
| super().__init__() | |
| self.cSE = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(in_channels, in_channels // reduction, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(in_channels // reduction, in_channels, 1), | |
| nn.Sigmoid(), | |
| ) | |
| self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) | |
| def forward(self, x): | |
| return x * self.cSE(x) + x * self.sSE(x) | |
| class ArgMax(nn.Module): | |
| def __init__(self, dim=None): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| return torch.argmax(x, dim=self.dim) | |
| class Clamp(nn.Module): | |
| def __init__(self, min=0, max=1): | |
| super().__init__() | |
| self.min, self.max = min, max | |
| def forward(self, x): | |
| return torch.clamp(x, self.min, self.max) | |
| class Activation(nn.Module): | |
| def __init__(self, name, **params): | |
| super().__init__() | |
| if name is None or name == "identity": | |
| self.activation = nn.Identity(**params) | |
| elif name == "sigmoid": | |
| self.activation = nn.Sigmoid() | |
| elif name == "relu": | |
| self.activation = nn.ReLU(inplace=True) | |
| elif name == "softmax2d": | |
| self.activation = nn.Softmax(dim=1, **params) | |
| elif name == "softmax": | |
| self.activation = nn.Softmax(**params) | |
| elif name == "logsoftmax": | |
| self.activation = nn.LogSoftmax(**params) | |
| elif name == "tanh": | |
| self.activation = nn.Tanh() | |
| elif name == "argmax": | |
| self.activation = ArgMax(**params) | |
| elif name == "argmax2d": | |
| self.activation = ArgMax(dim=1, **params) | |
| elif name == "clamp": | |
| self.activation = Clamp(**params) | |
| elif callable(name): | |
| self.activation = name(**params) | |
| else: | |
| raise ValueError( | |
| f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" | |
| f"argmax/argmax2d/clamp/None; got {name}" | |
| ) | |
| def forward(self, x): | |
| return self.activation(x) | |
| class Attention(nn.Module): | |
| def __init__(self, name, **params): | |
| super().__init__() | |
| if name is None: | |
| self.attention = nn.Identity(**params) | |
| elif name == "scse": | |
| self.attention = SCSEModule(**params) | |
| else: | |
| raise ValueError("Attention {} is not implemented".format(name)) | |
| def forward(self, x): | |
| return self.attention(x) | |