Spaces:
Sleeping
Sleeping
| """ | |
| Export TorchScript | |
| python export_torchscript.py \ | |
| --model-backbone resnet50 \ | |
| --model-checkpoint "PATH_TO_CHECKPOINT" \ | |
| --precision float32 \ | |
| --output "torchscript.pth" | |
| """ | |
| import argparse | |
| import torch | |
| from torch import nn | |
| from model import MattingRefine | |
| # --------------- Arguments --------------- | |
| parser = argparse.ArgumentParser(description='Export TorchScript') | |
| parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) | |
| parser.add_argument('--model-checkpoint', type=str, required=True) | |
| parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) | |
| parser.add_argument('--output', type=str, required=True) | |
| args = parser.parse_args() | |
| # --------------- Utils --------------- | |
| class MattingRefine_TorchScriptWrapper(nn.Module): | |
| """ | |
| The purpose of this wrapper is to hoist all the configurable attributes to the top level. | |
| So that the user can easily change them after loading the saved TorchScript model. | |
| Example: | |
| model = torch.jit.load('torchscript.pth') | |
| model.backbone_scale = 0.25 | |
| model.refine_mode = 'sampling' | |
| model.refine_sample_pixels = 80_000 | |
| pha, fgr = model(src, bgr)[:2] | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.model = MattingRefine(*args, **kwargs) | |
| # Hoist the attributes to the top level. | |
| self.backbone_scale = self.model.backbone_scale | |
| self.refine_mode = self.model.refiner.mode | |
| self.refine_sample_pixels = self.model.refiner.sample_pixels | |
| self.refine_threshold = self.model.refiner.threshold | |
| self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling | |
| def forward(self, src, bgr): | |
| # Reset the attributes. | |
| self.model.backbone_scale = self.backbone_scale | |
| self.model.refiner.mode = self.refine_mode | |
| self.model.refiner.sample_pixels = self.refine_sample_pixels | |
| self.model.refiner.threshold = self.refine_threshold | |
| self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling | |
| return self.model(src, bgr) | |
| def load_state_dict(self, *args, **kwargs): | |
| return self.model.load_state_dict(*args, **kwargs) | |
| # --------------- Main --------------- | |
| model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval() | |
| model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu')) | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| if args.precision == 'float16': | |
| model = model.half() | |
| model = torch.jit.script(model) | |
| model.save(args.output) | |