Spaces:
Build error
Build error
| import torch | |
| from torch import nn, optim | |
| from loss import GANLoss | |
| class UnetBlock(nn.Module): | |
| def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, | |
| innermost=False, outermost=False): | |
| super().__init__() | |
| self.outermost = outermost | |
| if input_c is None: input_c = nf | |
| downconv = nn.Conv2d(input_c, ni, kernel_size=4, | |
| stride=2, padding=1, bias=False) | |
| downrelu = nn.LeakyReLU(0.2, True) | |
| downnorm = nn.BatchNorm2d(ni) | |
| uprelu = nn.ReLU(True) | |
| upnorm = nn.BatchNorm2d(nf) | |
| if outermost: | |
| upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, | |
| stride=2, padding=1) | |
| down = [downconv] | |
| up = [uprelu, upconv, nn.Tanh()] | |
| model = down + [submodule] + up | |
| elif innermost: | |
| upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, | |
| stride=2, padding=1, bias=False) | |
| down = [downrelu, downconv] | |
| up = [uprelu, upconv, upnorm] | |
| model = down + up | |
| else: | |
| upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, | |
| stride=2, padding=1, bias=False) | |
| down = [downrelu, downconv, downnorm] | |
| up = [uprelu, upconv, upnorm] | |
| if dropout: up += [nn.Dropout(0.5)] | |
| model = down + [submodule] + up | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| if self.outermost: | |
| return self.model(x) | |
| else: | |
| return torch.cat([x, self.model(x)], 1) | |
| class Unet(nn.Module): | |
| def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64): | |
| super().__init__() | |
| unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True) | |
| for _ in range(n_down - 5): | |
| unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True) | |
| out_filters = num_filters * 8 | |
| for _ in range(3): | |
| unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block) | |
| out_filters //= 2 | |
| self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True) | |
| def forward(self, x): | |
| return self.model(x) | |
| class PatchDiscriminator(nn.Module): | |
| def __init__(self, input_c, num_filters=64, n_down=3): | |
| super().__init__() | |
| model = [self.get_layers(input_c, num_filters, norm=False)] | |
| model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down - 1) else 2) | |
| for i in range(n_down)] # the 'if' statement is taking care of not using | |
| # stride of 2 for the last block in this loop | |
| model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, | |
| act=False)] # Make sure to not use normalization or | |
| # activation for the last layer of the model | |
| self.model = nn.Sequential(*model) | |
| def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, | |
| act=True): # when needing to make some repeatitive blocks of layers, | |
| layers = [ | |
| nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose | |
| if norm: layers += [nn.BatchNorm2d(nf)] | |
| if act: layers += [nn.LeakyReLU(0.2, True)] | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |
| def init_weights(net, init='norm', gain=0.02): | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and 'Conv' in classname: | |
| if init == 'norm': | |
| nn.init.normal_(m.weight.data, mean=0.0, std=gain) | |
| elif init == 'xavier': | |
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init == 'kaiming': | |
| nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| nn.init.constant_(m.bias.data, 0.0) | |
| elif 'BatchNorm2d' in classname: | |
| nn.init.normal_(m.weight.data, 1., gain) | |
| nn.init.constant_(m.bias.data, 0.) | |
| net.apply(init_func) | |
| print(f"model initialized with {init} initialization") | |
| return net | |
| def init_model(model, device): | |
| model = model.to(device) | |
| model = init_weights(model) | |
| return model | |
| class MainModel(nn.Module): | |
| def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, | |
| beta1=0.5, beta2=0.999, lambda_L1=100.): | |
| super().__init__() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.lambda_L1 = lambda_L1 | |
| if net_G is None: | |
| self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device) | |
| else: | |
| self.net_G = net_G.to(self.device) | |
| self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device) | |
| self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device) | |
| self.L1criterion = nn.L1Loss() | |
| self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2)) | |
| self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2)) | |
| def set_requires_grad(self, model, requires_grad=True): | |
| for p in model.parameters(): | |
| p.requires_grad = requires_grad | |
| def setup_input(self, data): | |
| self.L = data['L'].to(self.device) | |
| self.ab = data['ab'].to(self.device) | |
| def forward(self): | |
| self.fake_color = self.net_G(self.L) | |
| def backward_D(self): | |
| fake_image = torch.cat([self.L, self.fake_color], dim=1) | |
| fake_preds = self.net_D(fake_image.detach()) | |
| self.loss_D_fake = self.GANcriterion(fake_preds, False) | |
| real_image = torch.cat([self.L, self.ab], dim=1) | |
| real_preds = self.net_D(real_image) | |
| self.loss_D_real = self.GANcriterion(real_preds, True) | |
| self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 | |
| self.loss_D.backward() | |
| def backward_G(self): | |
| fake_image = torch.cat([self.L, self.fake_color], dim=1) | |
| fake_preds = self.net_D(fake_image) | |
| self.loss_G_GAN = self.GANcriterion(fake_preds, True) | |
| self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1 | |
| self.loss_G = self.loss_G_GAN + self.loss_G_L1 | |
| self.loss_G.backward() | |
| def optimize(self): | |
| self.forward() | |
| self.net_D.train() | |
| self.set_requires_grad(self.net_D, True) | |
| self.opt_D.zero_grad() | |
| self.backward_D() | |
| self.opt_D.step() | |
| self.net_G.train() | |
| self.set_requires_grad(self.net_D, False) | |
| self.opt_G.zero_grad() | |
| self.backward_G() | |
| self.opt_G.step() | |
| class UNetAuto(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=2, features=[64, 128, 256, 512]): | |
| super(UNetAuto, self).__init__() | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Encoder part | |
| for feature in features: | |
| self.encoder.append(self._block(in_channels, feature)) | |
| in_channels = feature | |
| # Decoder part (Upsampling) | |
| for feature in reversed(features): | |
| self.decoder.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)) | |
| self.decoder.append(self._block(feature * 2, feature)) | |
| # Final Convolution | |
| self.bottleneck = self._block(features[-1], features[-1] * 2) | |
| self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
| def forward(self, x): #, t): | |
| skip_connections = [] | |
| # Encode | |
| for layer in self.encoder: | |
| x = layer(x) | |
| skip_connections.append(x) | |
| x = self.pool(x) | |
| # Bottleneck | |
| x = self.bottleneck(x) | |
| # Decode | |
| skip_connections = skip_connections[::-1] | |
| for idx in range(0, len(self.decoder), 2): | |
| x = self.decoder[idx](x) | |
| skip_connection = skip_connections[idx // 2] | |
| x = torch.cat((x, skip_connection), dim=1) # Skip connection | |
| x = self.decoder[idx + 1](x) | |
| return self.final_conv(x) | |
| def _block(self, in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| class Autoencoder(nn.Module): | |
| def __init__(self, model): | |
| super(Autoencoder, self).__init__() | |
| self.model = model | |
| def forward(self, x): #, t): | |
| return self.model(x)#, t) |