Env_mixer / src /model /enhancer.py
Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from .filter import Filter
from .backbone import EfficientBackboneCommon
from .module import CascadeArgumentRegressor, FilterPerformer
class Enhancer(nn.Module):
def __init__(self):
super(Enhancer, self).__init__()
self.input_size = (256, 256)
self.filter_types = [
Filter.BRIGHTNESS,
Filter.CONTRAST,
Filter.SATURATION,
Filter.HIGHLIGHT,
Filter.SHADOW,
]
self.backbone = EfficientBackboneCommon.from_name('efficientnet-b0')
self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
self.performer = FilterPerformer(self.filter_types)
def predict_arguments(self, x, mask):
x = F.interpolate(x, self.input_size, mode='bilinear', align_corners=False)
enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(x)
arguments = self.regressor(enc32x)
return arguments
def restore_image(self, x, mask, arguments):
assert len(arguments) == len(self.filter_types)
arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
return self.performer.restore(x, mask, arguments)