Timerns's picture
Upload folder using huggingface_hub
984cdba verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50, ResNet50_Weights
def resnet50_multispectral(in_channels, use_pretrained=True):
weights = ResNet50_Weights.IMAGENET1K_V1 if use_pretrained else None
model = resnet50(weights=weights)
old_conv = model.conv1
model.conv1 = nn.Conv2d(
in_channels,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False
)
if use_pretrained:
with torch.no_grad():
if in_channels >= 3:
model.conv1.weight[:, :3] = old_conv.weight
if in_channels > 3:
model.conv1.weight[:, 3:] = old_conv.weight.mean(dim=1, keepdim=True)
else:
model.conv1.weight = old_conv.weight[:, :in_channels]
return model
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels=256):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False)
self.conv3 = nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False)
self.conv4 = nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False)
self.global_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False)
)
self.project = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False)
def forward(self, x):
h, w = x.shape[2:]
p1 = self.conv1(x)
p2 = self.conv2(x)
p3 = self.conv3(x)
p4 = self.conv4(x)
gp = self.global_pool(x)
gp = F.interpolate(gp, size=(h, w), mode="bilinear", align_corners=False)
x = torch.cat([p1, p2, p3, p4, gp], dim=1)
return self.project(x)
class Decoder(nn.Module):
def __init__(self, low_level_channels, num_classes):
super().__init__()
self.low_proj = nn.Conv2d(low_level_channels, 48, 1, bias=False)
self.output = nn.Sequential(
nn.Conv2d(304, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
def forward(self, x, low_level):
low_level = self.low_proj(low_level)
x = F.interpolate(x, size=low_level.shape[2:], mode="bilinear", align_corners=False)
x = torch.cat([x, low_level], dim=1)
return self.output(x)
class model(nn.Module):
def __init__(self,
in_channels,
num_classes,
freeze_encoder=False,):
super().__init__()
backbone = resnet50_multispectral(in_channels)
self.layer0 = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool
)
if freeze_encoder:
for param in backbone.parameters():
param.requires_grad = False
self.layer1 = backbone.layer1 # low-level features
self.layer2 = backbone.layer2
self.layer3 = backbone.layer3
self.layer4 = backbone.layer4 # high-level features
self.aspp = ASPP(2048)
self.decoder = Decoder(256, num_classes)
def forward(self, x):
input_size = x.shape[2:]
x = self.layer0(x)
low_level = self.layer1(x)
x = self.layer2(low_level)
x = self.layer3(x)
x = self.layer4(x)
x = self.aspp(x)
x = self.decoder(x, low_level)
x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False)
return x