Cuervo-x's picture
Update defectArch.py
9363120 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, dropout=0.0):
super(DoubleConv, self).__init__()
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if dropout > 0:
layers.append(nn.Dropout2d(p=dropout))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def crop_to_match(enc_feat, dec_feat):
"""Center-crop encoder feature map to match size of decoder feature map."""
_, _, H, W = dec_feat.shape
enc_H, enc_W = enc_feat.shape[2], enc_feat.shape[3]
crop_top = (enc_H - H) // 2
crop_left = (enc_W - W) // 2
return enc_feat[:, :, crop_top:crop_top+H, crop_left:crop_left+W]
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, dropout=0.1):
super(UNet, self).__init__()
# Encoder
self.enc1 = DoubleConv(in_channels, 64, dropout=dropout)
self.enc2 = DoubleConv(64, 128, dropout=dropout)
self.enc3 = DoubleConv(128, 256, dropout=dropout)
self.enc4 = DoubleConv(256, 512, dropout=dropout)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(512, 1024, dropout=dropout)
# Decoder
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec4 = DoubleConv(1024, 512, dropout=dropout)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = DoubleConv(512, 256, dropout=dropout)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = DoubleConv(256, 128, dropout=dropout)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = DoubleConv(128, 64, dropout=dropout)
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
input_size = x.shape[2:] # (H, W)
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
# Bottleneck
b = self.bottleneck(self.pool(e4))
# Decoder with cropping
d4 = self.up4(b)
e4_cropped = crop_to_match(e4, d4)
d4 = self.dec4(torch.cat([d4, e4_cropped], dim=1))
d3 = self.up3(d4)
e3_cropped = crop_to_match(e3, d3)
d3 = self.dec3(torch.cat([d3, e3_cropped], dim=1))
d2 = self.up2(d3)
e2_cropped = crop_to_match(e2, d2)
d2 = self.dec2(torch.cat([d2, e2_cropped], dim=1))
d1 = self.up1(d2)
e1_cropped = crop_to_match(e1, d1)
d1 = self.dec1(torch.cat([d1, e1_cropped], dim=1))
out = self.final(d1)
# Resize output back to input size (200x200)
out = F.interpolate(out, size=input_size, mode="bilinear", align_corners=False)
return out
# # =========================================================
# # 1. U-Net
# # =========================================================
# class UNet(nn.Module):
# def __init__(self, in_channels=1, out_channels=1):
# super().__init__()
# def CBR(in_c, out_c):
# return nn.Sequential(
# nn.Conv2d(in_c, out_c, 3, padding=1),
# nn.BatchNorm2d(out_c),
# nn.ReLU(inplace=True)
# )
# self.enc1 = nn.Sequential(CBR(in_channels, 64), CBR(64, 64))
# self.enc2 = nn.Sequential(CBR(64, 128), CBR(128, 128))
# self.enc3 = nn.Sequential(CBR(128, 256), CBR(256, 256))
# self.enc4 = nn.Sequential(CBR(256, 512), CBR(512, 512))
# self.pool = nn.MaxPool2d(2, 2)
# self.center = nn.Sequential(CBR(512, 1024), CBR(1024, 512))
# self.up4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
# self.dec4 = nn.Sequential(CBR(1024, 512), CBR(512, 256))
# self.up3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
# self.dec3 = nn.Sequential(CBR(512, 256), CBR(256, 128))
# self.up2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
# self.dec2 = nn.Sequential(CBR(256, 128), CBR(128, 64))
# self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
# self.dec1 = nn.Sequential(CBR(128, 64), nn.Conv2d(64, out_channels, 1))
# def crop_to_match(self, x, target):
# _, _, h, w = target.size()
# return x[:, :, :h, :w]
# def forward(self, x):
# e1 = self.enc1(x)
# e2 = self.enc2(self.pool(e1))
# e3 = self.enc3(self.pool(e2))
# e4 = self.enc4(self.pool(e3))
# c = self.center(self.pool(e4))
# d4 = self.up4(c)
# e4_c = self.crop_to_match(e4, d4)
# d4 = torch.cat([d4, e4_c], dim=1)
# d4 = self.dec4(d4)
# d3 = self.up3(d4)
# e3_c = self.crop_to_match(e3, d3)
# d3 = torch.cat([d3, e3_c], dim=1)
# d3 = self.dec3(d3)
# d2 = self.up2(d3)
# e2_c = self.crop_to_match(e2, d2)
# d2 = torch.cat([d2, e2_c], dim=1)
# d2 = self.dec2(d2)
# d1 = self.up1(d2)
# e1_c = self.crop_to_match(e1, d1)
# d1 = torch.cat([d1, e1_c], dim=1)
# out = self.dec1(d1)
# out = F.interpolate(out, size=x.shape[2:], mode="bilinear", align_corners=False)
# return out