| """ |
| Source url: https://github.com/OPHoperHPO/image-background-remove-tool |
| Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. |
| License: Apache License 2.0 |
| """ |
| import pathlib |
| from typing import Union, List, Tuple |
|
|
| import PIL |
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from carvekit.ml.arch.fba_matting.models import FBA |
| from carvekit.ml.arch.fba_matting.transforms import ( |
| trimap_transform, |
| groupnorm_normalise_image, |
| ) |
| from carvekit.ml.files.models_loc import fba_pretrained |
| from carvekit.utils.image_utils import convert_image, load_image |
| from carvekit.utils.models_utils import get_precision_autocast, cast_network |
| from carvekit.utils.pool_utils import batch_generator, thread_pool_processing |
|
|
| __all__ = ["FBAMatting"] |
|
|
|
|
| class FBAMatting(FBA): |
| """ |
| FBA Matting Neural Network to improve edges on image. |
| """ |
|
|
| def __init__( |
| self, |
| device="cpu", |
| input_tensor_size: Union[List[int], int] = 2048, |
| batch_size: int = 2, |
| encoder="resnet50_GN_WS", |
| load_pretrained: bool = True, |
| fp16: bool = False, |
| ): |
| """ |
| Initialize the FBAMatting model |
| |
| Args: |
| device: processing device |
| input_tensor_size: input image size |
| batch_size: the number of images that the neural network processes in one run |
| encoder: neural network encoder head |
| load_pretrained: loading pretrained model |
| fp16: use half precision |
| |
| """ |
| super(FBAMatting, self).__init__(encoder=encoder) |
| self.fp16 = fp16 |
| self.device = device |
| self.batch_size = batch_size |
| if isinstance(input_tensor_size, list): |
| self.input_image_size = input_tensor_size[:2] |
| else: |
| self.input_image_size = (input_tensor_size, input_tensor_size) |
| self.to(device) |
| if load_pretrained: |
| self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device)) |
| self.eval() |
|
|
| def data_preprocessing( |
| self, data: Union[PIL.Image.Image, np.ndarray] |
| ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
| """ |
| Transform input image to suitable data format for neural network |
| |
| Args: |
| data: input image |
| |
| Returns: |
| input for neural network |
| |
| """ |
| resized = data.copy() |
| if self.batch_size == 1: |
| resized.thumbnail(self.input_image_size, resample=3) |
| else: |
| resized = resized.resize(self.input_image_size, resample=3) |
| |
| image = np.array(resized, dtype=np.float64) |
| image = image / 255.0 |
| if resized.mode == "RGB": |
| image = image[:, :, ::-1] |
| elif resized.mode == "L": |
| image2 = np.copy(image) |
| h, w = image2.shape |
| image = np.zeros((h, w, 2)) |
| image[image2 == 1, 1] = 1 |
| image[image2 == 0, 0] = 1 |
| else: |
| raise ValueError("Incorrect color mode for image") |
| h, w = image.shape[:2] |
| h1 = int(np.ceil(1.0 * h / 8) * 8) |
| w1 = int(np.ceil(1.0 * w / 8) * 8) |
| x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4) |
| image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float() |
| if resized.mode == "RGB": |
| return image_tensor, groupnorm_normalise_image( |
| image_tensor.clone(), format="nchw" |
| ) |
| else: |
| return ( |
| image_tensor, |
| torch.from_numpy(trimap_transform(x_scale)) |
| .permute(2, 0, 1)[None, :, :, :] |
| .float(), |
| ) |
|
|
| @staticmethod |
| def data_postprocessing( |
| data: torch.tensor, trimap: PIL.Image.Image |
| ) -> PIL.Image.Image: |
| """ |
| Transforms output data from neural network to suitable data |
| format for using with other components of this framework. |
| |
| Args: |
| data: output data from neural network |
| trimap: Map with the area we need to refine |
| |
| Returns: |
| Segmentation mask as PIL Image instance |
| |
| """ |
| if trimap.mode != "L": |
| raise ValueError("Incorrect color mode for trimap") |
| pred = data.numpy().transpose((1, 2, 0)) |
| pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0] |
| |
| |
| trimap_arr = np.array(trimap.copy()) |
| pred[trimap_arr[:, :] == 0] = 0 |
| |
| pred[pred < 0.3] = 0 |
| return Image.fromarray(pred * 255).convert("L") |
|
|
| def __call__( |
| self, |
| images: List[Union[str, pathlib.Path, PIL.Image.Image]], |
| trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]], |
| ) -> List[PIL.Image.Image]: |
| """ |
| Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances |
| |
| Args: |
| images: input images |
| trimaps: Maps with the areas we need to refine |
| |
| Returns: |
| segmentation masks as for input images, as PIL.Image.Image instances |
| |
| """ |
|
|
| if len(images) != len(trimaps): |
| raise ValueError( |
| "Len of specified arrays of images and trimaps should be equal!" |
| ) |
|
|
| collect_masks = [] |
| autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) |
| with autocast: |
| cast_network(self, dtype) |
| for idx_batch in batch_generator(range(len(images)), self.batch_size): |
| inpt_images = thread_pool_processing( |
| lambda x: convert_image(load_image(images[x])), idx_batch |
| ) |
|
|
| inpt_trimaps = thread_pool_processing( |
| lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch |
| ) |
|
|
| inpt_img_batches = thread_pool_processing( |
| self.data_preprocessing, inpt_images |
| ) |
| inpt_trimaps_batches = thread_pool_processing( |
| self.data_preprocessing, inpt_trimaps |
| ) |
|
|
| inpt_img_batches_transformed = torch.vstack( |
| [i[1] for i in inpt_img_batches] |
| ) |
| inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches]) |
|
|
| inpt_trimaps_transformed = torch.vstack( |
| [i[1] for i in inpt_trimaps_batches] |
| ) |
| inpt_trimaps_batches = torch.vstack( |
| [i[0] for i in inpt_trimaps_batches] |
| ) |
|
|
| with torch.no_grad(): |
| inpt_img_batches = inpt_img_batches.to(self.device) |
| inpt_trimaps_batches = inpt_trimaps_batches.to(self.device) |
| inpt_img_batches_transformed = inpt_img_batches_transformed.to( |
| self.device |
| ) |
| inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device) |
|
|
| output = super(FBAMatting, self).__call__( |
| inpt_img_batches, |
| inpt_trimaps_batches, |
| inpt_img_batches_transformed, |
| inpt_trimaps_transformed, |
| ) |
| output_cpu = output.cpu() |
| del ( |
| inpt_img_batches, |
| inpt_trimaps_batches, |
| inpt_img_batches_transformed, |
| inpt_trimaps_transformed, |
| output, |
| ) |
| masks = thread_pool_processing( |
| lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]), |
| range(len(inpt_images)), |
| ) |
| collect_masks += masks |
| return collect_masks |
|
|