| | from io import BytesIO
|
| | import pickle
|
| | import time
|
| | import torch
|
| | from tqdm import tqdm
|
| | from collections import OrderedDict
|
| |
|
| |
|
| | def load_inputs(path, device, is_half=False):
|
| | parm = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
|
| | for key in parm.keys():
|
| | parm[key] = parm[key].to(device)
|
| | if is_half and parm[key].dtype == torch.float32:
|
| | parm[key] = parm[key].half()
|
| | elif not is_half and parm[key].dtype == torch.float16:
|
| | parm[key] = parm[key].float()
|
| | return parm
|
| |
|
| |
|
| | def benchmark(
|
| | model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
|
| | ):
|
| | parm = load_inputs(inputs_path, device, is_half)
|
| | total_ts = 0.0
|
| | bar = tqdm(range(epoch))
|
| | for i in bar:
|
| | start_time = time.perf_counter()
|
| | o = model(**parm)
|
| | total_ts += time.perf_counter() - start_time
|
| | print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
|
| |
|
| |
|
| | def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
|
| | benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
|
| |
|
| |
|
| | def to_jit_model(
|
| | model_path,
|
| | model_type: str,
|
| | mode: str = "trace",
|
| | inputs_path: str = None,
|
| | device=torch.device("cpu"),
|
| | is_half=False,
|
| | ):
|
| | model = None
|
| | if model_type.lower() == "synthesizer":
|
| | from .get_synthesizer import get_synthesizer
|
| |
|
| | model, _ = get_synthesizer(model_path, device)
|
| | model.forward = model.infer
|
| | elif model_type.lower() == "rmvpe":
|
| | from .get_rmvpe import get_rmvpe
|
| |
|
| | model = get_rmvpe(model_path, device)
|
| | elif model_type.lower() == "hubert":
|
| | from .get_hubert import get_hubert_model
|
| |
|
| | model = get_hubert_model(model_path, device)
|
| | model.forward = model.infer
|
| | else:
|
| | raise ValueError(f"No model type named {model_type}")
|
| | model = model.eval()
|
| | model = model.half() if is_half else model.float()
|
| | if mode == "trace":
|
| | assert not inputs_path
|
| | inputs = load_inputs(inputs_path, device, is_half)
|
| | model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
| | elif mode == "script":
|
| | model_jit = torch.jit.script(model)
|
| | model_jit.to(device)
|
| | model_jit = model_jit.half() if is_half else model_jit.float()
|
| |
|
| | return (model, model_jit)
|
| |
|
| |
|
| | def export(
|
| | model: torch.nn.Module,
|
| | mode: str = "trace",
|
| | inputs: dict = None,
|
| | device=torch.device("cpu"),
|
| | is_half: bool = False,
|
| | ) -> dict:
|
| | model = model.half() if is_half else model.float()
|
| | model.eval()
|
| | if mode == "trace":
|
| | assert inputs is not None
|
| | model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
| | elif mode == "script":
|
| | model_jit = torch.jit.script(model)
|
| | model_jit.to(device)
|
| | model_jit = model_jit.half() if is_half else model_jit.float()
|
| | buffer = BytesIO()
|
| |
|
| | torch.jit.save(model_jit, buffer)
|
| | del model_jit
|
| | cpt = OrderedDict()
|
| | cpt["model"] = buffer.getvalue()
|
| | cpt["is_half"] = is_half
|
| | return cpt
|
| |
|
| |
|
| | def load(path: str):
|
| | with open(path, "rb") as f:
|
| | return pickle.load(f)
|
| |
|
| |
|
| | def save(ckpt: dict, save_path: str):
|
| | with open(save_path, "wb") as f:
|
| | pickle.dump(ckpt, f)
|
| |
|
| |
|
| | def rmvpe_jit_export(
|
| | model_path: str,
|
| | mode: str = "script",
|
| | inputs_path: str = None,
|
| | save_path: str = None,
|
| | device=torch.device("cpu"),
|
| | is_half=False,
|
| | ):
|
| | if not save_path:
|
| | save_path = model_path.rstrip(".pth")
|
| | save_path += ".half.jit" if is_half else ".jit"
|
| | if "cuda" in str(device) and ":" not in str(device):
|
| | device = torch.device("cuda:0")
|
| | from .get_rmvpe import get_rmvpe
|
| |
|
| | model = get_rmvpe(model_path, device)
|
| | inputs = None
|
| | if mode == "trace":
|
| | inputs = load_inputs(inputs_path, device, is_half)
|
| | ckpt = export(model, mode, inputs, device, is_half)
|
| | ckpt["device"] = str(device)
|
| | save(ckpt, save_path)
|
| | return ckpt
|
| |
|
| |
|
| | def synthesizer_jit_export(
|
| | model_path: str,
|
| | mode: str = "script",
|
| | inputs_path: str = None,
|
| | save_path: str = None,
|
| | device=torch.device("cpu"),
|
| | is_half=False,
|
| | ):
|
| | if not save_path:
|
| | save_path = model_path.rstrip(".pth")
|
| | save_path += ".half.jit" if is_half else ".jit"
|
| | if "cuda" in str(device) and ":" not in str(device):
|
| | device = torch.device("cuda:0")
|
| | from .get_synthesizer import get_synthesizer
|
| |
|
| | model, cpt = get_synthesizer(model_path, device)
|
| | assert isinstance(cpt, dict)
|
| | model.forward = model.infer
|
| | inputs = None
|
| | if mode == "trace":
|
| | inputs = load_inputs(inputs_path, device, is_half)
|
| | ckpt = export(model, mode, inputs, device, is_half)
|
| | cpt.pop("weight")
|
| | cpt["model"] = ckpt["model"]
|
| | cpt["device"] = device
|
| | save(cpt, save_path)
|
| | return cpt
|
| |
|