| import torch |
|
|
| if __name__=="__main__": |
| src_ckpt = 'saved/train_mulan_v3_48k_everything3/latest/pytorch_model_2.bin' |
| tgt_ckpt = 'saved/train_mulan_v3_48k_everything3_sepnorm/src_pytorch_model_2.bin' |
| |
| |
|
|
| ckpt = torch.load(src_ckpt, map_location='cpu') |
|
|
| ckpt['normfeat.sum_x'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_x'].dtype) * ckpt['normfeat.sum_x'] / ckpt['normfeat.counts'] |
| ckpt['normfeat.sum_x2'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_x2'].dtype) * ckpt['normfeat.sum_x2'] / ckpt['normfeat.counts'] |
| ckpt['normfeat.sum_target_x2'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_target_x2'].dtype) * ckpt['normfeat.sum_target_x2'] / ckpt['normfeat.counts'] |
| ckpt['normfeat.counts'] = torch.ones_like(ckpt['normfeat.counts']) |
| torch.save(ckpt, tgt_ckpt) |
| |