segspace_app / model.py
functionNormally
Fix spatial size mismatch for patch sizes larger than image dimensions
15f6af6
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import NUM_CHANNELS, NUM_CLASSES
class DoubleConv(nn.Module):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class SmallUNet(nn.Module):
def __init__(self, in_channels: int = NUM_CHANNELS, num_classes: int = NUM_CLASSES, base_channels: int = 16):
super().__init__()
self.enc1 = DoubleConv(in_channels, base_channels)
self.pool1 = nn.MaxPool2d(2)
self.enc2 = DoubleConv(base_channels, base_channels * 2)
self.pool2 = nn.MaxPool2d(2)
self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
self.pool3 = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(base_channels * 4, base_channels * 8)
self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
self.dec1 = DoubleConv(base_channels * 2, base_channels)
self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
def forward(self, x):
H, W = x.shape[2], x.shape[3]
e1 = self.enc1(x)
e2 = self.enc2(self.pool1(e1))
e3 = self.enc3(self.pool2(e2))
b = self.bottleneck(self.pool3(e3))
d3 = self.up3(b)
d3 = torch.cat([d3, e3[:, :, :d3.shape[2], :d3.shape[3]]], dim=1)
d3 = self.dec3(d3)
d2 = self.up2(d3)
d2 = torch.cat([d2, e2[:, :, :d2.shape[2], :d2.shape[3]]], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
d1 = torch.cat([d1, e1[:, :, :d1.shape[2], :d1.shape[3]]], dim=1)
d1 = self.dec1(d1)
out = self.head(d1)
if out.shape[2] != H or out.shape[3] != W:
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out