| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as tf | |
| from .filter import Filter | |
| from .backbone import EfficientBackbone | |
| from .module import CascadeArgumentRegressor, FilterPerformer | |
| class Harmonizer(nn.Module): | |
| def __init__(self): | |
| super(Harmonizer, self).__init__() | |
| self.input_size = (256, 256) | |
| self.filter_types = [ | |
| Filter.TEMPERATURE, | |
| Filter.BRIGHTNESS, | |
| Filter.CONTRAST, | |
| Filter.SATURATION, | |
| Filter.HIGHLIGHT, | |
| Filter.SHADOW, | |
| ] | |
| self.filter_argument_ranges = [ | |
| 0.3, | |
| 0.5, | |
| 0.5, | |
| 0.6, | |
| 0.4, | |
| 0.4, | |
| ] | |
| self.backbone = EfficientBackbone.from_name('efficientnet-b0') | |
| self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types)) | |
| self.performer = FilterPerformer(self.filter_types) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| self._init_conv(m) | |
| elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): | |
| self._init_norm(m) | |
| self.backbone = EfficientBackbone.from_pretrained('efficientnet-b0') | |
| def forward(self, comp, mask): | |
| arguments = self.predict_arguments(comp, mask) | |
| pred = self.restore_image(comp, mask, arguments) | |
| return pred | |
| def predict_arguments(self, comp, mask): | |
| comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False) | |
| mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False) | |
| fg = torch.cat((comp, mask), dim=1) | |
| bg = torch.cat((comp, (1 - mask)), dim=1) | |
| enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg) | |
| arguments = self.regressor(enc32x) | |
| return arguments | |
| def restore_image(self, comp, mask, arguments): | |
| assert len(arguments) == len(self.filter_types) | |
| arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments] | |
| return self.performer.restore(comp, mask, arguments) | |
| def adjust_image(self, image, mask, arguments): | |
| assert len(arguments) == len(self.filter_types) | |
| arguments = [(torch.clamp(arg, -1, 1) * r).view(-1, 1, 1, 1) \ | |
| for arg, r in zip(arguments, self.filter_argument_ranges)] | |
| return self.performer.adjust(image, mask, arguments) | |
| def _init_conv(self, conv): | |
| nn.init.kaiming_uniform_( | |
| conv.weight, a=0, mode='fan_in', nonlinearity='relu') | |
| if conv.bias is not None: | |
| nn.init.constant_(conv.bias, 0) | |
| def _init_norm(self, bn): | |
| if bn.weight is not None: | |
| nn.init.constant_(bn.weight, 1) | |
| nn.init.constant_(bn.bias, 0) | |