|
|
import torch
|
|
|
from utils.paths import Paths
|
|
|
from models.tacotron import Tacotron
|
|
|
|
|
|
|
|
|
def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
|
|
|
"""
|
|
|
Returns the correct checkpointing paths
|
|
|
depending on whether model is Vocoder or TTS
|
|
|
|
|
|
Args:
|
|
|
checkpoint_type: Either 'voc' or 'tts'
|
|
|
paths: Paths object
|
|
|
"""
|
|
|
if checkpoint_type is 'tts':
|
|
|
weights_path = paths.tts_latest_weights
|
|
|
optim_path = paths.tts_latest_optim
|
|
|
checkpoint_path = paths.tts_checkpoints
|
|
|
elif checkpoint_type is 'voc':
|
|
|
weights_path = paths.voc_latest_weights
|
|
|
optim_path = paths.voc_latest_optim
|
|
|
checkpoint_path = paths.voc_checkpoints
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
return weights_path, optim_path, checkpoint_path
|
|
|
|
|
|
|
|
|
def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
|
|
|
name=None, is_silent=False):
|
|
|
"""Saves the training session to disk.
|
|
|
|
|
|
Args:
|
|
|
paths: Provides information about the different paths to use.
|
|
|
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
|
|
|
optimizer: An optmizer to save the state of (momentum, etc).
|
|
|
name: If provided, will name to a checkpoint with the given name. Note
|
|
|
that regardless of whether this is provided or not, this function
|
|
|
will always update the files specified in `paths` that give the
|
|
|
location of the latest weights and optimizer state. Saving
|
|
|
a named checkpoint happens in addition to this update.
|
|
|
"""
|
|
|
def helper(path_dict, is_named):
|
|
|
s = 'named' if is_named else 'latest'
|
|
|
num_exist = sum(p.exists() for p in path_dict.values())
|
|
|
|
|
|
if num_exist not in (0,2):
|
|
|
|
|
|
raise FileNotFoundError(
|
|
|
f'We expected either both or no files in the {s} checkpoint to '
|
|
|
'exist, but instead we got exactly one!')
|
|
|
|
|
|
if num_exist == 0:
|
|
|
if not is_silent: print(f'Creating {s} checkpoint...')
|
|
|
for p in path_dict.values():
|
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
|
else:
|
|
|
if not is_silent: print(f'Saving to existing {s} checkpoint...')
|
|
|
|
|
|
if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
|
|
|
model.save(path_dict['w'])
|
|
|
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
|
|
|
torch.save(optimizer.state_dict(), path_dict['o'])
|
|
|
|
|
|
weights_path, optim_path, checkpoint_path = \
|
|
|
get_checkpoint_paths(checkpoint_type, paths)
|
|
|
|
|
|
latest_paths = {'w': weights_path, 'o': optim_path}
|
|
|
helper(latest_paths, False)
|
|
|
|
|
|
if name:
|
|
|
named_paths = {
|
|
|
'w': checkpoint_path/f'{name}_weights.pyt',
|
|
|
'o': checkpoint_path/f'{name}_optim.pyt',
|
|
|
}
|
|
|
helper(named_paths, True)
|
|
|
|
|
|
|
|
|
def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
|
|
|
name=None, create_if_missing=False):
|
|
|
"""Restores from a training session saved to disk.
|
|
|
|
|
|
NOTE: The optimizer's state is placed on the same device as it's model
|
|
|
parameters. Therefore, be sure you have done `model.to(device)` before
|
|
|
calling this method.
|
|
|
|
|
|
Args:
|
|
|
paths: Provides information about the different paths to use.
|
|
|
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
|
|
|
optimizer: An optmizer to save the state of (momentum, etc).
|
|
|
name: If provided, will restore from a checkpoint with the given name.
|
|
|
Otherwise, will restore from the latest weights and optimizer state
|
|
|
as specified in `paths`.
|
|
|
create_if_missing: If `True`, will create the checkpoint if it doesn't
|
|
|
yet exist, as well as update the files specified in `paths` that
|
|
|
give the location of the current latest weights and optimizer state.
|
|
|
If `False` and the checkpoint doesn't exist, will raise a
|
|
|
`FileNotFoundError`.
|
|
|
"""
|
|
|
|
|
|
weights_path, optim_path, checkpoint_path = \
|
|
|
get_checkpoint_paths(checkpoint_type, paths)
|
|
|
|
|
|
if name:
|
|
|
path_dict = {
|
|
|
'w': checkpoint_path/f'{name}_weights.pyt',
|
|
|
'o': checkpoint_path/f'{name}_optim.pyt',
|
|
|
}
|
|
|
s = 'named'
|
|
|
else:
|
|
|
path_dict = {
|
|
|
'w': weights_path,
|
|
|
'o': optim_path
|
|
|
}
|
|
|
s = 'latest'
|
|
|
|
|
|
num_exist = sum(p.exists() for p in path_dict.values())
|
|
|
if num_exist == 2:
|
|
|
|
|
|
print(f'Restoring from {s} checkpoint...')
|
|
|
print(f'Loading {s} weights: {path_dict["w"]}')
|
|
|
model.load(path_dict['w'])
|
|
|
print(f'Loading {s} optimizer state: {path_dict["o"]}')
|
|
|
optimizer.load_state_dict(torch.load(path_dict['o']))
|
|
|
elif create_if_missing:
|
|
|
save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
|
|
|
else:
|
|
|
raise FileNotFoundError(f'The {s} checkpoint could not be found!') |