import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt from PIL import ImageFile # Configuration BATCH_SIZE = 64 EPOCHS = 100 IMG_SIZE = 48 MODEL_PATH = "data/models/expression_predictor_cnn.pth" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {DEVICE}") ImageFile.LOAD_TRUNCATED_IMAGES = True # Data transforms transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((48, 48)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Datasets and loaders train_dataset = datasets.ImageFolder("data/train", transform=transform) val_dataset = datasets.ImageFolder("data/validation", transform=transform) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0) # Class names CLASSES = train_dataset.classes NUM_CLASSES = len(CLASSES) print(f"Classes: {CLASSES}") # CNN Model class ExpressionCNN(nn.Module): def __init__(self): super(ExpressionCNN, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(2), nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(256), nn.AdaptiveAvgPool2d((1, 1)) ) self.fc = nn.Sequential( nn.Flatten(), nn.Linear(256, NUM_CLASSES) ) def forward(self, x): x = self.conv(x) x = self.fc(x) return x model = ExpressionCNN().to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # Training loop train_loss_log = [] val_loss_log = [] for epoch in range(EPOCHS): print(f"\nStarting Epoch {epoch+1}/{EPOCHS}") model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() scheduler.step() train_loss = running_loss / len(train_loader) train_loss_log.append(train_loss) # Validation model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(DEVICE), labels.to(DEVICE) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) correct += (predicted == labels).sum().item() total += labels.size(0) val_loss /= len(val_loader) val_loss_log.append(val_loss) accuracy = correct / total * 100 print(f"[{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {accuracy:.2f}%") # Save model torch.save(model.state_dict(), MODEL_PATH) print(f"✅ Model saved to {MODEL_PATH}") # Plot loss plt.plot(train_loss_log, label="Train") plt.plot(val_loss_log, label="Validation") plt.title("Loss Curve") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend() plt.grid() plt.savefig("training_loss_plot.png")