Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import inspect | |
| import warnings | |
| import torchvision | |
| from .stylematte import StyleMatte | |
| class StyleMatteEngine(torch.nn.Module): | |
| def __init__(self, device='cpu',human_matting_path='./pretrain_model/matting/stylematte_synth.pt'): | |
| super().__init__() | |
| self._device = device | |
| self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| self._init_models(human_matting_path) | |
| def _init_models(self,_ckpt_path): | |
| # load dict | |
| state_dict = torch.load(_ckpt_path, map_location='cpu') | |
| # build model | |
| model = StyleMatte() | |
| model.load_state_dict(state_dict) | |
| self.model = model.to(self._device).eval() | |
| def forward(self, input_image, return_type='matting', background_rgb=1.0): | |
| if not hasattr(self, 'model'): | |
| self._init_models() | |
| if input_image.max() > 2.0: | |
| warnings.warn('Image should be normalized to [0, 1].') | |
| _, ori_h, ori_w = input_image.shape | |
| input_image = input_image.to(self._device).float() | |
| image = input_image.clone() | |
| # resize | |
| if max(ori_h, ori_w) > 1024: | |
| scale = 1024.0 / max(ori_h, ori_w) | |
| resized_h, resized_w = int(ori_h * scale), int(ori_w * scale) | |
| image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True) | |
| else: | |
| resized_h, resized_w = ori_h, ori_w | |
| # padding | |
| if resized_h % 8 != 0 or resized_w % 8 != 0: | |
| image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect') | |
| # normalize and forwarding | |
| image = self.normalize(image)[None] | |
| predict = self.model(image)[0] | |
| # undo padding | |
| predict = predict[:, -resized_h:, -resized_w:] | |
| # undo resize | |
| if resized_h != ori_h or resized_w != ori_w: | |
| predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True) | |
| if return_type == 'alpha': | |
| return predict[0] | |
| elif return_type == 'matting': | |
| predict = predict.expand(3, -1, -1) | |
| matting_image = input_image.clone() | |
| background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb | |
| matting_image = matting_image * predict + (1-predict) * background_rgb | |
| return matting_image, predict[0] | |
| elif return_type == 'all': | |
| predict = predict.expand(3, -1, -1) | |
| background_rgb = input_image.new_ones(input_image.shape) * background_rgb | |
| foreground_image = input_image * predict + (1-predict) * background_rgb | |
| background_image = input_image * (1-predict) + predict * background_rgb | |
| return foreground_image, background_image | |
| else: | |
| raise NotImplementedError | |