import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, dropout=0.0): super(DoubleConv, self).__init__() layers = [ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ] if dropout > 0: layers.append(nn.Dropout2d(p=dropout)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) def crop_to_match(enc_feat, dec_feat): """Center-crop encoder feature map to match size of decoder feature map.""" _, _, H, W = dec_feat.shape enc_H, enc_W = enc_feat.shape[2], enc_feat.shape[3] crop_top = (enc_H - H) // 2 crop_left = (enc_W - W) // 2 return enc_feat[:, :, crop_top:crop_top+H, crop_left:crop_left+W] class UNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, dropout=0.1): super(UNet, self).__init__() # Encoder self.enc1 = DoubleConv(in_channels, 64, dropout=dropout) self.enc2 = DoubleConv(64, 128, dropout=dropout) self.enc3 = DoubleConv(128, 256, dropout=dropout) self.enc4 = DoubleConv(256, 512, dropout=dropout) self.pool = nn.MaxPool2d(2) self.bottleneck = DoubleConv(512, 1024, dropout=dropout) # Decoder self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.dec4 = DoubleConv(1024, 512, dropout=dropout) self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.dec3 = DoubleConv(512, 256, dropout=dropout) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.dec2 = DoubleConv(256, 128, dropout=dropout) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.dec1 = DoubleConv(128, 64, dropout=dropout) self.final = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x): input_size = x.shape[2:] # (H, W) # Encoder e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) e4 = self.enc4(self.pool(e3)) # Bottleneck b = self.bottleneck(self.pool(e4)) # Decoder with cropping d4 = self.up4(b) e4_cropped = crop_to_match(e4, d4) d4 = self.dec4(torch.cat([d4, e4_cropped], dim=1)) d3 = self.up3(d4) e3_cropped = crop_to_match(e3, d3) d3 = self.dec3(torch.cat([d3, e3_cropped], dim=1)) d2 = self.up2(d3) e2_cropped = crop_to_match(e2, d2) d2 = self.dec2(torch.cat([d2, e2_cropped], dim=1)) d1 = self.up1(d2) e1_cropped = crop_to_match(e1, d1) d1 = self.dec1(torch.cat([d1, e1_cropped], dim=1)) out = self.final(d1) # Resize output back to input size (200x200) out = F.interpolate(out, size=input_size, mode="bilinear", align_corners=False) return out # # ========================================================= # # 1. U-Net # # ========================================================= # class UNet(nn.Module): # def __init__(self, in_channels=1, out_channels=1): # super().__init__() # def CBR(in_c, out_c): # return nn.Sequential( # nn.Conv2d(in_c, out_c, 3, padding=1), # nn.BatchNorm2d(out_c), # nn.ReLU(inplace=True) # ) # self.enc1 = nn.Sequential(CBR(in_channels, 64), CBR(64, 64)) # self.enc2 = nn.Sequential(CBR(64, 128), CBR(128, 128)) # self.enc3 = nn.Sequential(CBR(128, 256), CBR(256, 256)) # self.enc4 = nn.Sequential(CBR(256, 512), CBR(512, 512)) # self.pool = nn.MaxPool2d(2, 2) # self.center = nn.Sequential(CBR(512, 1024), CBR(1024, 512)) # self.up4 = nn.ConvTranspose2d(512, 512, 2, stride=2) # self.dec4 = nn.Sequential(CBR(1024, 512), CBR(512, 256)) # self.up3 = nn.ConvTranspose2d(256, 256, 2, stride=2) # self.dec3 = nn.Sequential(CBR(512, 256), CBR(256, 128)) # self.up2 = nn.ConvTranspose2d(128, 128, 2, stride=2) # self.dec2 = nn.Sequential(CBR(256, 128), CBR(128, 64)) # self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2) # self.dec1 = nn.Sequential(CBR(128, 64), nn.Conv2d(64, out_channels, 1)) # def crop_to_match(self, x, target): # _, _, h, w = target.size() # return x[:, :, :h, :w] # def forward(self, x): # e1 = self.enc1(x) # e2 = self.enc2(self.pool(e1)) # e3 = self.enc3(self.pool(e2)) # e4 = self.enc4(self.pool(e3)) # c = self.center(self.pool(e4)) # d4 = self.up4(c) # e4_c = self.crop_to_match(e4, d4) # d4 = torch.cat([d4, e4_c], dim=1) # d4 = self.dec4(d4) # d3 = self.up3(d4) # e3_c = self.crop_to_match(e3, d3) # d3 = torch.cat([d3, e3_c], dim=1) # d3 = self.dec3(d3) # d2 = self.up2(d3) # e2_c = self.crop_to_match(e2, d2) # d2 = torch.cat([d2, e2_c], dim=1) # d2 = self.dec2(d2) # d1 = self.up1(d2) # e1_c = self.crop_to_match(e1, d1) # d1 = torch.cat([d1, e1_c], dim=1) # out = self.dec1(d1) # out = F.interpolate(out, size=x.shape[2:], mode="bilinear", align_corners=False) # return out