nano6626's picture
Update unet.py
7da32d6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
# -------- building blocks (same as your original) --------
class CBR(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bnorm1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu1(self.bnorm1(self.conv1(x)))
class NCBR(nn.Module):
def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False):
super().__init__()
assert N > 1
self.skip = skip
self.skipcat = skipcat
channels = [in_channels] + [out_channels] * N # len(channels) == N+1
self.layers = nn.ModuleList()
for i in range(N):
self.layers.append(CBR(channels[i], channels[i + 1]))
def forward(self, x):
for i, layer in enumerate(self.layers):
if i == 0:
x = layer(x)
x1 = x
else:
x = layer(x)
if self.skip:
if self.skipcat:
x = torch.cat([x, x1], dim=1)
else:
x = x + x1
return x
class DownNCBR(nn.Module):
"""Downscaling with maxpool then NCBR"""
def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False):
super().__init__()
self.maxpool = nn.MaxPool2d(2)
self.ncbr = NCBR(in_channels, out_channels, N=N, skip=skip, skipcat=skipcat)
def forward(self, x):
return self.ncbr(self.maxpool(x))
class UpNCBR(nn.Module):
"""Upscaling then NCBR"""
def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.ncbr = NCBR(in_channels, out_channels, N=N, skip=skip, skipcat=skipcat)
def forward(self, x1, x2):
x1 = self.upconv(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(
x1,
[diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2]
)
x = torch.cat([x2, x1], dim=1)
return self.ncbr(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
def forward(self, x):
return self.conv(x)
# -------- HF config class --------
class UNetConfig(PretrainedConfig):
"""
Config for the custom UNet. This is what AutoConfig will load
when trust_remote_code=True.
"""
model_type = "custom_unet"
def __init__(
self,
n_channels=1,
n_classes=1,
N=2,
width=32,
skip=True,
skipcat=False,
catorig=False,
outker=1,
**kwargs,
):
super().__init__(**kwargs)
self.n_channels = n_channels
self.n_classes = n_classes
self.N = N
self.width = width
self.skip = skip
self.skipcat = skipcat
self.catorig = catorig
self.outker = outker
# This is also written into config.json when saving from code,
# but having it here makes it explicit.
self.auto_map = {
"AutoConfig": "unet.UNetConfig",
"AutoModel": "unet.UNet",
}
# -------- HF model wrapper around your architecture --------
class UNet(PreTrainedModel):
"""
PreTrainedModel wrapper so AutoModel can construct and load this.
"""
config_class = UNetConfig
def __init__(self, config: UNetConfig):
super().__init__(config)
n_channels = config.n_channels
n_classes = config.n_classes
N = config.N
width = config.width
skip = config.skip
skipcat = config.skipcat
self.catorig = config.catorig
outker = config.outker
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = NCBR(self.n_channels, 2 * width, N=N, skip=skip)
self.down1 = DownNCBR(2 * width, 2 * width if skipcat else 4 * width, N=N, skip=skip, skipcat=skipcat)
self.down2 = DownNCBR(4 * width, 4 * width if skipcat else 8 * width, N=N, skip=skip, skipcat=skipcat)
self.down3 = DownNCBR(8 * width, 8 * width if skipcat else 16 * width, N=N, skip=skip, skipcat=skipcat)
self.down4 = DownNCBR(16 * width, 16 * width if skipcat else 32 * width, N=N, skip=skip, skipcat=skipcat)
self.up1 = UpNCBR(32 * width, 8 * width if skipcat else 16 * width, N=N, skip=skip, skipcat=skipcat)
self.up2 = UpNCBR(16 * width, 4 * width if skipcat else 8 * width, N=N, skip=skip, skipcat=skipcat)
self.up3 = UpNCBR(8 * width, 2 * width if skipcat else 4 * width, N=N, skip=skip, skipcat=skipcat)
self.up4 = UpNCBR(4 * width, width if skipcat else 2 * width, N=N, skip=skip, skipcat=skipcat)
if self.catorig:
self.outc = OutConv(2 * width + self.n_channels, self.n_classes, kernel_size=outker)
else:
self.outc = OutConv(2 * width, self.n_classes, kernel_size=outker)
# HF weight init hook
self.post_init()
def forward(self, x):
orig = x
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
if self.catorig:
logits = self.outc(torch.cat([x, orig], dim=1))
else:
logits = self.outc(x)
return logits