| import torch |
| from torch import nn |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, ch_in, ch_out, down_sample=True): |
| super().__init__() |
| self.conv_1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1, stride=2 if down_sample else 1) |
| self.batch_norm_1 = nn.BatchNorm2d(ch_out) |
| self.relu_1 = nn.ReLU() |
|
|
| self.conv_2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) |
| self.batch_norm_2 = nn.BatchNorm2d(ch_out) |
|
|
| self.flag = down_sample |
| if down_sample: |
| self.down_sample = nn.Conv2d(ch_in, ch_out, stride=2, kernel_size=3, padding=1) |
| |
| self.relu_2 = nn.ReLU() |
|
|
| def forward(self, x): |
| skip_x = x |
| x = self.conv_1(x) |
| x = self.batch_norm_1(x) |
| x = self.relu_1(x) |
| x = self.conv_2(x) |
| x = self.batch_norm_2(x) |
|
|
| if self.flag: |
| skip_x = self.down_sample(skip_x) |
| |
| x = skip_x + x |
| x = self.relu_2(x) |
| return x |
| |
| class ResNetModel(nn.Module): |
| def __init__(self, ch_in, num_classes: int = 5): |
| super().__init__() |
| |
| |
| self.res_blocks = nn.Sequential( |
| ResBlock(ch_in, 64, down_sample=True), |
| ResBlock(64, 128, down_sample=True), |
| ResBlock(128, 256, down_sample=True), |
| ) |
| |
| self.GAP = nn.AdaptiveAvgPool2d((1, 1)) |
| self.linear_head = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(256, num_classes), |
| ) |
| |
| def forward(self, x): |
| return self.linear_head(self.GAP(self.res_blocks(x))) |