import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig # -------- building blocks (same as your original) -------- class CBR(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bnorm1 = nn.BatchNorm2d(out_channels) self.relu1 = nn.ReLU(inplace=True) def forward(self, x): return self.relu1(self.bnorm1(self.conv1(x))) class NCBR(nn.Module): def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False): super().__init__() assert N > 1 self.skip = skip self.skipcat = skipcat channels = [in_channels] + [out_channels] * N # len(channels) == N+1 self.layers = nn.ModuleList() for i in range(N): self.layers.append(CBR(channels[i], channels[i + 1])) def forward(self, x): for i, layer in enumerate(self.layers): if i == 0: x = layer(x) x1 = x else: x = layer(x) if self.skip: if self.skipcat: x = torch.cat([x, x1], dim=1) else: x = x + x1 return x class DownNCBR(nn.Module): """Downscaling with maxpool then NCBR""" def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False): super().__init__() self.maxpool = nn.MaxPool2d(2) self.ncbr = NCBR(in_channels, out_channels, N=N, skip=skip, skipcat=skipcat) def forward(self, x): return self.ncbr(self.maxpool(x)) class UpNCBR(nn.Module): """Upscaling then NCBR""" def __init__(self, in_channels, out_channels, N, skip=False, skipcat=False): super().__init__() self.upconv = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.ncbr = NCBR(in_channels, out_channels, N=N, skip=skip, skipcat=skipcat) def forward(self, x1, x2): x1 = self.upconv(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad( x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2] ) x = torch.cat([x2, x1], dim=1) return self.ncbr(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1): super().__init__() padding = kernel_size // 2 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) def forward(self, x): return self.conv(x) # -------- HF config class -------- class UNetConfig(PretrainedConfig): """ Config for the custom UNet. This is what AutoConfig will load when trust_remote_code=True. """ model_type = "custom_unet" def __init__( self, n_channels=1, n_classes=1, N=2, width=32, skip=True, skipcat=False, catorig=False, outker=1, **kwargs, ): super().__init__(**kwargs) self.n_channels = n_channels self.n_classes = n_classes self.N = N self.width = width self.skip = skip self.skipcat = skipcat self.catorig = catorig self.outker = outker # This is also written into config.json when saving from code, # but having it here makes it explicit. self.auto_map = { "AutoConfig": "unet.UNetConfig", "AutoModel": "unet.UNet", } # -------- HF model wrapper around your architecture -------- class UNet(PreTrainedModel): """ PreTrainedModel wrapper so AutoModel can construct and load this. """ config_class = UNetConfig def __init__(self, config: UNetConfig): super().__init__(config) n_channels = config.n_channels n_classes = config.n_classes N = config.N width = config.width skip = config.skip skipcat = config.skipcat self.catorig = config.catorig outker = config.outker self.n_channels = n_channels self.n_classes = n_classes self.inc = NCBR(self.n_channels, 2 * width, N=N, skip=skip) self.down1 = DownNCBR(2 * width, 2 * width if skipcat else 4 * width, N=N, skip=skip, skipcat=skipcat) self.down2 = DownNCBR(4 * width, 4 * width if skipcat else 8 * width, N=N, skip=skip, skipcat=skipcat) self.down3 = DownNCBR(8 * width, 8 * width if skipcat else 16 * width, N=N, skip=skip, skipcat=skipcat) self.down4 = DownNCBR(16 * width, 16 * width if skipcat else 32 * width, N=N, skip=skip, skipcat=skipcat) self.up1 = UpNCBR(32 * width, 8 * width if skipcat else 16 * width, N=N, skip=skip, skipcat=skipcat) self.up2 = UpNCBR(16 * width, 4 * width if skipcat else 8 * width, N=N, skip=skip, skipcat=skipcat) self.up3 = UpNCBR(8 * width, 2 * width if skipcat else 4 * width, N=N, skip=skip, skipcat=skipcat) self.up4 = UpNCBR(4 * width, width if skipcat else 2 * width, N=N, skip=skip, skipcat=skipcat) if self.catorig: self.outc = OutConv(2 * width + self.n_channels, self.n_classes, kernel_size=outker) else: self.outc = OutConv(2 * width, self.n_classes, kernel_size=outker) # HF weight init hook self.post_init() def forward(self, x): orig = x x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) if self.catorig: logits = self.outc(torch.cat([x, orig], dim=1)) else: logits = self.outc(x) return logits