import torch.nn as nn class GFM_Module(nn.Module): def __init__(self, in_channels, out_channels, ratio=2): super().__init__() init_channels = out_channels // ratio new_channels = out_channels - init_channels self.primary_conv = nn.Sequential( nn.Conv2d(in_channels, init_channels, 1, bias=False), nn.BatchNorm2d(init_channels), nn.ReLU(inplace=True) ) self.cheap_operation = nn.Sequential( nn.Conv2d(init_channels, new_channels, 3, 1, 1, groups=init_channels, bias=False), nn.BatchNorm2d(new_channels), nn.ReLU(inplace=True) ) def forward(self, x): # print("input:", x.shape) x1 = self.primary_conv(x) # print("primary conv output:", x1.shape) x2 = self.cheap_operation(x1) # print("cheap operation output:", x2.shape) return x1, x2