File size: 5,137 Bytes
2e62044
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from utils.paths import Paths
from models.tacotron import Tacotron


def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
    """

    Returns the correct checkpointing paths

    depending on whether model is Vocoder or TTS



    Args:

        checkpoint_type: Either 'voc' or 'tts'

        paths: Paths object

    """
    if checkpoint_type is 'tts':
        weights_path = paths.tts_latest_weights
        optim_path = paths.tts_latest_optim
        checkpoint_path = paths.tts_checkpoints
    elif checkpoint_type is 'voc':
        weights_path = paths.voc_latest_weights
        optim_path = paths.voc_latest_optim
        checkpoint_path = paths.voc_checkpoints
    else:
        raise NotImplementedError

    return weights_path, optim_path, checkpoint_path


def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,

        name=None, is_silent=False):
    """Saves the training session to disk.



    Args:

        paths:  Provides information about the different paths to use.

        model:  A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.

        optimizer:  An optmizer to save the state of (momentum, etc).

        name:  If provided, will name to a checkpoint with the given name. Note

            that regardless of whether this is provided or not, this function

            will always update the files specified in `paths` that give the

            location of the latest weights and optimizer state. Saving

            a named checkpoint happens in addition to this update.

    """
    def helper(path_dict, is_named):
        s = 'named' if is_named else 'latest'
        num_exist = sum(p.exists() for p in path_dict.values())

        if num_exist not in (0,2):
            # Checkpoint broken
            raise FileNotFoundError(
                f'We expected either both or no files in the {s} checkpoint to '
                'exist, but instead we got exactly one!')

        if num_exist == 0:
            if not is_silent: print(f'Creating {s} checkpoint...')
            for p in path_dict.values():
                p.parent.mkdir(parents=True, exist_ok=True)
        else:
            if not is_silent: print(f'Saving to existing {s} checkpoint...')

        if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
        model.save(path_dict['w'])
        if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
        torch.save(optimizer.state_dict(), path_dict['o'])

    weights_path, optim_path, checkpoint_path = \
        get_checkpoint_paths(checkpoint_type, paths)

    latest_paths = {'w': weights_path, 'o': optim_path}
    helper(latest_paths, False)

    if name:
        named_paths = {
            'w': checkpoint_path/f'{name}_weights.pyt',
            'o': checkpoint_path/f'{name}_optim.pyt',
        }
        helper(named_paths, True)


def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,

        name=None, create_if_missing=False):
    """Restores from a training session saved to disk.



    NOTE: The optimizer's state is placed on the same device as it's model

    parameters. Therefore, be sure you have done `model.to(device)` before

    calling this method.



    Args:

        paths:  Provides information about the different paths to use.

        model:  A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.

        optimizer:  An optmizer to save the state of (momentum, etc).

        name:  If provided, will restore from a checkpoint with the given name.

            Otherwise, will restore from the latest weights and optimizer state

            as specified in `paths`.

        create_if_missing:  If `True`, will create the checkpoint if it doesn't

            yet exist, as well as update the files specified in `paths` that

            give the location of the current latest weights and optimizer state.

            If `False` and the checkpoint doesn't exist, will raise a

            `FileNotFoundError`.

    """

    weights_path, optim_path, checkpoint_path = \
        get_checkpoint_paths(checkpoint_type, paths)

    if name:
        path_dict = {
            'w': checkpoint_path/f'{name}_weights.pyt',
            'o': checkpoint_path/f'{name}_optim.pyt',
        }
        s = 'named'
    else:
        path_dict = {
            'w': weights_path,
            'o': optim_path
        }
        s = 'latest'

    num_exist = sum(p.exists() for p in path_dict.values())
    if num_exist == 2:
        # Checkpoint exists
        print(f'Restoring from {s} checkpoint...')
        print(f'Loading {s} weights: {path_dict["w"]}')
        model.load(path_dict['w'])
        print(f'Loading {s} optimizer state: {path_dict["o"]}')
        optimizer.load_state_dict(torch.load(path_dict['o']))
    elif create_if_missing:
        save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
    else:
        raise FileNotFoundError(f'The {s} checkpoint could not be found!')