niobures's picture
RNNoise (models)
2e62044 verified
import torch
from utils.paths import Paths
from models.tacotron import Tacotron
def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
"""
Returns the correct checkpointing paths
depending on whether model is Vocoder or TTS
Args:
checkpoint_type: Either 'voc' or 'tts'
paths: Paths object
"""
if checkpoint_type is 'tts':
weights_path = paths.tts_latest_weights
optim_path = paths.tts_latest_optim
checkpoint_path = paths.tts_checkpoints
elif checkpoint_type is 'voc':
weights_path = paths.voc_latest_weights
optim_path = paths.voc_latest_optim
checkpoint_path = paths.voc_checkpoints
else:
raise NotImplementedError
return weights_path, optim_path, checkpoint_path
def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
name=None, is_silent=False):
"""Saves the training session to disk.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will name to a checkpoint with the given name. Note
that regardless of whether this is provided or not, this function
will always update the files specified in `paths` that give the
location of the latest weights and optimizer state. Saving
a named checkpoint happens in addition to this update.
"""
def helper(path_dict, is_named):
s = 'named' if is_named else 'latest'
num_exist = sum(p.exists() for p in path_dict.values())
if num_exist not in (0,2):
# Checkpoint broken
raise FileNotFoundError(
f'We expected either both or no files in the {s} checkpoint to '
'exist, but instead we got exactly one!')
if num_exist == 0:
if not is_silent: print(f'Creating {s} checkpoint...')
for p in path_dict.values():
p.parent.mkdir(parents=True, exist_ok=True)
else:
if not is_silent: print(f'Saving to existing {s} checkpoint...')
if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
model.save(path_dict['w'])
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])
weights_path, optim_path, checkpoint_path = \
get_checkpoint_paths(checkpoint_type, paths)
latest_paths = {'w': weights_path, 'o': optim_path}
helper(latest_paths, False)
if name:
named_paths = {
'w': checkpoint_path/f'{name}_weights.pyt',
'o': checkpoint_path/f'{name}_optim.pyt',
}
helper(named_paths, True)
def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
NOTE: The optimizer's state is placed on the same device as it's model
parameters. Therefore, be sure you have done `model.to(device)` before
calling this method.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will restore from a checkpoint with the given name.
Otherwise, will restore from the latest weights and optimizer state
as specified in `paths`.
create_if_missing: If `True`, will create the checkpoint if it doesn't
yet exist, as well as update the files specified in `paths` that
give the location of the current latest weights and optimizer state.
If `False` and the checkpoint doesn't exist, will raise a
`FileNotFoundError`.
"""
weights_path, optim_path, checkpoint_path = \
get_checkpoint_paths(checkpoint_type, paths)
if name:
path_dict = {
'w': checkpoint_path/f'{name}_weights.pyt',
'o': checkpoint_path/f'{name}_optim.pyt',
}
s = 'named'
else:
path_dict = {
'w': weights_path,
'o': optim_path
}
s = 'latest'
num_exist = sum(p.exists() for p in path_dict.values())
if num_exist == 2:
# Checkpoint exists
print(f'Restoring from {s} checkpoint...')
print(f'Loading {s} weights: {path_dict["w"]}')
model.load(path_dict['w'])
print(f'Loading {s} optimizer state: {path_dict["o"]}')
optimizer.load_state_dict(torch.load(path_dict['o']))
elif create_if_missing:
save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
else:
raise FileNotFoundError(f'The {s} checkpoint could not be found!')