File size: 2,851 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from .filter import Filter
from .backbone import EfficientBackbone
from .module import CascadeArgumentRegressor, FilterPerformer
class Harmonizer(nn.Module):
def __init__(self):
super(Harmonizer, self).__init__()
self.input_size = (256, 256)
self.filter_types = [
Filter.TEMPERATURE,
Filter.BRIGHTNESS,
Filter.CONTRAST,
Filter.SATURATION,
Filter.HIGHLIGHT,
Filter.SHADOW,
]
self.filter_argument_ranges = [
0.3,
0.5,
0.5,
0.6,
0.4,
0.4,
]
self.backbone = EfficientBackbone.from_name('efficientnet-b0')
self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
self.performer = FilterPerformer(self.filter_types)
for m in self.modules():
if isinstance(m, nn.Conv2d):
self._init_conv(m)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
self._init_norm(m)
self.backbone = EfficientBackbone.from_pretrained('efficientnet-b0')
def forward(self, comp, mask):
arguments = self.predict_arguments(comp, mask)
pred = self.restore_image(comp, mask, arguments)
return pred
def predict_arguments(self, comp, mask):
comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False)
mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False)
fg = torch.cat((comp, mask), dim=1)
bg = torch.cat((comp, (1 - mask)), dim=1)
enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg)
arguments = self.regressor(enc32x)
return arguments
def restore_image(self, comp, mask, arguments):
assert len(arguments) == len(self.filter_types)
arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
return self.performer.restore(comp, mask, arguments)
def adjust_image(self, image, mask, arguments):
assert len(arguments) == len(self.filter_types)
arguments = [(torch.clamp(arg, -1, 1) * r).view(-1, 1, 1, 1) \
for arg, r in zip(arguments, self.filter_argument_ranges)]
return self.performer.adjust(image, mask, arguments)
def _init_conv(self, conv):
nn.init.kaiming_uniform_(
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def _init_norm(self, bn):
if bn.weight is not None:
nn.init.constant_(bn.weight, 1)
nn.init.constant_(bn.bias, 0)
|