import time import copy import torch from data_loaders import load_dataset, get_dataset_sizes, get_dataloaders def train_model(model, criterion, optimizer, scheduler, num_epochs, batch_size, device): since = time.time() # Load data data_set = load_dataset() dataset_sizes = get_dataset_sizes(data_set) dataloaders = get_dataloaders(data_set, batch_size) best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 best_loss = 10000.0 # Large arbitrary number best_acc_train = 0.0 best_loss_train = 10000.0 # Large arbitrary number print("Training started:") time_elapsed = time.time() - since print("Data Loading Completed in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) for epoch in range(num_epochs): # Each epoch has a training and validation phase for phase in ["train", "validation"]: if phase == "train": # Set model to training mode model.train() else: # Set model to evaluate mode model.eval() running_loss = 0.0 running_corrects = 0 # Iterate over data. n_batches = dataset_sizes[phase] // batch_size it = 0 for inputs, labels in dataloaders[phase]: since_batch = time.time() batch_size_ = len(inputs) inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() # Track/compute gradient and make an optimization step only when training with torch.set_grad_enabled(phase == "train"): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == "train": loss.backward() optimizer.step() # Print iteration results running_loss += loss.item() * batch_size_ batch_corrects = torch.sum(preds == labels.data).item() running_corrects += batch_corrects print( "Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}".format( phase, epoch + 1, num_epochs, it + 1, n_batches + 1, time.time() - since_batch, ), end="\r", flush=True, ) it += 1 # Print epoch results epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects / dataset_sizes[phase] print( "Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f} ".format( "train" if phase == "train" else "validation ", epoch + 1, num_epochs, epoch_loss, epoch_acc, ) ) # Check if this is the best model wrt previous epochs if phase == "validation" and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) if phase == "validation" and epoch_loss < best_loss: best_loss = epoch_loss if phase == "train" and epoch_acc > best_acc_train: best_acc_train = epoch_acc if phase == "train" and epoch_loss < best_loss_train: best_loss_train = epoch_loss # Update learning rate if phase == "train": scheduler.step() # Print final results model.load_state_dict(best_model_wts) time_elapsed = time.time() - since print("Training Completed in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) print("Best test loss: {:.4f} | Best test accuracy: {:.4f}".format(best_loss, best_acc)) return model