Spaces:
Sleeping
Sleeping
| import time | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from matplotlib import pyplot as plt | |
| from torch import nn as nn, optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import config | |
| from dataset import EmotionDataset | |
| from model import EmotionModel | |
| from utils import load_checkpoint, save_checkpoint | |
| def train_fn(model, loader, opt, criterion, epoch): | |
| loop = tqdm(loader, leave=True) | |
| model.train() | |
| epoch_loss = 0.0 | |
| for idx, (image, label) in enumerate(loop): | |
| total_acc, total_count = 0, 0 | |
| image = image.to(config.DEVICE) | |
| label = label.to(config.DEVICE) | |
| opt.zero_grad() | |
| predicted_label = model(image) | |
| loss = criterion(predicted_label, label) | |
| epoch_loss += loss.item() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) | |
| opt.step() | |
| total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item() | |
| total_count += label.size(0) | |
| loop.set_postfix({"epoch": epoch, "loss": epoch_loss / len(loader), "accuracy": total_acc / total_count}) | |
| def main(): | |
| model = EmotionModel().to(config.DEVICE) | |
| opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), ) | |
| criterion = nn.CrossEntropyLoss() | |
| # if config.LOAD_MODEL: | |
| # load_checkpoint( | |
| # config.CHECKPOINT, model, opt, config.LEARNING_RATE, | |
| # ) | |
| train_dataset = EmotionDataset(root_dir=config.TRAIN_DIR) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| ) | |
| val_dataset = EmotionDataset(root_dir=config.VAL_DIR) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| ) | |
| total_accu = None | |
| # scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.5) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=2, verbose=True) | |
| for epoch in range(config.NUM_EPOCHS): | |
| epoch_start_time = time.time() | |
| train_fn( | |
| model, train_loader, opt, criterion, epoch | |
| ) | |
| accu_val, loss_val = evaluate(model, criterion, val_loader) | |
| # if total_accu is not None and total_accu > accu_val: | |
| # scheduler.step() | |
| # else: | |
| # total_accu = accu_val | |
| scheduler.step(loss_val) | |
| print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+") | |
| print( | |
| "| end of epoch: {:3d} | time: {:6.2f}s | val_loss: {:8.3f} | " | |
| "val_accuracy: {:8.3f} |".format( | |
| epoch, time.time() - epoch_start_time, loss_val, accu_val | |
| ) | |
| ) | |
| print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+") | |
| if config.SAVE_MODEL: | |
| save_checkpoint(model, opt, filename=config.CHECKPOINT) | |
| def test(): | |
| model = EmotionModel().to(config.DEVICE) | |
| opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), ) | |
| criterion = nn.CrossEntropyLoss() | |
| if config.LOAD_MODEL: | |
| load_checkpoint( | |
| config.CHECKPOINT, model, opt, config.LEARNING_RATE, | |
| ) | |
| val_dataset = EmotionDataset(root_dir=config.VAL_DIR) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| ) | |
| # print(evaluate(model, criterion, val_loader)) | |
| model.eval() | |
| print(val_dataset.class_to_idx) | |
| image = np.array(Image.open("images/validation/angry/245.jpg").convert('L')) | |
| plt.imshow(image) | |
| image = config.transform(image=image)["image"] | |
| image = image.to(config.DEVICE) | |
| image = torch.unsqueeze(image, dim=0) | |
| score = model(image) | |
| print(torch.argmax(score)) | |
| plt.show() | |
| def evaluate(model, criterion, dataloader): | |
| model.eval() | |
| total_correct = 0 | |
| total_samples = 0 | |
| total_loss = 0.0 | |
| with torch.no_grad(): | |
| for inputs, labels in dataloader: | |
| inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total_correct += (predicted == labels.argmax(1)).sum().item() | |
| total_samples += labels.size(0) | |
| accuracy = total_correct / total_samples | |
| average_loss = total_loss / len(dataloader) | |
| return accuracy, average_loss | |
| if __name__ == "__main__": | |
| test() | |