File size: 1,549 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from .filter import Filter
from .backbone import EfficientBackbone
from .module import CascadeArgumentRegressor, FilterPerformer
class Harmonizer(nn.Module):
def __init__(self):
super(Harmonizer, self).__init__()
self.input_size = (256, 256)
self.filter_types = [
Filter.TEMPERATURE,
Filter.BRIGHTNESS,
Filter.CONTRAST,
Filter.SATURATION,
Filter.HIGHLIGHT,
Filter.SHADOW,
]
self.backbone = EfficientBackbone.from_name('efficientnet-b0')
self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
self.performer = FilterPerformer(self.filter_types)
def predict_arguments(self, comp, mask):
comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False)
mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False)
fg = torch.cat((comp, mask), dim=1)
bg = torch.cat((comp, (1 - mask)), dim=1)
enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg)
arguments = self.regressor(enc32x)
return arguments
def restore_image(self, comp, mask, arguments):
assert len(arguments) == len(self.filter_types)
arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
return self.performer.restore(comp, mask, arguments)
|