Spaces:
Paused
Paused
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| def save_states(fname, | |
| model, | |
| optimizer, | |
| n_iter, epoch, | |
| net_config, config): | |
| torch.save({'model': model.state_dict(), | |
| 'optim': optimizer.state_dict(), | |
| 'epoch': epoch, | |
| 'iter': n_iter, | |
| 'config': net_config, | |
| }, | |
| f'{config.checkpoint_dir}/{fname}') | |
| def save_states_gan(fname, | |
| model, model_d, | |
| optimizer, optimizer_d, | |
| n_iter, epoch, | |
| net_config, config): | |
| torch.save({'model': model.state_dict(), | |
| 'model_d': model_d.state_dict(), | |
| 'optim': optimizer.state_dict(), | |
| 'optim_d': optimizer_d.state_dict(), | |
| 'epoch': epoch, 'iter': n_iter, | |
| 'config': net_config, | |
| }, | |
| f'{config.checkpoint_dir}/{fname}') | |
| def batch_to_device(batch, device): | |
| text_padded, input_lengths, mel_padded, gate_padded, \ | |
| output_lengths = batch | |
| text_padded = text_padded.to(device, non_blocking=True) | |
| input_lengths = input_lengths.to(device, non_blocking=True) | |
| mel_padded = mel_padded.to(device, non_blocking=True) | |
| gate_padded = gate_padded.to(device, non_blocking=True) | |
| output_lengths = output_lengths.to(device, non_blocking=True) | |
| return (text_padded, input_lengths, mel_padded, gate_padded, | |
| output_lengths) | |
| def validate(model, test_loader, writer, device, n_iter): | |
| loss_sum = 0 | |
| n_test_sum = 0 | |
| model.eval() | |
| for batch in test_loader: | |
| text_padded, input_lengths, mel_padded, gate_padded, \ | |
| output_lengths = batch_to_device(batch, device) | |
| y_pred = model(text_padded, input_lengths, | |
| mel_padded, output_lengths) | |
| mel_out, mel_out_postnet, gate_pred, alignments = y_pred | |
| mel_loss = F.mse_loss(mel_out, mel_padded) + \ | |
| F.mse_loss(mel_out_postnet, mel_padded) | |
| gate_loss = F.binary_cross_entropy_with_logits(gate_pred, gate_padded) | |
| loss = mel_loss + gate_loss | |
| loss_sum += mel_padded.size(0)*loss.item() | |
| n_test_sum += mel_padded.size(0) | |
| val_loss = loss_sum / n_test_sum | |
| idx = random.randint(0, mel_padded.size(0) - 1) | |
| mel_infer, *_ = model.infer( | |
| text_padded[idx:idx+1], input_lengths[idx:idx+1]) | |
| writer.add_sample( | |
| alignments[idx, :, :input_lengths[idx].item()], | |
| mel_out[idx], mel_padded[idx], mel_infer[0], | |
| output_lengths[idx], n_iter) | |
| writer.add_scalar('loss/val_loss', val_loss, n_iter) | |
| model.train() | |
| return val_loss | |