| import torch | |
| from torch import Tensor | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from typing import Optional, List | |
| from .mobilenetv3 import MobileNetV3LargeEncoder | |
| from .resnet import ResNet50Encoder | |
| from .lraspp import LRASPP | |
| from .decoder import RecurrentDecoder, Projection | |
| from .fast_guided_filter import FastGuidedFilterRefiner | |
| from .deep_guided_filter import DeepGuidedFilterRefiner | |
| class MattingNetwork(nn.Module): | |
| def __init__(self, | |
| variant: str = 'mobilenetv3', | |
| refiner: str = 'deep_guided_filter', | |
| pretrained_backbone: bool = False): | |
| super().__init__() | |
| assert variant in ['mobilenetv3', 'resnet50'] | |
| assert refiner in ['fast_guided_filter', 'deep_guided_filter'] | |
| if variant == 'mobilenetv3': | |
| self.backbone = MobileNetV3LargeEncoder(pretrained_backbone) | |
| self.aspp = LRASPP(960, 128) | |
| self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16]) | |
| else: | |
| self.backbone = ResNet50Encoder(pretrained_backbone) | |
| self.aspp = LRASPP(2048, 256) | |
| self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16]) | |
| self.project_mat = Projection(16, 4) | |
| self.project_seg = Projection(16, 1) | |
| if refiner == 'deep_guided_filter': | |
| self.refiner = DeepGuidedFilterRefiner() | |
| else: | |
| self.refiner = FastGuidedFilterRefiner() | |
| def forward(self, | |
| src: Tensor, | |
| r1: Optional[Tensor] = None, | |
| r2: Optional[Tensor] = None, | |
| r3: Optional[Tensor] = None, | |
| r4: Optional[Tensor] = None, | |
| downsample_ratio: float = 1, | |
| segmentation_pass: bool = False): | |
| if downsample_ratio != 1: | |
| src_sm = self._interpolate(src, scale_factor=downsample_ratio) | |
| else: | |
| src_sm = src | |
| f1, f2, f3, f4 = self.backbone(src_sm) | |
| f4 = self.aspp(f4) | |
| hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) | |
| if not segmentation_pass: | |
| fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) | |
| if downsample_ratio != 1: | |
| fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) | |
| fgr = fgr_residual + src | |
| fgr = fgr.clamp(0., 1.) | |
| pha = pha.clamp(0., 1.) | |
| return [fgr, pha, *rec] | |
| else: | |
| seg = self.project_seg(hid) | |
| return [seg, *rec] | |
| def _interpolate(self, x: Tensor, scale_factor: float): | |
| if x.ndim == 5: | |
| B, T = x.shape[:2] | |
| x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor, | |
| mode='bilinear', align_corners=False, recompute_scale_factor=False) | |
| x = x.unflatten(0, (B, T)) | |
| else: | |
| x = F.interpolate(x, scale_factor=scale_factor, | |
| mode='bilinear', align_corners=False, recompute_scale_factor=False) | |
| return x | |