Cuervo-x's picture
Update printArch.py
1c73e9b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
def CBR(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
self.enc1 = nn.Sequential(CBR(in_channels, 64), CBR(64, 64))
self.enc2 = nn.Sequential(CBR(64, 128), CBR(128, 128))
self.enc3 = nn.Sequential(CBR(128, 256), CBR(256, 256))
self.enc4 = nn.Sequential(CBR(256, 512), CBR(512, 512))
self.pool = nn.MaxPool2d(2, 2)
self.center = nn.Sequential(CBR(512, 1024), CBR(1024, 512))
self.up4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
self.dec4 = nn.Sequential(CBR(1024, 512), CBR(512, 256))
self.up3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
self.dec3 = nn.Sequential(CBR(512, 256), CBR(256, 128))
self.up2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
self.dec2 = nn.Sequential(CBR(256, 128), CBR(128, 64))
self.up1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
self.dec1 = nn.Sequential(CBR(128, 64), nn.Conv2d(64, out_channels, 1))
def crop_to_match(self, x, target):
_, _, h, w = target.size()
return x[:, :, :h, :w]
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
c = self.center(self.pool(e4))
d4 = self.up4(c)
e4_c = self.crop_to_match(e4, d4)
d4 = torch.cat([d4, e4_c], dim=1)
d4 = self.dec4(d4)
d3 = self.up3(d4)
e3_c = self.crop_to_match(e3, d3)
d3 = torch.cat([d3, e3_c], dim=1)
d3 = self.dec3(d3)
d2 = self.up2(d3)
e2_c = self.crop_to_match(e2, d2)
d2 = torch.cat([d2, e2_c], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
e1_c = self.crop_to_match(e1, d1)
d1 = torch.cat([d1, e1_c], dim=1)
out = self.dec1(d1)
return out
# =========================================================
# DeepLabv3
# =========================================================
def get_deeplabv3(in_channels=1, out_channels=1):
# Initialize the model without any pre-downloaded weights
model = models.segmentation.deeplabv3_resnet50(
weights=None,
weights_backbone=None,
num_classes=out_channels
)
# 1. Fix the main input layer for grayscale (1 channel)
model.backbone.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
# 2. Safely handle the Auxiliary Classifier
# If it exists, adjust its input; if not, ignore it.
if hasattr(model, 'aux_classifier') and model.aux_classifier is not None:
# Some versions of torchvision use a Sequential block here
try:
# This is likely where your error was occurring
in_channels_aux = model.aux_classifier[4].in_channels
model.aux_classifier[4] = nn.Conv2d(in_channels_aux, out_channels, 1)
except (TypeError, IndexError):
# If the structure is different or empty, we skip it for inference
pass
return model