lord-reso commited on
Commit
7071ffd
·
verified ·
1 Parent(s): cac86d7

Removed unused functions

Browse files
Files changed (1) hide show
  1. train.py +1 -201
train.py CHANGED
@@ -1,212 +1,12 @@
1
- import os
2
- import time
3
- import argparse
4
- import math
5
- from numpy import finfo
6
-
7
  import torch
8
- from torch.utils.data import DataLoader
9
 
10
  from model import Tacotron2
11
- from data_utils import TextMelLoader, TextMelCollate
12
- from loss_function import Tacotron2Loss
13
- from logger import Tacotron2Logger
14
  from hparams import create_hparams
15
 
16
-
17
- def prepare_dataloaders(hparams):
18
- # Get data, data loaders, and collate function ready
19
- trainset = TextMelLoader(hparams.training_files, hparams)
20
- valset = TextMelLoader(hparams.validation_files, hparams)
21
- collate_fn = TextMelCollate(hparams.n_frames_per_step)
22
-
23
- train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
24
- batch_size=hparams.batch_size, collate_fn=collate_fn)
25
- return train_loader, valset, collate_fn
26
-
27
-
28
- def prepare_directories_and_logger(output_directory, log_directory):
29
- if not os.path.isdir(output_directory):
30
- os.makedirs(output_directory)
31
- os.chmod(output_directory, 0o775)
32
- logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
33
- return logger
34
-
35
-
36
  def load_model(hparams):
37
  model = Tacotron2(hparams).float()
38
  if hparams.fp16_run:
39
  model.decoder.attention_layer.score_mask_value = finfo('float16').min
40
 
41
  return model
42
-
43
-
44
- def warm_start_model(checkpoint_path, model, ignore_layers):
45
- assert os.path.isfile(checkpoint_path)
46
- print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
47
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
48
- model_dict = checkpoint_dict['state_dict']
49
- if len(ignore_layers) > 0:
50
- model_dict = {k: v for k, v in model_dict.items()
51
- if k not in ignore_layers}
52
- dummy_dict = model.state_dict()
53
- dummy_dict.update(model_dict)
54
- model_dict = dummy_dict
55
- model.load_state_dict(model_dict)
56
- return model
57
-
58
-
59
- def load_checkpoint(checkpoint_path, model, optimizer):
60
- assert os.path.isfile(checkpoint_path)
61
- print("Loading checkpoint '{}'".format(checkpoint_path))
62
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
63
- model.load_state_dict(checkpoint_dict['state_dict'])
64
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
65
- learning_rate = checkpoint_dict['learning_rate']
66
- iteration = checkpoint_dict['iteration']
67
- print("Loaded checkpoint '{}' from iteration {}".format(
68
- checkpoint_path, iteration))
69
- return model, optimizer, learning_rate, iteration
70
-
71
-
72
- def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
73
- print("Saving model and optimizer state at iteration {} to {}".format(
74
- iteration, filepath))
75
- torch.save({'iteration': iteration,
76
- 'state_dict': model.state_dict(),
77
- 'optimizer': optimizer.state_dict(),
78
- 'learning_rate': learning_rate}, filepath)
79
-
80
-
81
- def validate(model, criterion, valset, iteration, batch_size,
82
- collate_fn, logger):
83
- """Handles all the validation scoring and printing"""
84
- model.eval()
85
- with torch.no_grad():
86
- val_loader = DataLoader(valset, num_workers=1, shuffle=False,
87
- batch_size=batch_size, collate_fn=collate_fn)
88
-
89
- val_loss = 0.0
90
- for i, batch in enumerate(val_loader):
91
- x, y = model.parse_batch(batch)
92
- y_pred = model(x)
93
- loss = criterion(y_pred, y)
94
- reduced_val_loss = loss.item()
95
- val_loss += reduced_val_loss
96
- val_loss = val_loss / (i + 1)
97
-
98
- model.train()
99
- print("Validation loss {}: {:9f} ".format(iteration, val_loss))
100
- logger.log_validation(val_loss, model, y, y_pred, iteration)
101
-
102
-
103
- def train(output_directory, log_directory, checkpoint_path, warm_start,
104
- hparams):
105
- """Training and validation logging results to tensorboard and stdout
106
-
107
- Params
108
- ------
109
- output_directory (string): directory to save checkpoints
110
- log_directory (string) directory to save tensorboard logs
111
- checkpoint_path(string): checkpoint path
112
- hparams (object): comma-separated list of "name=value" pairs.
113
- """
114
- torch.manual_seed(hparams.seed)
115
-
116
- model = load_model(hparams)
117
- learning_rate = hparams.learning_rate
118
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
119
- weight_decay=hparams.weight_decay)
120
-
121
- if hparams.fp16_run:
122
- from apex import amp
123
- model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
124
-
125
- criterion = Tacotron2Loss()
126
-
127
- logger = prepare_directories_and_logger(
128
- output_directory, log_directory)
129
-
130
- train_loader, valset, collate_fn = prepare_dataloaders(hparams)
131
-
132
- # Load checkpoint if one exists
133
- iteration = 0
134
- if checkpoint_path is not None:
135
- if warm_start:
136
- model = warm_start_model(checkpoint_path, model, hparams.ignore_layers)
137
- else:
138
- model, optimizer, _learning_rate, iteration = load_checkpoint(
139
- checkpoint_path, model, optimizer)
140
- if hparams.use_saved_learning_rate:
141
- learning_rate = _learning_rate
142
- iteration += 1 # next iteration is iteration + 1
143
-
144
- model.train()
145
- is_overflow = False
146
- # ================ MAIN TRAINING LOOP! ===================
147
- for epoch in range(hparams.epochs):
148
- print("Epoch: {}".format(epoch))
149
- for i, batch in enumerate(train_loader):
150
- start = time.perf_counter()
151
- for param_group in optimizer.param_groups:
152
- param_group['lr'] = learning_rate
153
-
154
- model.zero_grad()
155
- x, y = model.parse_batch(batch)
156
- y_pred = model(x)
157
-
158
- loss = criterion(y_pred, y)
159
- reduced_loss = loss.item()
160
- if hparams.fp16_run:
161
- with amp.scale_loss(loss, optimizer) as scaled_loss:
162
- scaled_loss.backward()
163
- else:
164
- loss.backward()
165
-
166
- grad_norm = torch.nn.utils.clip_grad_norm_(
167
- model.parameters(), hparams.grad_clip_thresh)
168
-
169
- optimizer.step()
170
-
171
- if not is_overflow:
172
- duration = time.perf_counter() - start
173
- print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
174
- iteration, reduced_loss, grad_norm, duration))
175
- logger.log_training(
176
- reduced_loss, grad_norm, learning_rate, duration, iteration)
177
-
178
- if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
179
- validate(model, criterion, valset, iteration,
180
- hparams.batch_size, collate_fn, logger)
181
- checkpoint_path = os.path.join(
182
- output_directory, "checkpoint_{}".format(iteration))
183
- save_checkpoint(model, optimizer, learning_rate, iteration,
184
- checkpoint_path)
185
-
186
- iteration += 1
187
-
188
-
189
- if __name__ == '__main__':
190
- parser = argparse.ArgumentParser()
191
- parser.add_argument('-o', '--output_directory', type=str,
192
- help='directory to save checkpoints')
193
- parser.add_argument('-l', '--log_directory', type=str,
194
- help='directory to save tensorboard logs')
195
- parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
196
- required=False, help='checkpoint path')
197
- parser.add_argument('--warm_start', action='store_true',
198
- help='load model weights only, ignore specified layers')
199
- parser.add_argument('--hparams', type=str,
200
- required=False, help='comma-separated name=value pairs')
201
-
202
- args = parser.parse_args()
203
- hparams = create_hparams(args.hparams)
204
-
205
- torch.backends.cudnn.enabled = hparams.cudnn_enabled
206
- torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
207
-
208
- print("FP16 Run:", hparams.fp16_run)
209
- print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
210
-
211
- train(args.output_directory, args.log_directory, args.checkpoint_path,
212
- args.warm_start, hparams)
 
 
 
 
 
 
 
1
  import torch
2
+ from numpy import finfo
3
 
4
  from model import Tacotron2
 
 
 
5
  from hparams import create_hparams
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def load_model(hparams):
8
  model = Tacotron2(hparams).float()
9
  if hparams.fp16_run:
10
  model.decoder.attention_layer.score_mask_value = finfo('float16').min
11
 
12
  return model