import torch from torch import nn import torch.nn.functional as F import torchvision.transforms.functional as tf from .filter import Filter from .backbone import EfficientBackboneCommon from .module import CascadeArgumentRegressor, FilterPerformer class Enhancer(nn.Module): def __init__(self): super(Enhancer, self).__init__() self.input_size = (256, 256) self.filter_types = [ Filter.BRIGHTNESS, Filter.CONTRAST, Filter.SATURATION, Filter.HIGHLIGHT, Filter.SHADOW, ] self.backbone = EfficientBackboneCommon.from_name('efficientnet-b0') self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types)) self.performer = FilterPerformer(self.filter_types) def predict_arguments(self, x, mask): x = F.interpolate(x, self.input_size, mode='bilinear', align_corners=False) enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(x) arguments = self.regressor(enc32x) return arguments def restore_image(self, x, mask, arguments): assert len(arguments) == len(self.filter_types) arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments] return self.performer.restore(x, mask, arguments)