| | import pickle |
| | from io import BytesIO |
| | from collections import OrderedDict |
| | import os |
| |
|
| | import torch |
| |
|
| |
|
| | def load_pickle(path: str): |
| | with open(path, "rb") as f: |
| | return pickle.load(f) |
| |
|
| |
|
| | def save_pickle(ckpt: dict, save_path: str): |
| | with open(save_path, "wb") as f: |
| | pickle.dump(ckpt, f) |
| |
|
| |
|
| | def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False): |
| | parm = torch.load(path, map_location=torch.device("cpu")) |
| | for key in parm.keys(): |
| | parm[key] = parm[key].to(device) |
| | if is_half and parm[key].dtype == torch.float32: |
| | parm[key] = parm[key].half() |
| | elif not is_half and parm[key].dtype == torch.float16: |
| | parm[key] = parm[key].float() |
| | return parm |
| |
|
| |
|
| | def export_jit_model( |
| | model: torch.nn.Module, |
| | mode: str = "trace", |
| | inputs: dict = None, |
| | device=torch.device("cpu"), |
| | is_half: bool = False, |
| | ) -> dict: |
| | model = model.half() if is_half else model.float() |
| | model.eval() |
| | if mode == "trace": |
| | assert inputs is not None |
| | model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) |
| | elif mode == "script": |
| | model_jit = torch.jit.script(model) |
| | model_jit.to(device) |
| | model_jit = model_jit.half() if is_half else model_jit.float() |
| | buffer = BytesIO() |
| | |
| | torch.jit.save(model_jit, buffer) |
| | del model_jit |
| | cpt = OrderedDict() |
| | cpt["model"] = buffer.getvalue() |
| | cpt["is_half"] = is_half |
| | return cpt |
| |
|
| |
|
| | def get_jit_model(model_path: str, is_half: bool, device: str, exporter): |
| | jit_model_path = model_path.rstrip(".pth") |
| | jit_model_path += ".half.jit" if is_half else ".jit" |
| | ckpt = None |
| |
|
| | if os.path.exists(jit_model_path): |
| | ckpt = load_pickle(jit_model_path) |
| | model_device = ckpt["device"] |
| | if model_device != str(device): |
| | del ckpt |
| | ckpt = None |
| |
|
| | if ckpt is None: |
| | ckpt = exporter( |
| | model_path=model_path, |
| | mode="script", |
| | inputs_path=None, |
| | save_path=jit_model_path, |
| | device=device, |
| | is_half=is_half, |
| | ) |
| |
|
| | return ckpt |
| |
|