face_eMEOWtion / train.py
Tanishq
Upload 7 files
2972f68 verified
import time
import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from torch import nn as nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import config
from dataset import EmotionDataset
from model import EmotionModel
from utils import load_checkpoint, save_checkpoint
def train_fn(model, loader, opt, criterion, epoch):
loop = tqdm(loader, leave=True)
model.train()
epoch_loss = 0.0
for idx, (image, label) in enumerate(loop):
total_acc, total_count = 0, 0
image = image.to(config.DEVICE)
label = label.to(config.DEVICE)
opt.zero_grad()
predicted_label = model(image)
loss = criterion(predicted_label, label)
epoch_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
opt.step()
total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item()
total_count += label.size(0)
loop.set_postfix({"epoch": epoch, "loss": epoch_loss / len(loader), "accuracy": total_acc / total_count})
def main():
model = EmotionModel().to(config.DEVICE)
opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), )
criterion = nn.CrossEntropyLoss()
# if config.LOAD_MODEL:
# load_checkpoint(
# config.CHECKPOINT, model, opt, config.LEARNING_RATE,
# )
train_dataset = EmotionDataset(root_dir=config.TRAIN_DIR)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
val_dataset = EmotionDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
total_accu = None
# scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=2, verbose=True)
for epoch in range(config.NUM_EPOCHS):
epoch_start_time = time.time()
train_fn(
model, train_loader, opt, criterion, epoch
)
accu_val, loss_val = evaluate(model, criterion, val_loader)
# if total_accu is not None and total_accu > accu_val:
# scheduler.step()
# else:
# total_accu = accu_val
scheduler.step(loss_val)
print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+")
print(
"| end of epoch: {:3d} | time: {:6.2f}s | val_loss: {:8.3f} | "
"val_accuracy: {:8.3f} |".format(
epoch, time.time() - epoch_start_time, loss_val, accu_val
)
)
print("+" + "-" * 19 + "+" + "-" * 15 + "+" + "-" * 20 + "+" + "-" * 24 + "+")
if config.SAVE_MODEL:
save_checkpoint(model, opt, filename=config.CHECKPOINT)
def test():
model = EmotionModel().to(config.DEVICE)
opt = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999), )
criterion = nn.CrossEntropyLoss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT, model, opt, config.LEARNING_RATE,
)
val_dataset = EmotionDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
# print(evaluate(model, criterion, val_loader))
model.eval()
print(val_dataset.class_to_idx)
image = np.array(Image.open("images/validation/angry/245.jpg").convert('L'))
plt.imshow(image)
image = config.transform(image=image)["image"]
image = image.to(config.DEVICE)
image = torch.unsqueeze(image, dim=0)
score = model(image)
print(torch.argmax(score))
plt.show()
def evaluate(model, criterion, dataloader):
model.eval()
total_correct = 0
total_samples = 0
total_loss = 0.0
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_correct += (predicted == labels.argmax(1)).sum().item()
total_samples += labels.size(0)
accuracy = total_correct / total_samples
average_loss = total_loss / len(dataloader)
return accuracy, average_loss
if __name__ == "__main__":
test()