File size: 4,660 Bytes
875bba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class SEResNeXtBlock(nn.Module):
expansion = 2
def __init__(self, in_channels, planes, stride=1, downsample=None, groups=8, base_width=4):
super().__init__()
# Calculate width based on ResNeXt formula
# width = floor(planes * (base_width/64)) * groups
# For small planes, this might be 0. Let's ensure minimum width.
width = int(planes * (base_width / 64.0)) * groups
if width < groups:
width = groups
self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.se = SEBlock(planes * self.expansion)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.se(out)
out += identity
out = self.relu(out)
return out
class SEResNeXt(nn.Module, PyTorchModelHubMixin):
def __init__(
self, layers=[2, 2, 2, 2], planes=[16, 32, 64, 128], dropout_rate=0.5, groups=8, base_width=4
):
super().__init__()
self.in_channels = 32 # Increased stem size
self.groups = groups
self.base_width = base_width
# Stem
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
# Stages
self.layer1 = self._make_layer(planes[0], layers[0], stride=1)
self.layer2 = self._make_layer(planes[1], layers[1], stride=2)
self.layer3 = self._make_layer(planes[2], layers[2], stride=2)
self.layer4 = self._make_layer(planes[3], layers[3], stride=2)
self.dropout = nn.Dropout(p=dropout_rate)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Final channel count is planes[3] * expansion (2)
self.fc = nn.Linear(planes[3] * SEResNeXtBlock.expansion, 1)
self.sigmoid = nn.Sigmoid()
def _make_layer(self, planes, blocks, stride=1):
downsample = None
out_channels = planes * SEResNeXtBlock.expansion
if stride != 1 or self.in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(out_channels),
)
layers = []
layers.append(SEResNeXtBlock(self.in_channels, planes, stride, downsample, self.groups, self.base_width))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(SEResNeXtBlock(self.in_channels, planes, groups=self.groups, base_width=self.base_width))
return nn.Sequential(*layers)
def forward(self, x):
# x: (B, 3, 80, 101)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x)
x = self.sigmoid(x)
return x
if __name__ == "__main__":
from torchinfo import summary
model = SEResNeXt()
summary(model, (1, 3, 80, 101))
|