| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import timm | |
| class CNNMedium(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.module = nn.Sequential( | |
| nn.Conv2d(3, 16, 3), | |
| nn.MaxPool2d(2, 2), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(16, 32, 3), | |
| nn.MaxPool2d(2, 2), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 15, 3), | |
| nn.MaxPool2d(2, 2), | |
| nn.LeakyReLU(), | |
| nn.Flatten(start_dim=1), | |
| ) | |
| self.head = nn.Sequential( | |
| nn.Linear(60, 20), | |
| nn.LeakyReLU(), | |
| nn.Linear(20, 10), | |
| ) | |
| def forward(self, x): | |
| x = self.module(x) | |
| x = self.head(x) | |
| return x | |
| def Model(): | |
| model = CNNMedium() | |
| return model, model.head | |
| if __name__ == "__main__": | |
| model, _ = Model() | |
| x = torch.ones([4, 3, 32, 32]) | |
| y = model(x) | |
| print(y.shape) | |
| print(model) | |
| num_param = 0 | |
| for v in model.parameters(): | |
| num_param += v.numel() | |
| print("num_param:", num_param) | |