File size: 3,776 Bytes
52109c6 69535bd 52109c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import ImageFile
# Configuration
BATCH_SIZE = 64
EPOCHS = 100
IMG_SIZE = 48
MODEL_PATH = "data/models/expression_predictor_cnn.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Data transforms
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((48, 48)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Datasets and loaders
train_dataset = datasets.ImageFolder("data/train", transform=transform)
val_dataset = datasets.ImageFolder("data/validation", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0)
# Class names
CLASSES = train_dataset.classes
NUM_CLASSES = len(CLASSES)
print(f"Classes: {CLASSES}")
# CNN Model
class ExpressionCNN(nn.Module):
def __init__(self):
super(ExpressionCNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(256), nn.AdaptiveAvgPool2d((1, 1))
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(256, NUM_CLASSES)
)
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x
model = ExpressionCNN().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Training loop
train_loss_log = []
val_loss_log = []
for epoch in range(EPOCHS):
print(f"\nStarting Epoch {epoch+1}/{EPOCHS}")
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
scheduler.step()
train_loss = running_loss / len(train_loader)
train_loss_log.append(train_loss)
# Validation
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
val_loss /= len(val_loader)
val_loss_log.append(val_loss)
accuracy = correct / total * 100
print(f"[{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {accuracy:.2f}%")
# Save model
torch.save(model.state_dict(), MODEL_PATH)
print(f"✅ Model saved to {MODEL_PATH}")
# Plot loss
plt.plot(train_loss_log, label="Train")
plt.plot(val_loss_log, label="Validation")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.savefig("training_loss_plot.png") |