Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class UNetInpaint(nn.Module): | |
| def __init__(self, input_channels=4, output_channels=3): | |
| super().__init__() | |
| self.enc1 = self.conv_block(input_channels, 64) | |
| self.enc2 = self.conv_block(64, 128) | |
| self.enc3 = self.conv_block(128, 256) | |
| self.enc4 = self.conv_block(256, 512) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.bottleneck = self.conv_block(512, 1024) | |
| self.upconv4 = self.up_conv_block(1024, 512) | |
| self.dec4 = self.conv_block(1024, 512) | |
| self.upconv3 = self.up_conv_block(512, 256) | |
| self.dec3 = self.conv_block(512, 256) | |
| self.upconv2 = self.up_conv_block(256, 128) | |
| self.dec2 = self.conv_block(256, 128) | |
| self.upconv1 = self.up_conv_block(128, 64) | |
| self.dec1 = self.conv_block(128, 64) | |
| self.out_conv = nn.Conv2d(64, output_channels, 1) | |
| self.final_activation = nn.Sigmoid() | |
| def conv_block(self, in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def up_conv_block(self, in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| e1 = self.enc1(x) | |
| e2 = self.enc2(self.pool(e1)) | |
| e3 = self.enc3(self.pool(e2)) | |
| e4 = self.enc4(self.pool(e3)) | |
| b = self.bottleneck(self.pool(e4)) | |
| d4 = self.upconv4(b) | |
| d4 = self.dec4(torch.cat([d4, e4], dim=1)) | |
| d3 = self.upconv3(d4) | |
| d3 = self.dec3(torch.cat([d3, e3], dim=1)) | |
| d2 = self.upconv2(d3) | |
| d2 = self.dec2(torch.cat([d2, e2], dim=1)) | |
| d1 = self.upconv1(d2) | |
| d1 = self.dec1(torch.cat([d1, e1], dim=1)) | |
| out = self.out_conv(d1) | |
| return self.final_activation(out) | |