| from torch import nn | |
| class DepthWiseConv2d(nn.Module): | |
| def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d( | |
| dim_in, | |
| dim_in, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| groups=dim_in, | |
| stride=stride, | |
| bias=bias, | |
| ), | |
| nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |