road-detection-model / model_def.py
daniel-crawford-dunedain's picture
Initial Space upload
40e9b28 verified
# backend/utils/terrain_analyzer/road_detection_model/model_def.py
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
# === UNet Model Definition ===
class UNet(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 1,
features: List[int] = [64, 128, 256, 512],
):
super(UNet, self).__init__()
self.encoder = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.decoder = nn.ModuleList()
# Encoder
for feature in features:
self.encoder.append(self._conv_block(in_channels, feature))
in_channels = feature
# Bottleneck
self.bottleneck = self._conv_block(features[-1], features[-1] * 2)
# Decoder
for feature in reversed(features):
self.decoder.append(
nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
)
self.decoder.append(self._conv_block(feature * 2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
for down in self.encoder:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.decoder), 2):
x = self.decoder[idx](x)
skip_connection = skip_connections[idx // 2]
if x.shape != skip_connection.shape:
x = F.interpolate(x, size=skip_connection.shape[2:])
x = torch.cat((skip_connection, x), dim=1)
x = self.decoder[idx + 1](x)
return torch.sigmoid(self.final_conv(x))
@staticmethod
def _conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def build_model():
# If you need custom args (e.g., from a config.json), read & pass them here.
return UNet()