| | import datetime |
| | import json |
| | import os |
| | import pickle as pickle_tts |
| | import shutil |
| | from typing import Any, Callable, Dict, Union |
| |
|
| | import fsspec |
| | import torch |
| | from coqpit import Coqpit |
| |
|
| | from TTS.utils.generic_utils import get_user_data_dir |
| |
|
| |
|
| | class RenamingUnpickler(pickle_tts.Unpickler): |
| | """Overload default pickler to solve module renaming problem""" |
| |
|
| | def find_class(self, module, name): |
| | return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) |
| |
|
| |
|
| | class AttrDict(dict): |
| | """A custom dict which converts dict keys |
| | to class attributes""" |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.__dict__ = self |
| |
|
| |
|
| | def copy_model_files(config: Coqpit, out_path, new_fields=None): |
| | """Copy config.json and other model files to training folder and add |
| | new fields. |
| | |
| | Args: |
| | config (Coqpit): Coqpit config defining the training run. |
| | out_path (str): output path to copy the file. |
| | new_fields (dict): new fileds to be added or edited |
| | in the config file. |
| | """ |
| | copy_config_path = os.path.join(out_path, "config.json") |
| | |
| | if new_fields: |
| | config.update(new_fields, allow_new=True) |
| | |
| | with fsspec.open(copy_config_path, "w", encoding="utf8") as f: |
| | json.dump(config.to_dict(), f, indent=4) |
| |
|
| | |
| | if config.audio.stats_path is not None: |
| | copy_stats_path = os.path.join(out_path, "scale_stats.npy") |
| | filesystem = fsspec.get_mapper(copy_stats_path).fs |
| | if not filesystem.exists(copy_stats_path): |
| | with fsspec.open(config.audio.stats_path, "rb") as source_file: |
| | with fsspec.open(copy_stats_path, "wb") as target_file: |
| | shutil.copyfileobj(source_file, target_file) |
| |
|
| |
|
| | def load_fsspec( |
| | path: str, |
| | map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, |
| | cache: bool = True, |
| | **kwargs, |
| | ) -> Any: |
| | """Like torch.load but can load from other locations (e.g. s3:// , gs://). |
| | |
| | Args: |
| | path: Any path or url supported by fsspec. |
| | map_location: torch.device or str. |
| | cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. |
| | **kwargs: Keyword arguments forwarded to torch.load. |
| | |
| | Returns: |
| | Object stored in path. |
| | """ |
| | is_local = os.path.isdir(path) or os.path.isfile(path) |
| | if cache and not is_local: |
| | with fsspec.open( |
| | f"filecache::{path}", |
| | filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, |
| | mode="rb", |
| | ) as f: |
| | return torch.load(f, map_location=map_location, **kwargs) |
| | else: |
| | with fsspec.open(path, "rb") as f: |
| | return torch.load(f, map_location=map_location, **kwargs) |
| |
|
| |
|
| | def load_checkpoint( |
| | model, checkpoint_path, use_cuda=False, eval=False, cache=False |
| | ): |
| | try: |
| | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
| | except ModuleNotFoundError: |
| | pickle_tts.Unpickler = RenamingUnpickler |
| | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) |
| | model.load_state_dict(state["model"]) |
| | if use_cuda: |
| | model.cuda() |
| | if eval: |
| | model.eval() |
| | return model, state |
| |
|
| |
|
| | def save_fsspec(state: Any, path: str, **kwargs): |
| | """Like torch.save but can save to other locations (e.g. s3:// , gs://). |
| | |
| | Args: |
| | state: State object to save |
| | path: Any path or url supported by fsspec. |
| | **kwargs: Keyword arguments forwarded to torch.save. |
| | """ |
| | with fsspec.open(path, "wb") as f: |
| | torch.save(state, f, **kwargs) |
| |
|
| |
|
| | def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): |
| | if hasattr(model, "module"): |
| | model_state = model.module.state_dict() |
| | else: |
| | model_state = model.state_dict() |
| | if isinstance(optimizer, list): |
| | optimizer_state = [optim.state_dict() for optim in optimizer] |
| | elif optimizer.__class__.__name__ == "CapacitronOptimizer": |
| | optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()] |
| | else: |
| | optimizer_state = optimizer.state_dict() if optimizer is not None else None |
| |
|
| | if isinstance(scaler, list): |
| | scaler_state = [s.state_dict() for s in scaler] |
| | else: |
| | scaler_state = scaler.state_dict() if scaler is not None else None |
| |
|
| | if isinstance(config, Coqpit): |
| | config = config.to_dict() |
| |
|
| | state = { |
| | "config": config, |
| | "model": model_state, |
| | "optimizer": optimizer_state, |
| | "scaler": scaler_state, |
| | "step": current_step, |
| | "epoch": epoch, |
| | "date": datetime.date.today().strftime("%B %d, %Y"), |
| | } |
| | state.update(kwargs) |
| | save_fsspec(state, output_path) |
| |
|
| |
|
| | def save_checkpoint( |
| | config, |
| | model, |
| | optimizer, |
| | scaler, |
| | current_step, |
| | epoch, |
| | output_folder, |
| | **kwargs, |
| | ): |
| | file_name = "checkpoint_{}.pth".format(current_step) |
| | checkpoint_path = os.path.join(output_folder, file_name) |
| | print("\n > CHECKPOINT : {}".format(checkpoint_path)) |
| | save_model( |
| | config, |
| | model, |
| | optimizer, |
| | scaler, |
| | current_step, |
| | epoch, |
| | checkpoint_path, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def save_best_model( |
| | current_loss, |
| | best_loss, |
| | config, |
| | model, |
| | optimizer, |
| | scaler, |
| | current_step, |
| | epoch, |
| | out_path, |
| | keep_all_best=False, |
| | keep_after=10000, |
| | **kwargs, |
| | ): |
| | if current_loss < best_loss: |
| | best_model_name = f"best_model_{current_step}.pth" |
| | checkpoint_path = os.path.join(out_path, best_model_name) |
| | print(" > BEST MODEL : {}".format(checkpoint_path)) |
| | save_model( |
| | config, |
| | model, |
| | optimizer, |
| | scaler, |
| | current_step, |
| | epoch, |
| | checkpoint_path, |
| | model_loss=current_loss, |
| | **kwargs, |
| | ) |
| | fs = fsspec.get_mapper(out_path).fs |
| | |
| | if not keep_all_best or (current_step < keep_after): |
| | model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) |
| | for model_name in model_names: |
| | if os.path.basename(model_name) != best_model_name: |
| | fs.rm(model_name) |
| | |
| | shortcut_name = "best_model.pth" |
| | shortcut_path = os.path.join(out_path, shortcut_name) |
| | fs.copy(checkpoint_path, shortcut_path) |
| | best_loss = current_loss |
| | return best_loss |
| |
|