| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import collections |
| | from functools import partial |
| | from itertools import repeat |
| | from typing import Callable |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| |
|
| |
|
| | |
| | def _ntuple(n): |
| | def parse(x): |
| | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
| | return tuple(x) |
| | return tuple(repeat(x, n)) |
| |
|
| | return parse |
| |
|
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def default(val, d): |
| | return val if exists(val) else d |
| |
|
| |
|
| | to_2tuple = _ntuple(2) |
| |
|
| |
|
| | class ResidualBlock(nn.Module): |
| | """ |
| | ResidualBlock: construct a block of two conv layers with residual connections |
| | """ |
| |
|
| | def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): |
| | super(ResidualBlock, self).__init__() |
| |
|
| | self.conv1 = nn.Conv2d( |
| | in_planes, |
| | planes, |
| | kernel_size=kernel_size, |
| | padding=1, |
| | stride=stride, |
| | padding_mode="zeros", |
| | ) |
| | self.conv2 = nn.Conv2d( |
| | planes, |
| | planes, |
| | kernel_size=kernel_size, |
| | padding=1, |
| | padding_mode="zeros", |
| | ) |
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | num_groups = planes // 8 |
| |
|
| | if norm_fn == "group": |
| | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| | if not stride == 1: |
| | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| |
|
| | elif norm_fn == "batch": |
| | self.norm1 = nn.BatchNorm2d(planes) |
| | self.norm2 = nn.BatchNorm2d(planes) |
| | if not stride == 1: |
| | self.norm3 = nn.BatchNorm2d(planes) |
| |
|
| | elif norm_fn == "instance": |
| | self.norm1 = nn.InstanceNorm2d(planes) |
| | self.norm2 = nn.InstanceNorm2d(planes) |
| | if not stride == 1: |
| | self.norm3 = nn.InstanceNorm2d(planes) |
| |
|
| | elif norm_fn == "none": |
| | self.norm1 = nn.Sequential() |
| | self.norm2 = nn.Sequential() |
| | if not stride == 1: |
| | self.norm3 = nn.Sequential() |
| | else: |
| | raise NotImplementedError |
| |
|
| | if stride == 1: |
| | self.downsample = None |
| | else: |
| | self.downsample = nn.Sequential( |
| | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), |
| | self.norm3, |
| | ) |
| |
|
| | def forward(self, x): |
| | y = x |
| | y = self.relu(self.norm1(self.conv1(y))) |
| | y = self.relu(self.norm2(self.conv2(y))) |
| |
|
| | if self.downsample is not None: |
| | x = self.downsample(x) |
| |
|
| | return self.relu(x + y) |
| |
|
| |
|
| | class Mlp(nn.Module): |
| | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" |
| |
|
| | def __init__( |
| | self, |
| | in_features, |
| | hidden_features=None, |
| | out_features=None, |
| | act_layer=nn.GELU, |
| | norm_layer=None, |
| | bias=True, |
| | drop=0.0, |
| | use_conv=False, |
| | ): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | bias = to_2tuple(bias) |
| | drop_probs = to_2tuple(drop) |
| | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
| |
|
| | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
| | self.act = act_layer() |
| | self.drop1 = nn.Dropout(drop_probs[0]) |
| | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
| | self.drop2 = nn.Dropout(drop_probs[1]) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop1(x) |
| | x = self.fc2(x) |
| | x = self.drop2(x) |
| | return x |
| |
|
| |
|
| | class AttnBlock(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_size, |
| | num_heads, |
| | attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, |
| | mlp_ratio=4.0, |
| | **block_kwargs |
| | ): |
| | """ |
| | Self attention block |
| | """ |
| | super().__init__() |
| |
|
| | self.norm1 = nn.LayerNorm(hidden_size) |
| | self.norm2 = nn.LayerNorm(hidden_size) |
| |
|
| | self.attn = attn_class( |
| | embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs |
| | ) |
| |
|
| | mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| |
|
| | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) |
| |
|
| | def forward(self, x, mask=None): |
| | |
| | |
| | |
| | x = self.norm1(x) |
| |
|
| | |
| | |
| |
|
| | attn_output, _ = self.attn(x, x, x) |
| |
|
| | |
| | x = x + attn_output |
| | x = x + self.mlp(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class CrossAttnBlock(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_size, |
| | context_dim, |
| | num_heads=1, |
| | mlp_ratio=4.0, |
| | eps=1e-5, |
| | **block_kwargs |
| | ): |
| | """ |
| | Cross attention block |
| | """ |
| | super().__init__() |
| |
|
| | self.norm1 = nn.LayerNorm(hidden_size, eps=eps) |
| | self.norm_context = nn.LayerNorm(context_dim, eps=eps) |
| | self.norm2 = nn.LayerNorm(hidden_size, eps=eps) |
| |
|
| | self.cross_attn = nn.MultiheadAttention( |
| | embed_dim=hidden_size, |
| | kdim=context_dim, |
| | vdim=context_dim, |
| | num_heads=num_heads, |
| | batch_first=True, |
| | **block_kwargs |
| | ) |
| |
|
| | mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| |
|
| | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) |
| |
|
| | def forward(self, x, context, mask=None): |
| | |
| | x = self.norm1(x) |
| | context = self.norm_context(context) |
| |
|
| | |
| | |
| | attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) |
| |
|
| | |
| | x = x + attn_output |
| | x = x + self.mlp(self.norm2(x)) |
| | return x |
| |
|