|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision.models.resnet import resnet50, ResNet50_Weights |
|
|
|
|
|
def resnet50_multispectral(in_channels, use_pretrained=True): |
|
|
weights = ResNet50_Weights.IMAGENET1K_V1 if use_pretrained else None |
|
|
model = resnet50(weights=weights) |
|
|
|
|
|
old_conv = model.conv1 |
|
|
model.conv1 = nn.Conv2d( |
|
|
in_channels, |
|
|
64, |
|
|
kernel_size=7, |
|
|
stride=2, |
|
|
padding=3, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
if use_pretrained: |
|
|
with torch.no_grad(): |
|
|
if in_channels >= 3: |
|
|
model.conv1.weight[:, :3] = old_conv.weight |
|
|
if in_channels > 3: |
|
|
model.conv1.weight[:, 3:] = old_conv.weight.mean(dim=1, keepdim=True) |
|
|
else: |
|
|
model.conv1.weight = old_conv.weight[:, :in_channels] |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class ASPP(nn.Module): |
|
|
def __init__(self, in_channels, out_channels=256): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False) |
|
|
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False) |
|
|
self.conv3 = nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False) |
|
|
self.conv4 = nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False) |
|
|
|
|
|
self.global_pool = nn.Sequential( |
|
|
nn.AdaptiveAvgPool2d(1), |
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False) |
|
|
) |
|
|
|
|
|
self.project = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
h, w = x.shape[2:] |
|
|
p1 = self.conv1(x) |
|
|
p2 = self.conv2(x) |
|
|
p3 = self.conv3(x) |
|
|
p4 = self.conv4(x) |
|
|
|
|
|
gp = self.global_pool(x) |
|
|
gp = F.interpolate(gp, size=(h, w), mode="bilinear", align_corners=False) |
|
|
|
|
|
x = torch.cat([p1, p2, p3, p4, gp], dim=1) |
|
|
return self.project(x) |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, low_level_channels, num_classes): |
|
|
super().__init__() |
|
|
|
|
|
self.low_proj = nn.Conv2d(low_level_channels, 48, 1, bias=False) |
|
|
|
|
|
self.output = nn.Sequential( |
|
|
nn.Conv2d(304, 256, 3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(256, num_classes, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x, low_level): |
|
|
low_level = self.low_proj(low_level) |
|
|
x = F.interpolate(x, size=low_level.shape[2:], mode="bilinear", align_corners=False) |
|
|
x = torch.cat([x, low_level], dim=1) |
|
|
return self.output(x) |
|
|
|
|
|
|
|
|
class model(nn.Module): |
|
|
def __init__(self, |
|
|
in_channels, |
|
|
num_classes, |
|
|
freeze_encoder=False,): |
|
|
super().__init__() |
|
|
|
|
|
backbone = resnet50_multispectral(in_channels) |
|
|
|
|
|
self.layer0 = nn.Sequential( |
|
|
backbone.conv1, |
|
|
backbone.bn1, |
|
|
backbone.relu, |
|
|
backbone.maxpool |
|
|
) |
|
|
|
|
|
if freeze_encoder: |
|
|
for param in backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.layer1 = backbone.layer1 |
|
|
self.layer2 = backbone.layer2 |
|
|
self.layer3 = backbone.layer3 |
|
|
self.layer4 = backbone.layer4 |
|
|
|
|
|
self.aspp = ASPP(2048) |
|
|
self.decoder = Decoder(256, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
input_size = x.shape[2:] |
|
|
|
|
|
x = self.layer0(x) |
|
|
low_level = self.layer1(x) |
|
|
|
|
|
x = self.layer2(low_level) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
|
|
|
x = self.aspp(x) |
|
|
x = self.decoder(x, low_level) |
|
|
|
|
|
x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False) |
|
|
return x |
|
|
|