Spaces:
Sleeping
Sleeping
Removed unused functions
Browse files
train.py
CHANGED
|
@@ -1,212 +1,12 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import argparse
|
| 4 |
-
import math
|
| 5 |
-
from numpy import finfo
|
| 6 |
-
|
| 7 |
import torch
|
| 8 |
-
from
|
| 9 |
|
| 10 |
from model import Tacotron2
|
| 11 |
-
from data_utils import TextMelLoader, TextMelCollate
|
| 12 |
-
from loss_function import Tacotron2Loss
|
| 13 |
-
from logger import Tacotron2Logger
|
| 14 |
from hparams import create_hparams
|
| 15 |
|
| 16 |
-
|
| 17 |
-
def prepare_dataloaders(hparams):
|
| 18 |
-
# Get data, data loaders, and collate function ready
|
| 19 |
-
trainset = TextMelLoader(hparams.training_files, hparams)
|
| 20 |
-
valset = TextMelLoader(hparams.validation_files, hparams)
|
| 21 |
-
collate_fn = TextMelCollate(hparams.n_frames_per_step)
|
| 22 |
-
|
| 23 |
-
train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
|
| 24 |
-
batch_size=hparams.batch_size, collate_fn=collate_fn)
|
| 25 |
-
return train_loader, valset, collate_fn
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def prepare_directories_and_logger(output_directory, log_directory):
|
| 29 |
-
if not os.path.isdir(output_directory):
|
| 30 |
-
os.makedirs(output_directory)
|
| 31 |
-
os.chmod(output_directory, 0o775)
|
| 32 |
-
logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
|
| 33 |
-
return logger
|
| 34 |
-
|
| 35 |
-
|
| 36 |
def load_model(hparams):
|
| 37 |
model = Tacotron2(hparams).float()
|
| 38 |
if hparams.fp16_run:
|
| 39 |
model.decoder.attention_layer.score_mask_value = finfo('float16').min
|
| 40 |
|
| 41 |
return model
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def warm_start_model(checkpoint_path, model, ignore_layers):
|
| 45 |
-
assert os.path.isfile(checkpoint_path)
|
| 46 |
-
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
|
| 47 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
| 48 |
-
model_dict = checkpoint_dict['state_dict']
|
| 49 |
-
if len(ignore_layers) > 0:
|
| 50 |
-
model_dict = {k: v for k, v in model_dict.items()
|
| 51 |
-
if k not in ignore_layers}
|
| 52 |
-
dummy_dict = model.state_dict()
|
| 53 |
-
dummy_dict.update(model_dict)
|
| 54 |
-
model_dict = dummy_dict
|
| 55 |
-
model.load_state_dict(model_dict)
|
| 56 |
-
return model
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def load_checkpoint(checkpoint_path, model, optimizer):
|
| 60 |
-
assert os.path.isfile(checkpoint_path)
|
| 61 |
-
print("Loading checkpoint '{}'".format(checkpoint_path))
|
| 62 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
| 63 |
-
model.load_state_dict(checkpoint_dict['state_dict'])
|
| 64 |
-
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
| 65 |
-
learning_rate = checkpoint_dict['learning_rate']
|
| 66 |
-
iteration = checkpoint_dict['iteration']
|
| 67 |
-
print("Loaded checkpoint '{}' from iteration {}".format(
|
| 68 |
-
checkpoint_path, iteration))
|
| 69 |
-
return model, optimizer, learning_rate, iteration
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
|
| 73 |
-
print("Saving model and optimizer state at iteration {} to {}".format(
|
| 74 |
-
iteration, filepath))
|
| 75 |
-
torch.save({'iteration': iteration,
|
| 76 |
-
'state_dict': model.state_dict(),
|
| 77 |
-
'optimizer': optimizer.state_dict(),
|
| 78 |
-
'learning_rate': learning_rate}, filepath)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def validate(model, criterion, valset, iteration, batch_size,
|
| 82 |
-
collate_fn, logger):
|
| 83 |
-
"""Handles all the validation scoring and printing"""
|
| 84 |
-
model.eval()
|
| 85 |
-
with torch.no_grad():
|
| 86 |
-
val_loader = DataLoader(valset, num_workers=1, shuffle=False,
|
| 87 |
-
batch_size=batch_size, collate_fn=collate_fn)
|
| 88 |
-
|
| 89 |
-
val_loss = 0.0
|
| 90 |
-
for i, batch in enumerate(val_loader):
|
| 91 |
-
x, y = model.parse_batch(batch)
|
| 92 |
-
y_pred = model(x)
|
| 93 |
-
loss = criterion(y_pred, y)
|
| 94 |
-
reduced_val_loss = loss.item()
|
| 95 |
-
val_loss += reduced_val_loss
|
| 96 |
-
val_loss = val_loss / (i + 1)
|
| 97 |
-
|
| 98 |
-
model.train()
|
| 99 |
-
print("Validation loss {}: {:9f} ".format(iteration, val_loss))
|
| 100 |
-
logger.log_validation(val_loss, model, y, y_pred, iteration)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def train(output_directory, log_directory, checkpoint_path, warm_start,
|
| 104 |
-
hparams):
|
| 105 |
-
"""Training and validation logging results to tensorboard and stdout
|
| 106 |
-
|
| 107 |
-
Params
|
| 108 |
-
------
|
| 109 |
-
output_directory (string): directory to save checkpoints
|
| 110 |
-
log_directory (string) directory to save tensorboard logs
|
| 111 |
-
checkpoint_path(string): checkpoint path
|
| 112 |
-
hparams (object): comma-separated list of "name=value" pairs.
|
| 113 |
-
"""
|
| 114 |
-
torch.manual_seed(hparams.seed)
|
| 115 |
-
|
| 116 |
-
model = load_model(hparams)
|
| 117 |
-
learning_rate = hparams.learning_rate
|
| 118 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
|
| 119 |
-
weight_decay=hparams.weight_decay)
|
| 120 |
-
|
| 121 |
-
if hparams.fp16_run:
|
| 122 |
-
from apex import amp
|
| 123 |
-
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
|
| 124 |
-
|
| 125 |
-
criterion = Tacotron2Loss()
|
| 126 |
-
|
| 127 |
-
logger = prepare_directories_and_logger(
|
| 128 |
-
output_directory, log_directory)
|
| 129 |
-
|
| 130 |
-
train_loader, valset, collate_fn = prepare_dataloaders(hparams)
|
| 131 |
-
|
| 132 |
-
# Load checkpoint if one exists
|
| 133 |
-
iteration = 0
|
| 134 |
-
if checkpoint_path is not None:
|
| 135 |
-
if warm_start:
|
| 136 |
-
model = warm_start_model(checkpoint_path, model, hparams.ignore_layers)
|
| 137 |
-
else:
|
| 138 |
-
model, optimizer, _learning_rate, iteration = load_checkpoint(
|
| 139 |
-
checkpoint_path, model, optimizer)
|
| 140 |
-
if hparams.use_saved_learning_rate:
|
| 141 |
-
learning_rate = _learning_rate
|
| 142 |
-
iteration += 1 # next iteration is iteration + 1
|
| 143 |
-
|
| 144 |
-
model.train()
|
| 145 |
-
is_overflow = False
|
| 146 |
-
# ================ MAIN TRAINING LOOP! ===================
|
| 147 |
-
for epoch in range(hparams.epochs):
|
| 148 |
-
print("Epoch: {}".format(epoch))
|
| 149 |
-
for i, batch in enumerate(train_loader):
|
| 150 |
-
start = time.perf_counter()
|
| 151 |
-
for param_group in optimizer.param_groups:
|
| 152 |
-
param_group['lr'] = learning_rate
|
| 153 |
-
|
| 154 |
-
model.zero_grad()
|
| 155 |
-
x, y = model.parse_batch(batch)
|
| 156 |
-
y_pred = model(x)
|
| 157 |
-
|
| 158 |
-
loss = criterion(y_pred, y)
|
| 159 |
-
reduced_loss = loss.item()
|
| 160 |
-
if hparams.fp16_run:
|
| 161 |
-
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| 162 |
-
scaled_loss.backward()
|
| 163 |
-
else:
|
| 164 |
-
loss.backward()
|
| 165 |
-
|
| 166 |
-
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 167 |
-
model.parameters(), hparams.grad_clip_thresh)
|
| 168 |
-
|
| 169 |
-
optimizer.step()
|
| 170 |
-
|
| 171 |
-
if not is_overflow:
|
| 172 |
-
duration = time.perf_counter() - start
|
| 173 |
-
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
|
| 174 |
-
iteration, reduced_loss, grad_norm, duration))
|
| 175 |
-
logger.log_training(
|
| 176 |
-
reduced_loss, grad_norm, learning_rate, duration, iteration)
|
| 177 |
-
|
| 178 |
-
if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
|
| 179 |
-
validate(model, criterion, valset, iteration,
|
| 180 |
-
hparams.batch_size, collate_fn, logger)
|
| 181 |
-
checkpoint_path = os.path.join(
|
| 182 |
-
output_directory, "checkpoint_{}".format(iteration))
|
| 183 |
-
save_checkpoint(model, optimizer, learning_rate, iteration,
|
| 184 |
-
checkpoint_path)
|
| 185 |
-
|
| 186 |
-
iteration += 1
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
if __name__ == '__main__':
|
| 190 |
-
parser = argparse.ArgumentParser()
|
| 191 |
-
parser.add_argument('-o', '--output_directory', type=str,
|
| 192 |
-
help='directory to save checkpoints')
|
| 193 |
-
parser.add_argument('-l', '--log_directory', type=str,
|
| 194 |
-
help='directory to save tensorboard logs')
|
| 195 |
-
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
|
| 196 |
-
required=False, help='checkpoint path')
|
| 197 |
-
parser.add_argument('--warm_start', action='store_true',
|
| 198 |
-
help='load model weights only, ignore specified layers')
|
| 199 |
-
parser.add_argument('--hparams', type=str,
|
| 200 |
-
required=False, help='comma-separated name=value pairs')
|
| 201 |
-
|
| 202 |
-
args = parser.parse_args()
|
| 203 |
-
hparams = create_hparams(args.hparams)
|
| 204 |
-
|
| 205 |
-
torch.backends.cudnn.enabled = hparams.cudnn_enabled
|
| 206 |
-
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
|
| 207 |
-
|
| 208 |
-
print("FP16 Run:", hparams.fp16_run)
|
| 209 |
-
print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
|
| 210 |
-
|
| 211 |
-
train(args.output_directory, args.log_directory, args.checkpoint_path,
|
| 212 |
-
args.warm_start, hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from numpy import finfo
|
| 3 |
|
| 4 |
from model import Tacotron2
|
|
|
|
|
|
|
|
|
|
| 5 |
from hparams import create_hparams
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def load_model(hparams):
|
| 8 |
model = Tacotron2(hparams).float()
|
| 9 |
if hparams.fp16_run:
|
| 10 |
model.decoder.attention_layer.score_mask_value = finfo('float16').min
|
| 11 |
|
| 12 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|