import time import numpy as np import torch from PIL import Image from matplotlib import pyplot as plt from torch import nn as nn, optim from torch.utils.data import DataLoader from tqdm import tqdm import config from dataset import EmotionDataset from model import EmotionModel from utils import load_checkpoint, save_checkpoint def train_fn(model, loader, opt, criterion, epoch): loop = tqdm(loader, leave=True) model.train() epoch_loss = 0.0 for idx, (image, label) in enumerate(loop): total_acc, total_count = 0, 0 image = image.to(config.DEVICE) label = label.to(config.DEVICE) opt.zero_grad() predicted_label = model(image) loss = criterion(predicted_label, label) epoch_loss += loss.item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) opt.step() total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item() total_count += label.size(0) loop.set_postfix({"epoch": epoch, "loss": epoch_loss / len(loader), "accuracy": total_acc / total_count}) def main(): model = EmotionModel().to(config.DEVICE) opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), ) criterion = nn.CrossEntropyLoss() # if config.LOAD_MODEL: # load_checkpoint( # config.CHECKPOINT, model, opt, config.LEARNING_RATE, # ) train_dataset = EmotionDataset(root_dir=config.TRAIN_DIR) train_loader = DataLoader( train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, ) val_dataset = EmotionDataset(root_dir=config.VAL_DIR) val_loader = DataLoader( val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, ) total_accu = None # scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=2, verbose=True) for epoch in range(config.NUM_EPOCHS): epoch_start_time = time.time() train_fn( model, train_loader, opt, criterion, epoch ) accu_val, loss_val = evaluate(model, criterion, val_loader) # if total_accu is not None and total_accu > accu_val: # scheduler.step() # else: # total_accu = accu_val scheduler.step(loss_val) print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+") print( "| end of epoch: {:3d} | time: {:6.2f}s | val_loss: {:8.3f} | " "val_accuracy: {:8.3f} |".format( epoch, time.time() - epoch_start_time, loss_val, accu_val ) ) print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+") if config.SAVE_MODEL: save_checkpoint(model, opt, filename=config.CHECKPOINT) def test(): model = EmotionModel().to(config.DEVICE) opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), ) criterion = nn.CrossEntropyLoss() if config.LOAD_MODEL: load_checkpoint( config.CHECKPOINT, model, opt, config.LEARNING_RATE, ) val_dataset = EmotionDataset(root_dir=config.VAL_DIR) val_loader = DataLoader( val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, ) # print(evaluate(model, criterion, val_loader)) model.eval() print(val_dataset.class_to_idx) image = np.array(Image.open("images/validation/angry/245.jpg").convert('L')) plt.imshow(image) image = config.transform(image=image)["image"] image = image.to(config.DEVICE) image = torch.unsqueeze(image, dim=0) score = model(image) print(torch.argmax(score)) plt.show() def evaluate(model, criterion, dataloader): model.eval() total_correct = 0 total_samples = 0 total_loss = 0.0 with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE) outputs = model(inputs) loss = criterion(outputs, labels) total_loss += loss.item() _, predicted = torch.max(outputs, 1) total_correct += (predicted == labels.argmax(1)).sum().item() total_samples += labels.size(0) accuracy = total_correct / total_samples average_loss = total_loss / len(dataloader) return accuracy, average_loss if __name__ == "__main__": test()