import torch.nn as nn class GAP(nn.Module): """ Global Average pooling Widely used in ResNet, Inception, DenseNet, etc. """ def __init__(self): super(GAP, self).__init__() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x): x = self.avgpool(x) # x = x.view(x.shape[0], -1) return x class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x