allantacuelwvsu's picture
added evaluations
69535bd
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import ImageFile
# Configuration
BATCH_SIZE = 64
EPOCHS = 100
IMG_SIZE = 48
MODEL_PATH = "data/models/expression_predictor_cnn.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Data transforms
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((48, 48)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Datasets and loaders
train_dataset = datasets.ImageFolder("data/train", transform=transform)
val_dataset = datasets.ImageFolder("data/validation", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0)
# Class names
CLASSES = train_dataset.classes
NUM_CLASSES = len(CLASSES)
print(f"Classes: {CLASSES}")
# CNN Model
class ExpressionCNN(nn.Module):
def __init__(self):
super(ExpressionCNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(256), nn.AdaptiveAvgPool2d((1, 1))
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(256, NUM_CLASSES)
)
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x
model = ExpressionCNN().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Training loop
train_loss_log = []
val_loss_log = []
for epoch in range(EPOCHS):
print(f"\nStarting Epoch {epoch+1}/{EPOCHS}")
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
scheduler.step()
train_loss = running_loss / len(train_loader)
train_loss_log.append(train_loss)
# Validation
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
val_loss /= len(val_loader)
val_loss_log.append(val_loss)
accuracy = correct / total * 100
print(f"[{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {accuracy:.2f}%")
# Save model
torch.save(model.state_dict(), MODEL_PATH)
print(f"βœ… Model saved to {MODEL_PATH}")
# Plot loss
plt.plot(train_loss_log, label="Train")
plt.plot(val_loss_log, label="Validation")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.savefig("training_loss_plot.png")