| """ | |
| Loading model | |
| model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") | |
| model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50") | |
| Converter API | |
| convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") | |
| """ | |
| dependencies = ['torch', 'torchvision'] | |
| import torch | |
| from model import MattingNetwork | |
| def mobilenetv3(pretrained: bool = True, progress: bool = True): | |
| model = MattingNetwork('mobilenetv3') | |
| if pretrained: | |
| url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth' | |
| model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) | |
| return model | |
| def resnet50(pretrained: bool = True, progress: bool = True): | |
| model = MattingNetwork('resnet50') | |
| if pretrained: | |
| url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth' | |
| model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) | |
| return model | |
| def converter(): | |
| try: | |
| from inference import convert_video | |
| return convert_video | |
| except ModuleNotFoundError as error: | |
| print(error) | |
| print('Please run "pip install av tqdm pims"') | |