Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| class ResBlock(nn.Module): | |
| def __init__(self, channels): | |
| super(ResBlock, self).__init__() | |
| self.resblock = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(channels), | |
| nn.ReLU(), | |
| nn.Conv2d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(channels), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, x): | |
| return x + self.resblock(x) | |
| class CustomResnet(nn.Module): | |
| def __init__(self): | |
| super(CustomResnet, self).__init__() | |
| self.prep = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=3, | |
| out_channels=64, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| ) | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=64, | |
| out_channels=128, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| bias=False, | |
| ), | |
| nn.MaxPool2d(kernel_size=2), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| ResBlock(channels=128), | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=128, | |
| out_channels=256, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| bias=False, | |
| ), | |
| nn.MaxPool2d(kernel_size=2), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| ) | |
| self.layer3 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=256, | |
| out_channels=512, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| bias=False, | |
| ), | |
| nn.MaxPool2d(kernel_size=2), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| ResBlock(channels=512), | |
| ) | |
| self.pool = nn.MaxPool2d(kernel_size=4) | |
| self.fc = nn.Linear(in_features=512, out_features=10, bias=False) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x): | |
| x = self.prep(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.pool(x) | |
| x = x.view(-1, 512) | |
| x = self.fc(x) | |
| # x = self.softmax(x) | |
| return x | |