Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # | |
| # Copyright 2021-2022 Xiaomi Corporation | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Usage: | |
| This script loads checkpoints and averages them. | |
| python3 -m zipvoice.bin.generate_averaged_model \ | |
| --epoch 11 \ | |
| --avg 4 \ | |
| --model-name zipvoice \ | |
| --exp-dir exp/zipvoice | |
| It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`. | |
| You can later load it by `torch.load("epoch-11-avg-4.pt")`. | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import torch | |
| from zipvoice.models.zipvoice import ZipVoice | |
| from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo | |
| from zipvoice.models.zipvoice_distill import ZipVoiceDistill | |
| from zipvoice.tokenizer.tokenizer import SimpleTokenizer | |
| from zipvoice.utils.checkpoint import ( | |
| average_checkpoints_with_averaged_model, | |
| find_checkpoints, | |
| ) | |
| from zipvoice.utils.common import AttributeDict | |
| def get_parser(): | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
| ) | |
| parser.add_argument( | |
| "--epoch", | |
| type=int, | |
| default=11, | |
| help="""It specifies the checkpoint to use for decoding. | |
| Note: Epoch counts from 1. | |
| You can specify --avg to use more checkpoints for model averaging.""", | |
| ) | |
| parser.add_argument( | |
| "--iter", | |
| type=int, | |
| default=0, | |
| help="""If positive, --epoch is ignored and it | |
| will use the checkpoint exp_dir/checkpoint-iter.pt. | |
| You can specify --avg to use more checkpoints for model averaging. | |
| """, | |
| ) | |
| parser.add_argument( | |
| "--avg", | |
| type=int, | |
| default=4, | |
| help="Number of checkpoints to average. Automatically select " | |
| "consecutive checkpoints before the checkpoint specified by " | |
| "'--epoch' or --iter", | |
| ) | |
| parser.add_argument( | |
| "--exp-dir", | |
| type=str, | |
| default="exp/zipvoice", | |
| help="The experiment dir", | |
| ) | |
| parser.add_argument( | |
| "--model-name", | |
| type=str, | |
| default="zipvoice", | |
| choices=[ | |
| "zipvoice", | |
| "zipvoice_distill", | |
| "zipvoice_dialog", | |
| "zipvoice_dialog_stereo", | |
| ], | |
| help="The model type to be averaged. ", | |
| ) | |
| return parser | |
| def main(): | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| params = AttributeDict() | |
| params.update(vars(args)) | |
| params.exp_dir = Path(params.exp_dir) | |
| with open(params.exp_dir / "model.json", "r") as f: | |
| model_config = json.load(f) | |
| # Any tokenizer can be used here. | |
| # Use SimpleTokenizer for simplicity. | |
| tokenizer = SimpleTokenizer(token_file=params.exp_dir / "tokens.txt") | |
| if params.model_name in ["zipvoice", "zipvoice_distill"]: | |
| tokenizer_config = { | |
| "vocab_size": tokenizer.vocab_size, | |
| "pad_id": tokenizer.pad_id, | |
| } | |
| elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]: | |
| tokenizer_config = { | |
| "vocab_size": tokenizer.vocab_size, | |
| "pad_id": tokenizer.pad_id, | |
| "spk_a_id": tokenizer.spk_a_id, | |
| "spk_b_id": tokenizer.spk_b_id, | |
| } | |
| params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" | |
| logging.info("Script started") | |
| params.device = torch.device("cpu") | |
| logging.info(f"Device: {params.device}") | |
| logging.info("About to create model") | |
| if params.model_name == "zipvoice": | |
| model = ZipVoice( | |
| **model_config["model"], | |
| **tokenizer_config, | |
| ) | |
| elif params.model_name == "zipvoice_distill": | |
| model = ZipVoiceDistill( | |
| **model_config["model"], | |
| **tokenizer_config, | |
| ) | |
| elif params.model_name == "zipvoice_dialog": | |
| model = ZipVoiceDialog( | |
| **model_config["model"], | |
| **tokenizer_config, | |
| ) | |
| elif params.model_name == "zipvoice_dialog_stereo": | |
| model = ZipVoiceDialogStereo( | |
| **model_config["model"], | |
| **tokenizer_config, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown model name: {params.model_name}") | |
| if params.iter > 0: | |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ | |
| : params.avg + 1 | |
| ] | |
| if len(filenames) == 0: | |
| raise ValueError( | |
| f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" | |
| ) | |
| elif len(filenames) < params.avg + 1: | |
| raise ValueError( | |
| f"Not enough checkpoints ({len(filenames)}) found for" | |
| f" --iter {params.iter}, --avg {params.avg}" | |
| ) | |
| filename_start = filenames[-1] | |
| filename_end = filenames[0] | |
| logging.info( | |
| "Calculating the averaged model over iteration checkpoints" | |
| f" from {filename_start} (excluded) to {filename_end}" | |
| ) | |
| model.to(params.device) | |
| model.load_state_dict( | |
| average_checkpoints_with_averaged_model( | |
| filename_start=filename_start, | |
| filename_end=filename_end, | |
| device=params.device, | |
| ), | |
| strict=True, | |
| ) | |
| else: | |
| assert params.avg > 0, params.avg | |
| start = params.epoch - params.avg | |
| assert start >= 1, start | |
| filename_start = f"{params.exp_dir}/epoch-{start}.pt" | |
| filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" | |
| logging.info( | |
| f"Calculating the averaged model over epoch range from " | |
| f"{start} (excluded) to {params.epoch}" | |
| ) | |
| model.to(params.device) | |
| model.load_state_dict( | |
| average_checkpoints_with_averaged_model( | |
| filename_start=filename_start, | |
| filename_end=filename_end, | |
| device=params.device, | |
| ), | |
| strict=True, | |
| ) | |
| if params.iter > 0: | |
| filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" | |
| else: | |
| filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" | |
| logging.info(f"Saving the averaged checkpoint to {filename}") | |
| torch.save({"model": model.state_dict()}, filename) | |
| num_param = sum([p.numel() for p in model.parameters()]) | |
| logging.info(f"Number of model parameters: {num_param}") | |
| logging.info("Done!") | |
| if __name__ == "__main__": | |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | |
| logging.basicConfig(format=formatter, level=logging.INFO, force=True) | |
| main() | |