File size: 4,123 Bytes
707cbac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.se = SEBlock(out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module, PyTorchModelHubMixin):
def __init__(
self, layers=[2, 2, 2, 2], channels=[16, 24, 48, 96], dropout_rate=0.5
):
super().__init__()
self.in_channels = 16
# Stem
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
# Stages
self.layer1 = self._make_layer(channels[0], layers[0], stride=1)
self.layer2 = self._make_layer(channels[1], layers[1], stride=2)
self.layer3 = self._make_layer(channels[2], layers[2], stride=2)
self.layer4 = self._make_layer(channels[3], layers[3], stride=2)
self.dropout = nn.Dropout(p=dropout_rate)
# Final classification head
# H, W will reduce. Assuming input is (3, 80, 101)
# L1: (16, 80, 101) (stride 1)
# L2: (32, 40, 51) (stride 2)
# L3: (64, 20, 26) (stride 2)
# L4: (128, 10, 13) (stride 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(channels[3], 1)
self.sigmoid = nn.Sigmoid()
def _make_layer(self, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(out_channels),
)
layers = []
layers.append(ResBlock(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(ResBlock(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
# x: (B, 3, 80, 101)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x) # (B, 128, 1, 1)
x = torch.flatten(x, 1) # (B, 128)
x = self.dropout(x)
x = self.fc(x)
x = self.sigmoid(x)
return x
if __name__ == "__main__":
from torchinfo import summary
model = ResNet()
summary(model, (1, 3, 80, 101))
|