Spaces:
Sleeping
Sleeping
| ''' | |
| For MEMO implementations of CIFAR-ConvNet | |
| Reference: | |
| https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py | |
| ''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # for cifar | |
| def conv_block(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2) | |
| ) | |
| class ConvNet2(nn.Module): | |
| def __init__(self, x_dim=3, hid_dim=64, z_dim=64): | |
| super().__init__() | |
| self.out_dim = 64 | |
| self.avgpool = nn.AvgPool2d(8) | |
| self.encoder = nn.Sequential( | |
| conv_block(x_dim, hid_dim), | |
| conv_block(hid_dim, z_dim), | |
| ) | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| x = self.avgpool(x) | |
| features = x.view(x.shape[0], -1) | |
| return { | |
| "features":features | |
| } | |
| class GeneralizedConvNet2(nn.Module): | |
| def __init__(self, x_dim=3, hid_dim=64, z_dim=64): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| conv_block(x_dim, hid_dim), | |
| ) | |
| def forward(self, x): | |
| base_features = self.encoder(x) | |
| return base_features | |
| class SpecializedConvNet2(nn.Module): | |
| def __init__(self,hid_dim=64,z_dim=64): | |
| super().__init__() | |
| self.feature_dim = 64 | |
| self.avgpool = nn.AvgPool2d(8) | |
| self.AdaptiveBlock = conv_block(hid_dim,z_dim) | |
| def forward(self,x): | |
| base_features = self.AdaptiveBlock(x) | |
| pooled = self.avgpool(base_features) | |
| features = pooled.view(pooled.size(0),-1) | |
| return features | |
| def conv2(): | |
| return ConvNet2() | |
| def get_conv_a2fc(): | |
| basenet = GeneralizedConvNet2() | |
| adaptivenet = SpecializedConvNet2() | |
| return basenet,adaptivenet | |
| if __name__ == '__main__': | |
| a, b = get_conv_a2fc() | |
| _base = sum(p.numel() for p in a.parameters()) | |
| _adap = sum(p.numel() for p in b.parameters()) | |
| print(f"conv :{_base+_adap}") | |
| conv2 = conv2() | |
| conv2_sum = sum(p.numel() for p in conv2.parameters()) | |
| print(f"conv2 :{conv2_sum}") |