| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.cuda.amp import autocast |
|
|
| from src.efficientvit.models.nn.act import build_act |
| from src.efficientvit.models.nn.norm import build_norm |
| from src.efficientvit.models.utils import (get_same_padding, list_sum, resize, |
| val2list, val2tuple) |
|
|
| __all__ = [ |
| "ConvLayer", |
| "UpSampleLayer", |
| "LinearLayer", |
| "IdentityLayer", |
| "DSConv", |
| "MBConv", |
| "FusedMBConv", |
| "ResBlock", |
| "LiteMLA", |
| "EfficientViTBlock", |
| "ResidualBlock", |
| "DAGBlock", |
| "OpSequential", |
| ] |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ConvLayer(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size=3, |
| stride=1, |
| dilation=1, |
| groups=1, |
| use_bias=False, |
| dropout=0, |
| norm="bn2d", |
| act_func="relu", |
| ): |
| super(ConvLayer, self).__init__() |
|
|
| padding = get_same_padding(kernel_size) |
| padding *= dilation |
|
|
| self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None |
| self.conv = nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=(kernel_size, kernel_size), |
| stride=(stride, stride), |
| padding=padding, |
| dilation=(dilation, dilation), |
| groups=groups, |
| bias=use_bias, |
| ) |
| self.norm = build_norm(norm, num_features=out_channels) |
| self.act = build_act(act_func) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.dropout is not None: |
| x = self.dropout(x) |
| x = self.conv(x) |
| if self.norm: |
| x = self.norm(x) |
| if self.act: |
| x = self.act(x) |
| return x |
|
|
|
|
| class UpSampleLayer(nn.Module): |
| def __init__( |
| self, |
| mode="bicubic", |
| size: int or tuple[int, int] or list[int] or None = None, |
| factor=2, |
| align_corners=False, |
| ): |
| super(UpSampleLayer, self).__init__() |
| self.mode = mode |
| self.size = val2list(size, 2) if size is not None else None |
| self.factor = None if self.size is not None else factor |
| self.align_corners = align_corners |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if ( |
| self.size is not None and tuple(x.shape[-2:]) == self.size |
| ) or self.factor == 1: |
| return x |
| return resize(x, self.size, self.factor, self.mode, self.align_corners) |
|
|
|
|
| class LinearLayer(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| use_bias=True, |
| dropout=0, |
| norm=None, |
| act_func=None, |
| ): |
| super(LinearLayer, self).__init__() |
|
|
| self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None |
| self.linear = nn.Linear(in_features, out_features, use_bias) |
| self.norm = build_norm(norm, num_features=out_features) |
| self.act = build_act(act_func) |
|
|
| def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor: |
| if x.dim() > 2: |
| x = torch.flatten(x, start_dim=1) |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self._try_squeeze(x) |
| if self.dropout: |
| x = self.dropout(x) |
| x = self.linear(x) |
| if self.norm: |
| x = self.norm(x) |
| if self.act: |
| x = self.act(x) |
| return x |
|
|
|
|
| class IdentityLayer(nn.Module): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| class DSConv(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size=3, |
| stride=1, |
| use_bias=False, |
| norm=("bn2d", "bn2d"), |
| act_func=("relu6", None), |
| ): |
| super(DSConv, self).__init__() |
|
|
| use_bias = val2tuple(use_bias, 2) |
| norm = val2tuple(norm, 2) |
| act_func = val2tuple(act_func, 2) |
|
|
| self.depth_conv = ConvLayer( |
| in_channels, |
| in_channels, |
| kernel_size, |
| stride, |
| groups=in_channels, |
| norm=norm[0], |
| act_func=act_func[0], |
| use_bias=use_bias[0], |
| ) |
| self.point_conv = ConvLayer( |
| in_channels, |
| out_channels, |
| 1, |
| norm=norm[1], |
| act_func=act_func[1], |
| use_bias=use_bias[1], |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.depth_conv(x) |
| x = self.point_conv(x) |
| return x |
|
|
|
|
| class MBConv(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size=3, |
| stride=1, |
| mid_channels=None, |
| expand_ratio=6, |
| use_bias=False, |
| norm=("bn2d", "bn2d", "bn2d"), |
| act_func=("relu6", "relu6", None), |
| ): |
| super(MBConv, self).__init__() |
|
|
| use_bias = val2tuple(use_bias, 3) |
| norm = val2tuple(norm, 3) |
| act_func = val2tuple(act_func, 3) |
| mid_channels = mid_channels or round(in_channels * expand_ratio) |
|
|
| self.inverted_conv = ConvLayer( |
| in_channels, |
| mid_channels, |
| 1, |
| stride=1, |
| norm=norm[0], |
| act_func=act_func[0], |
| use_bias=use_bias[0], |
| ) |
| self.depth_conv = ConvLayer( |
| mid_channels, |
| mid_channels, |
| kernel_size, |
| stride=stride, |
| groups=mid_channels, |
| norm=norm[1], |
| act_func=act_func[1], |
| use_bias=use_bias[1], |
| ) |
| self.point_conv = ConvLayer( |
| mid_channels, |
| out_channels, |
| 1, |
| norm=norm[2], |
| act_func=act_func[2], |
| use_bias=use_bias[2], |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.inverted_conv(x) |
| x = self.depth_conv(x) |
| x = self.point_conv(x) |
| return x |
|
|
|
|
| class FusedMBConv(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size=3, |
| stride=1, |
| mid_channels=None, |
| expand_ratio=6, |
| groups=1, |
| use_bias=False, |
| norm=("bn2d", "bn2d"), |
| act_func=("relu6", None), |
| ): |
| super().__init__() |
| use_bias = val2tuple(use_bias, 2) |
| norm = val2tuple(norm, 2) |
| act_func = val2tuple(act_func, 2) |
|
|
| mid_channels = mid_channels or round(in_channels * expand_ratio) |
|
|
| self.spatial_conv = ConvLayer( |
| in_channels, |
| mid_channels, |
| kernel_size, |
| stride, |
| groups=groups, |
| use_bias=use_bias[0], |
| norm=norm[0], |
| act_func=act_func[0], |
| ) |
| self.point_conv = ConvLayer( |
| mid_channels, |
| out_channels, |
| 1, |
| use_bias=use_bias[1], |
| norm=norm[1], |
| act_func=act_func[1], |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.spatial_conv(x) |
| x = self.point_conv(x) |
| return x |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size=3, |
| stride=1, |
| mid_channels=None, |
| expand_ratio=1, |
| use_bias=False, |
| norm=("bn2d", "bn2d"), |
| act_func=("relu6", None), |
| ): |
| super().__init__() |
| use_bias = val2tuple(use_bias, 2) |
| norm = val2tuple(norm, 2) |
| act_func = val2tuple(act_func, 2) |
|
|
| mid_channels = mid_channels or round(in_channels * expand_ratio) |
|
|
| self.conv1 = ConvLayer( |
| in_channels, |
| mid_channels, |
| kernel_size, |
| stride, |
| use_bias=use_bias[0], |
| norm=norm[0], |
| act_func=act_func[0], |
| ) |
| self.conv2 = ConvLayer( |
| mid_channels, |
| out_channels, |
| kernel_size, |
| 1, |
| use_bias=use_bias[1], |
| norm=norm[1], |
| act_func=act_func[1], |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
|
|
|
|
| class LiteMLA(nn.Module): |
| r"""Lightweight multi-scale linear attention""" |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| heads: int or None = None, |
| heads_ratio: float = 1.0, |
| dim=8, |
| use_bias=False, |
| norm=(None, "bn2d"), |
| act_func=(None, None), |
| kernel_func="relu", |
| scales: tuple[int, ...] = (5,), |
| eps=1.0e-15, |
| ): |
| super(LiteMLA, self).__init__() |
| self.eps = eps |
| heads = heads or int(in_channels // dim * heads_ratio) |
|
|
| total_dim = heads * dim |
|
|
| use_bias = val2tuple(use_bias, 2) |
| norm = val2tuple(norm, 2) |
| act_func = val2tuple(act_func, 2) |
|
|
| self.dim = dim |
| self.qkv = ConvLayer( |
| in_channels, |
| 3 * total_dim, |
| 1, |
| use_bias=use_bias[0], |
| norm=norm[0], |
| act_func=act_func[0], |
| ) |
| self.aggreg = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.Conv2d( |
| 3 * total_dim, |
| 3 * total_dim, |
| scale, |
| padding=get_same_padding(scale), |
| groups=3 * total_dim, |
| bias=use_bias[0], |
| ), |
| nn.Conv2d( |
| 3 * total_dim, |
| 3 * total_dim, |
| 1, |
| groups=3 * heads, |
| bias=use_bias[0], |
| ), |
| ) |
| for scale in scales |
| ] |
| ) |
| self.kernel_func = build_act(kernel_func, inplace=False) |
|
|
| self.proj = ConvLayer( |
| total_dim * (1 + len(scales)), |
| out_channels, |
| 1, |
| use_bias=use_bias[1], |
| norm=norm[1], |
| act_func=act_func[1], |
| ) |
|
|
| @autocast(enabled=False) |
| def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: |
| B, _, H, W = list(qkv.size()) |
|
|
| if qkv.dtype == torch.float16: |
| qkv = qkv.float() |
|
|
| qkv = torch.reshape( |
| qkv, |
| ( |
| B, |
| -1, |
| 3 * self.dim, |
| H * W, |
| ), |
| ) |
| qkv = torch.transpose(qkv, -1, -2) |
| q, k, v = ( |
| qkv[..., 0 : self.dim], |
| qkv[..., self.dim : 2 * self.dim], |
| qkv[..., 2 * self.dim :], |
| ) |
|
|
| |
| q = self.kernel_func(q) |
| k = self.kernel_func(k) |
|
|
| |
| trans_k = k.transpose(-1, -2) |
|
|
| v = F.pad(v, (0, 1), mode="constant", value=1) |
| kv = torch.matmul(trans_k, v) |
| out = torch.matmul(q, kv) |
| out = out[..., :-1] / (out[..., -1:] + self.eps) |
|
|
| out = torch.transpose(out, -1, -2) |
| out = torch.reshape(out, (B, -1, H, W)) |
| return out |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| qkv = self.qkv(x) |
| multi_scale_qkv = [qkv] |
| for op in self.aggreg: |
| multi_scale_qkv.append(op(qkv)) |
| multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) |
|
|
| out = self.relu_linear_att(multi_scale_qkv) |
| out = self.proj(out) |
|
|
| return out |
|
|
|
|
| class EfficientViTBlock(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| heads_ratio: float = 1.0, |
| dim=32, |
| expand_ratio: float = 4, |
| scales=(5,), |
| norm="bn2d", |
| act_func="hswish", |
| ): |
| super(EfficientViTBlock, self).__init__() |
| self.context_module = ResidualBlock( |
| LiteMLA( |
| in_channels=in_channels, |
| out_channels=in_channels, |
| heads_ratio=heads_ratio, |
| dim=dim, |
| norm=(None, norm), |
| scales=scales, |
| ), |
| IdentityLayer(), |
| ) |
| local_module = MBConv( |
| in_channels=in_channels, |
| out_channels=in_channels, |
| expand_ratio=expand_ratio, |
| use_bias=(True, True, False), |
| norm=(None, None, norm), |
| act_func=(act_func, act_func, None), |
| ) |
| self.local_module = ResidualBlock(local_module, IdentityLayer()) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.context_module(x) |
| x = self.local_module(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ResidualBlock(nn.Module): |
| def __init__( |
| self, |
| main: nn.Module or None, |
| shortcut: nn.Module or None, |
| post_act=None, |
| pre_norm: nn.Module or None = None, |
| ): |
| super(ResidualBlock, self).__init__() |
|
|
| self.pre_norm = pre_norm |
| self.main = main |
| self.shortcut = shortcut |
| self.post_act = build_act(post_act) |
|
|
| def forward_main(self, x: torch.Tensor) -> torch.Tensor: |
| if self.pre_norm is None: |
| return self.main(x) |
| else: |
| return self.main(self.pre_norm(x)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.main is None: |
| res = x |
| elif self.shortcut is None: |
| res = self.forward_main(x) |
| else: |
| res = self.forward_main(x) + self.shortcut(x) |
| if self.post_act: |
| res = self.post_act(res) |
| return res |
|
|
|
|
| class DAGBlock(nn.Module): |
| def __init__( |
| self, |
| inputs: dict[str, nn.Module], |
| merge: str, |
| post_input: nn.Module or None, |
| middle: nn.Module, |
| outputs: dict[str, nn.Module], |
| ): |
| super(DAGBlock, self).__init__() |
|
|
| self.input_keys = list(inputs.keys()) |
| self.input_ops = nn.ModuleList(list(inputs.values())) |
| self.merge = merge |
| self.post_input = post_input |
|
|
| self.middle = middle |
|
|
| self.output_keys = list(outputs.keys()) |
| self.output_ops = nn.ModuleList(list(outputs.values())) |
|
|
| def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| feat = [ |
| op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops) |
| ] |
| if self.merge == "add": |
| feat = list_sum(feat) |
| elif self.merge == "cat": |
| feat = torch.concat(feat, dim=1) |
| else: |
| raise NotImplementedError |
| if self.post_input is not None: |
| feat = self.post_input(feat) |
| feat = self.middle(feat) |
| for key, op in zip(self.output_keys, self.output_ops): |
| feature_dict[key] = op(feat) |
| return feature_dict |
|
|
|
|
| class OpSequential(nn.Module): |
| def __init__(self, op_list: list[nn.Module or None]): |
| super(OpSequential, self).__init__() |
| valid_op_list = [] |
| for op in op_list: |
| if op is not None: |
| valid_op_list.append(op) |
| self.op_list = nn.ModuleList(valid_op_list) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for op in self.op_list: |
| x = op(x) |
| return x |
|
|