Spaces:
Runtime error
Runtime error
| import torch | |
| from .model import PHNet | |
| import torchvision.transforms.functional as tf | |
| from .util import inference_img, log | |
| from .stylematte import StyleMatte | |
| import numpy as np | |
| import onnx | |
| from .engine import execute_onnx_model | |
| import cv2 | |
| from torchvision import transforms | |
| import time | |
| class Inference: | |
| def __init__(self, **kwargs): | |
| self.rank = 0 | |
| self.__dict__.update(kwargs) | |
| self.model = PHNet(enc_sizes=self.enc_sizes, | |
| skips=self.skips, | |
| grid_count=self.grid_counts, | |
| init_weights=self.init_weights, | |
| init_value=self.init_value) | |
| state = torch.load(self.checkpoint.harmonizer, | |
| map_location=self.device) | |
| self.model.load_state_dict(state, strict=True) | |
| self.model.eval() | |
| def harmonize(self, composite, mask): | |
| if len(composite.shape) < 4: | |
| composite = composite.unsqueeze(0) | |
| while len(mask.shape) < 4: | |
| mask = mask.unsqueeze(0) | |
| composite = tf.resize(composite, [self.image_size, self.image_size]) | |
| mask = tf.resize(mask, [self.image_size, self.image_size]) | |
| log(composite.shape, mask.shape) | |
| with torch.no_grad(): | |
| harmonized = self.model(composite, mask) # ['harmonized'] | |
| result = harmonized * mask + composite * (1-mask) | |
| return result | |
| class Matting: | |
| def __init__(self, **kwargs): | |
| self.rank = 0 | |
| self.__dict__.update(kwargs) | |
| if self.onnx: | |
| self.model = onnx.load(self.checkpoint.matting_onnx) | |
| else: | |
| self.model = StyleMatte().to(self.device) | |
| state = torch.load(self.checkpoint.matting, | |
| map_location=self.device) | |
| self.model.load_state_dict(state, strict=True) | |
| self.model.eval() | |
| def extract(self, inp): | |
| mask = inference_img(self.model, inp, self.device, self.onnx) | |
| inp_np = np.array(inp) | |
| fg = mask[:, :, None]*inp_np | |
| return [mask, fg] | |
| def inference_img(model, img, device='cpu', onnx=True): | |
| beg = time.time() | |
| h, w, _ = img.shape | |
| # print(img.shape) | |
| if h % 8 != 0 or w % 8 != 0: | |
| img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w % | |
| 8, 0, cv2.BORDER_REFLECT) | |
| # print(img.shape) | |
| tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device) | |
| input_t = tensor_img/255.0 | |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| input_t = normalize(input_t) | |
| input_t = input_t.unsqueeze(0).float() | |
| end_p = time.time() | |
| if onnx: | |
| out = execute_onnx_model(input_t, model) | |
| else: | |
| with torch.no_grad(): | |
| out = model(input_t).cpu().numpy() | |
| end = time.time() | |
| log(f"Inference time: {end-beg}, processing time: {end_p-beg}") | |
| # print("out",out.shape) | |
| result = out[0][:, -h:, -w:] | |
| # print(result.shape) | |
| return result[0] | |