BFZD233
initial
f06f310
import os
import sys
import logging
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
class OffsetConfidence(nn.Module):
def __init__(self, args):
super(OffsetConfidence, self).__init__()
self.detach = args.detach_in_confidence
self.offset_memory_size = args.offset_memory_size
self.conv_fea = nn.Conv2d(256, 16, 3, padding=1)
self.conv_offset = nn.Conv2d(2*args.offset_memory_size, 16, 3, padding=1)
self.fusion = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(32, 8, 3, padding=1)),
('relu1', nn.LeakyReLU(inplace=True)),
('conv2', nn.Conv2d(8, 2, 3, padding=1)),
('relu2', nn.LeakyReLU(inplace=True)),
('conv3', nn.Conv2d(2, 1, 1, padding=0)),
]))
if "local_rank" not in args or args.local_rank==0 :
logging.info(f"OffsetConfidence: " + \
f"detach: {args.detach_in_confidence}")
def forward(self, fea, offset_memory):
if type(fea) is list:
fea = torch.cat(fea, dim=1)
context = self.conv_fea(fea.detach() if self.detach else fea)
offset_memory = torch.cat([offset.detach() if self.detach else offset for offset in offset_memory], dim=1)
confidence = self.conv_offset( -offset_memory )
confidence = self.fusion( torch.cat([confidence,context], dim=1) )
return confidence
class MBConvBlockSimple(nn.Module):
def __init__(self, in_channels, out_channels, expand_ratio=1, kernel_size=3, stride=1, se_ratio=0.25):
super(MBConvBlockSimple, self).__init__()
self.has_se = se_ratio is not None and 0 < se_ratio <= 1
self.expand_ratio = expand_ratio
mid_channels = in_channels * expand_ratio
if expand_ratio != 1:
self.expand_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
self.bn0 = nn.BatchNorm2d(mid_channels)
self.depthwise_conv = nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride,
padding=kernel_size // 2, groups=mid_channels, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
if self.has_se:
se_channels = max(1, int(in_channels * se_ratio))
self.se_reduce = nn.Conv2d(mid_channels, se_channels, kernel_size=1)
self.se_expand = nn.Conv2d(se_channels, mid_channels, kernel_size=1)
self.project_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.swish = nn.SiLU(inplace=True)
self.use_residual = (stride == 1 and in_channels == out_channels)
def forward(self, x):
identity = x
if self.expand_ratio != 1:
x = self.swish(self.bn0(self.expand_conv(x)))
x = self.swish(self.bn1(self.depthwise_conv(x)))
if self.has_se:
se = F.adaptive_avg_pool2d(x, 1)
se = self.swish(self.se_reduce(se))
se = torch.sigmoid(self.se_expand(se))
x = x * se
x = self.bn2(self.project_conv(x))
if self.use_residual:
x = x + identity
return x
class EfficientNetB1SimpleEncoder(nn.Module):
def __init__(self, in_C=2):
super(EfficientNetB1SimpleEncoder, self).__init__()
self.pre_pro = nn.Sequential(
nn.Conv2d(in_C, 8, 3, padding=1),
nn.BatchNorm2d(8),
nn.SiLU(inplace=True),
nn.Conv2d(8, 8, 3, padding=1),
nn.BatchNorm2d(8),
nn.SiLU(inplace=True),
)
# Stem, first downsampling
self.stem = nn.Sequential(
nn.Conv2d(8, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.SiLU(inplace=True)
)
# EfficientNet-B1 Layers Configuration
layers_config = [
(32, 16, 1, 3, 1, 1), # Stage 1 (no downsampling)
(16, 24, 6, 3, 2, 2), # Stage 2 (second downsampling)
(24, 40, 6, 5, 2, 2), # Stage 3 (third downsampling)
]
# Building EfficientNet-B1 stages
self.blocks = nn.ModuleList()
for in_channels, out_channels, expand_ratio, kernel_size, stride, repeats in layers_config:
block_layers = []
block_layers.append(MBConvBlockSimple(in_channels, out_channels, expand_ratio, kernel_size, stride))
for _ in range(repeats - 1):
block_layers.append(MBConvBlockSimple(out_channels, out_channels, expand_ratio, kernel_size, stride=1))
self.blocks.append(nn.Sequential(*block_layers))
def forward(self, x):
features = []
x = self.pre_pro(x)
features.append(x) # Store features for skip connections
x = self.stem(x)
for block in self.blocks:
x = block(x)
features.append(x) # Store features for skip connections
return features
class EfficientUNetSimple(nn.Module):
def __init__(self, num_classes=1):
super(EfficientUNetSimple, self).__init__()
# Encoder using EfficientNet-B1 with only three stages
self.encoder = EfficientNetB1SimpleEncoder()
# Decoder layers (Upsampling)
self.upconv3 = nn.Conv2d(40, 24, kernel_size=1)
self.up3 = nn.ConvTranspose2d(24, 24, kernel_size=2, stride=2)
self.upconv2 = nn.Conv2d(24, 16, kernel_size=1)
self.up2 = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2)
self.upconv1 = nn.Conv2d(16, 8, kernel_size=1)
self.up1 = nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2)
# Final conv layer
self.final_conv = nn.Conv2d(8, num_classes, kernel_size=1)
def forward(self, x):
# Encoder
features = self.encoder(x)
# print("-"*30, features[-1].shape, features[-2].shape, features[-3].shape, features[-4].shape)
# Decoder with skip connections
x = self.up3(self.upconv3(features[-1])) + features[-2] # 1/8 ~ 1/4
x = self.up2(self.upconv2(x)) + features[-3] # 1/4 ~ 1/2
x = self.up1(self.upconv1(x)) + features[-4] # 1/2 ~ 1
# Final output layer
x = self.final_conv(x)
return x