BirdGen / model_arch.py
triaNova's picture
3rd
7f31246
import torch
from torch import nn
class ResBlock(nn.Module):
def __init__(self, ch_in, ch_out, down_sample=True):
super().__init__()
self.conv_1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1, stride=2 if down_sample else 1)
self.batch_norm_1 = nn.BatchNorm2d(ch_out)
self.relu_1 = nn.ReLU()
self.conv_2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.batch_norm_2 = nn.BatchNorm2d(ch_out)
self.flag = down_sample
if down_sample:
self.down_sample = nn.Conv2d(ch_in, ch_out, stride=2, kernel_size=3, padding=1)
self.relu_2 = nn.ReLU()
def forward(self, x):
skip_x = x
x = self.conv_1(x)
x = self.batch_norm_1(x)
x = self.relu_1(x)
x = self.conv_2(x)
x = self.batch_norm_2(x)
if self.flag:
skip_x = self.down_sample(skip_x)
x = skip_x + x
x = self.relu_2(x)
return x
class ResNetModel(nn.Module):
def __init__(self, ch_in, num_classes: int = 5):
super().__init__()
self.res_blocks = nn.Sequential(
ResBlock(ch_in, 64, down_sample=True), # 80x345 -> 40x173
ResBlock(64, 128, down_sample=True), # 40x173 -> 20x87
ResBlock(128, 256, down_sample=True), # 20x87 -> 10x44
)
self.GAP = nn.AdaptiveAvgPool2d((1, 1))
self.linear_head = nn.Sequential(
nn.Flatten(),
nn.Linear(256, num_classes),
)
def forward(self, x):
return self.linear_head(self.GAP(self.res_blocks(x)))