| from collections import OrderedDict |
| from text.symbols import symbols |
| import torch |
|
|
| from tools.log import logger |
| import utils |
| from models import SynthesizerTrn |
| import os |
|
|
|
|
| def copyStateDict(state_dict): |
| if list(state_dict.keys())[0].startswith("module"): |
| start_idx = 1 |
| else: |
| start_idx = 0 |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = ",".join(k.split(".")[start_idx:]) |
| new_state_dict[name] = v |
| return new_state_dict |
|
|
|
|
| def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): |
| hps = utils.get_hparams_from_file(config) |
|
|
| net_g = SynthesizerTrn( |
| len(symbols), |
| hps.data.filter_length // 2 + 1, |
| hps.train.segment_size // hps.data.hop_length, |
| n_speakers=hps.data.n_speakers, |
| **hps.model, |
| ) |
|
|
| optim_g = torch.optim.AdamW( |
| net_g.parameters(), |
| hps.train.learning_rate, |
| betas=hps.train.betas, |
| eps=hps.train.eps, |
| ) |
|
|
| state_dict_g = torch.load(input_model, map_location="cpu") |
| new_dict_g = copyStateDict(state_dict_g) |
| keys = [] |
| for k, v in new_dict_g["model"].items(): |
| if "enc_q" in k: |
| continue |
| keys.append(k) |
|
|
| new_dict_g = ( |
| {k: new_dict_g["model"][k].half() for k in keys} |
| if ishalf |
| else {k: new_dict_g["model"][k] for k in keys} |
| ) |
|
|
| torch.save( |
| { |
| "model": new_dict_g, |
| "iteration": 0, |
| "optimizer": optim_g.state_dict(), |
| "learning_rate": 0.0001, |
| }, |
| output_model, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("-c", "--config", type=str, default="configs/config.json") |
| parser.add_argument("-i", "--input", type=str) |
| parser.add_argument("-o", "--output", type=str, default=None) |
| parser.add_argument( |
| "-hf", "--half", action="store_true", default=False, help="Save as FP16" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| output = args.output |
|
|
| if output is None: |
| import os.path |
|
|
| filename, ext = os.path.splitext(args.input) |
| half = "_half" if args.half else "" |
| output = filename + "_release" + half + ext |
|
|
| removeOptimizer(args.config, args.input, args.half, output) |
| logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}") |
|
|