Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import logging | |
| def weights_init(init_type='gaussian', std=0.02): | |
| def init_fun(m): | |
| classname = m.__class__.__name__ | |
| if (classname.find('Conv') == 0 or classname.find( | |
| 'Linear') == 0) and hasattr(m, 'weight'): | |
| if init_type == 'gaussian': | |
| nn.init.normal_(m.weight, 0.0, std) | |
| elif init_type == 'xavier': | |
| nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) | |
| elif init_type == 'kaiming': | |
| nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) | |
| elif init_type == 'default': | |
| pass | |
| else: | |
| assert 0, "Unsupported initialization: {}".format(init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| nn.init.constant_(m.bias, 0.0) | |
| return init_fun | |
| def freeze(module): | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| def unfreeze(module): | |
| for param in module.parameters(): | |
| param.requires_grad = True | |
| def get_optimizer(opt, model): | |
| lr = float(opt['hyper_params']['lr']) | |
| beta1 = float(opt['model']['beta1']) | |
| weight_decay = float(opt['model']['weight_decay']) | |
| opt_name = opt['model']['optimizer'] | |
| optim_params = [] | |
| # weight decay | |
| for key, value in model.named_parameters(): | |
| if not value.requires_grad: | |
| continue # frozen weights | |
| if key[-4:] == 'bias': | |
| optim_params += [{'params': value, 'weight_decay': 0.0}] | |
| else: | |
| optim_params += [{'params': value, | |
| 'weight_decay': weight_decay}] | |
| if opt_name == 'Adam': | |
| return optim.Adam(optim_params, | |
| lr=lr, | |
| betas=(beta1, 0.999), | |
| eps=1e-5) | |
| else: | |
| err = '{} not implemented yet'.format(opt_name) | |
| logging.error(err) | |
| raise NotImplementedError(err) | |
| def get_activation(activation): | |
| if activation is None: | |
| return nn.Identity() | |
| act_func = { | |
| 'relu':nn.ReLU(), | |
| 'sigmoid':nn.Sigmoid(), | |
| 'tanh':nn.Tanh(), | |
| 'prelu':nn.PReLU(), | |
| 'leaky':nn.LeakyReLU(0.2), | |
| 'gelu':nn.GELU(), | |
| } | |
| if activation not in act_func.keys(): | |
| logging.error("activation {} is not implemented yet".format(activation)) | |
| assert False | |
| return act_func[activation] | |
| def get_norm(out_channels, norm_type='Instance'): | |
| norm_set = ['Instance', 'Batch', 'Group'] | |
| if norm_type not in norm_set: | |
| err = "Normalization {} has not been implemented yet" | |
| logging.error(err) | |
| raise ValueError(err) | |
| if norm_type == 'Instance': | |
| return nn.InstanceNorm2d(out_channels, affine=True) | |
| if norm_type == 'Batch': | |
| return nn.BatchNorm2d(out_channels) | |
| if norm_type == 'Group': | |
| if out_channels >= 32: | |
| groups = 32 | |
| else: | |
| groups = 1 | |
| return nn.GroupNorm(groups, out_channels) | |
| else: | |
| raise NotImplementedError('{} has not implemented yet'.format(norm_type)) | |
| def get_layer_info(out_channels, activation_func='relu'): | |
| activation = get_activation(activation_func) | |
| norm_layer = get_norm(out_channels, 'Group') | |
| return norm_layer, activation | |
| class Conv(nn.Module): | |
| """ (convolution => [BN] => ReLU) """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=True, | |
| activation='leaky', | |
| resnet=True): | |
| super().__init__() | |
| norm_layer, act_func = get_layer_info(out_channels,activation) | |
| if resnet and in_channels == out_channels: | |
| self.resnet = True | |
| else: | |
| self.resnet = False | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias), | |
| norm_layer, | |
| act_func) | |
| def forward(self, x): | |
| res = self.conv(x) | |
| if self.resnet: | |
| res = res + x | |
| return res | |
| class Up(nn.Module): | |
| """ Upscaling then conv """ | |
| def __init__(self, in_channels, out_channels, activation='relu', resnet=True): | |
| super().__init__() | |
| self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet) | |
| def forward(self, x): | |
| x = self.up_layer(x) | |
| return self.up(x) | |
| class DConv(nn.Module): | |
| """ Double Conv Layer | |
| """ | |
| def __init__(self, in_channels, out_channels, activation='relu', resnet=True): | |
| super().__init__() | |
| self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet) | |
| self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet) | |
| def forward(self, x): | |
| return self.conv2(self.conv1(x)) | |
| class Encoder(nn.Module): | |
| def __init__(self, in_channels=3, mid_act='leaky', resnet=True): | |
| super(Encoder, self).__init__() | |
| self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet) | |
| self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet) | |
| self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet) | |
| self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet) | |
| self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet) | |
| self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet) | |
| self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet) | |
| self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet) | |
| self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet) | |
| self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet) | |
| self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet) | |
| def forward(self, x): | |
| x1 = self.in_conv(x) # 32 x 256 x 256 | |
| x1 = torch.cat((x, x1), dim=1) | |
| x2 = self.down_32_64(x1) | |
| x3 = self.down_64_64_1(x2) | |
| x4 = self.down_64_128(x3) | |
| x5 = self.down_128_128_1(x4) | |
| x6 = self.down_128_256(x5) | |
| x7 = self.down_256_256_1(x6) | |
| x8 = self.down_256_512(x7) | |
| x9 = self.down_512_512_1(x8) | |
| x10 = self.down_512_512_2(x9) | |
| x11 = self.down_512_512_3(x10) | |
| return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 | |
| class Decoder(nn.Module): | |
| """ Up Stream Sequence """ | |
| def __init__(self, | |
| out_channels=3, | |
| mid_act='relu', | |
| out_act='sigmoid', | |
| resnet = True): | |
| super(Decoder, self).__init__() | |
| input_channel = 512 | |
| fea_dim = 100 | |
| self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet) | |
| self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet) | |
| self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet) | |
| self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet) | |
| self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet) | |
| self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet) | |
| self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet) | |
| self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet) | |
| self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet) | |
| self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet) | |
| self.out_conv = Conv(64, out_channels, activation=out_act) | |
| def forward(self, x, ibl): | |
| x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x | |
| h,w = x10.shape[2:] | |
| y = ibl.view(-1, 512, 1, 1).repeat(1, 1, h, w) | |
| y = self.up_16_16_1(y) # 256 x 16 x 16 | |
| y = torch.cat((x10, y), dim=1) # 768 x 16 x 16 | |
| y = self.up_16_16_2(y) # 512 x 16 x 16 | |
| y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16 | |
| y = self.up_16_16_3(y) # 512 x 16 x 16 | |
| y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16 | |
| y = self.up_16_32(y) # 256 x 32 x 32 | |
| y = torch.cat((x7, y), dim=1) | |
| y = self.up_32_32_1(y) # 256 x 32 x 32 | |
| y = torch.cat((x6, y), dim=1) | |
| y = self.up_32_64(y) | |
| y = torch.cat((x5, y), dim=1) | |
| y = self.up_64_64_1(y) # 128 x 64 x 64 | |
| y = torch.cat((x4, y), dim=1) | |
| y = self.up_64_128(y) | |
| y = torch.cat((x3, y), dim=1) | |
| y = self.up_128_128_1(y) # 64 x 128 x 128 | |
| y = torch.cat((x2, y), dim=1) | |
| y = self.up_128_256(y) # 32 x 256 x 256 | |
| y = torch.cat((x1, y), dim=1) | |
| y = self.out_conv(y) # 3 x 256 x 256 | |
| return y | |
| class SSN_Model(nn.Module): | |
| """ Implementation of Relighting Net """ | |
| def __init__(self, | |
| in_channels=3, | |
| out_channels=3, | |
| mid_act='leaky', | |
| out_act='sigmoid', | |
| resnet=True): | |
| super(SSN_Model, self).__init__() | |
| self.out_act = out_act | |
| self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet) | |
| self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet) | |
| # init weights | |
| init_func = weights_init('gaussian', std=1e-3) | |
| self.encoder.apply(init_func) | |
| self.decoder.apply(init_func) | |
| def forward(self, x, ibl): | |
| """ | |
| Input is (source image, target light, source light, ) | |
| Output is: predicted new image, predicted source light, self-supervision image | |
| """ | |
| latent = self.encoder(x) | |
| pred = self.decoder(latent, ibl) | |
| if self.out_act == 'sigmoid': | |
| pred = pred * 30.0 | |
| return pred | |
| if __name__ == '__main__': | |
| x = torch.randn(5,1,256,256) | |
| ibl = torch.randn(5, 1, 32, 16) | |
| model = SSN_Model(1,1) | |
| y = model(x, ibl) | |
| print('Output: ', y.shape) | |