| import time | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| from data_proc.cross_entropy_dataset import FBanksCrossEntropyDataset | |
| from models.cross_entropy_model import FBankCrossEntropyNetV2 | |
| from utils.pt_util import restore_objects, save_model, save_objects, restore_model | |
| from trainer.cross_entropy_train import train, test | |
| def main(args): | |
| model_path = f"saved_models_cross_entropy/{args.num_layers}/" | |
| use_cuda = True | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print('using device', device) | |
| import multiprocessing | |
| print('num cpus:', multiprocessing.cpu_count()) | |
| kwargs = {'num_workers': multiprocessing.cpu_count(), | |
| 'pin_memory': True} if use_cuda else {} | |
| train_dataset = FBanksCrossEntropyDataset(args.train_folder) | |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) | |
| test_dataset = FBanksCrossEntropyDataset(args.test_folder) | |
| test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) | |
| model = FBankCrossEntropyNetV2(num_layers=args.num_layers, reduction='mean').to(device) | |
| model = restore_model(model, model_path) | |
| last_epoch, max_accuracy, train_losses, test_losses, train_accuracies, test_accuracies = restore_objects(model_path, (0, 0, [], [], [], [])) | |
| start = last_epoch + 1 if max_accuracy > 0 else 0 | |
| optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
| for epoch in range(start, args.epochs): | |
| train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch, 500) | |
| test_loss, test_accuracy = test(model, device, test_loader) | |
| print('After epoch: {}, train_loss: {}, test loss is: {}, train_accuracy: {}, ' | |
| 'test_accuracy: {}'.format(epoch, train_loss, test_loss, train_accuracy, test_accuracy)) | |
| train_losses.append(train_loss) | |
| test_losses.append(test_loss) | |
| train_accuracies.append(train_accuracy) | |
| test_accuracies.append(test_accuracy) | |
| if test_accuracy > max_accuracy: | |
| max_accuracy = test_accuracy | |
| save_model(model, epoch, model_path) | |
| save_objects((epoch, max_accuracy, train_losses, test_losses, train_accuracies, test_accuracies), epoch, model_path) | |
| print('saved epoch: {} as checkpoint'.format(epoch)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='FBank Cross Entropy Training Script') | |
| parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model') | |
| parser.add_argument('--train_folder', type=str, default='fbanks_train', help='Training dataset folder') | |
| parser.add_argument('--test_folder', type=str, default='fbanks_test', help='Testing dataset folder') | |
| parser.add_argument('--epochs', type=int, default=20, help='Number of epochs to train') | |
| parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training') | |
| parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for the optimizer') | |
| args = parser.parse_args() | |
| main(args) | |