| ''' |
| For MEMO implementations of ImageNet-ConvNet |
| Reference: |
| https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py |
| ''' |
| import torch.nn as nn |
| import torch |
|
|
| |
| def first_block(in_channels, out_channels): |
| return nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(), |
| nn.MaxPool2d(2) |
| ) |
| |
| def conv_block(in_channels, out_channels): |
| return nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(), |
| nn.MaxPool2d(2) |
| ) |
|
|
| class ConvNet(nn.Module): |
| def __init__(self, x_dim=3, hid_dim=128, z_dim=512): |
| super().__init__() |
| self.block1 = first_block(x_dim, hid_dim) |
| self.block2 = conv_block(hid_dim, hid_dim) |
| self.block3 = conv_block(hid_dim, hid_dim) |
| self.block4 = conv_block(hid_dim, z_dim) |
| self.avgpool = nn.AvgPool2d(7) |
| self.out_dim = 512 |
|
|
| def forward(self, x): |
| x = self.block1(x) |
| x = self.block2(x) |
| x = self.block3(x) |
| x = self.block4(x) |
|
|
| x = self.avgpool(x) |
| features = x.view(x.shape[0], -1) |
| |
| return { |
| "features": features |
| } |
|
|
| class GeneralizedConvNet(nn.Module): |
| def __init__(self, x_dim=3, hid_dim=128, z_dim=512): |
| super().__init__() |
| self.block1 = first_block(x_dim, hid_dim) |
| self.block2 = conv_block(hid_dim, hid_dim) |
| self.block3 = conv_block(hid_dim, hid_dim) |
|
|
| def forward(self, x): |
| x = self.block1(x) |
| x = self.block2(x) |
| x = self.block3(x) |
| return x |
|
|
| class SpecializedConvNet(nn.Module): |
| def __init__(self, hid_dim=128,z_dim=512): |
| super().__init__() |
| self.block4 = conv_block(hid_dim, z_dim) |
| self.avgpool = nn.AvgPool2d(7) |
| self.feature_dim = 512 |
| |
| def forward(self, x): |
| x = self.block4(x) |
| x = self.avgpool(x) |
| features = x.view(x.shape[0], -1) |
| return features |
| |
| def conv4(): |
| model = ConvNet() |
| return model |
|
|
| def conv_a2fc_imagenet(): |
| _base = GeneralizedConvNet() |
| _adaptive_net = SpecializedConvNet() |
| return _base, _adaptive_net |