Emotion-engine / train.py
Niranjan-ninja's picture
Deploy: AIIO Expression Analyzer - clean build (no venv)
3ee866e
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
import cv2
from skimage.feature import hog
import kagglehub
from core_model import DCNN_BiLSTM_DAM
# Dataset wrapper to apply HOG on the fly exactly as required
class FER2013HOGDataset(Dataset):
def __init__(self, image_folder_dataset, apply_augmentations=False):
self.dataset = image_folder_dataset
self.apply_augmentations = apply_augmentations
# We need transforms for augmentation before HOG
if self.apply_augmentations:
self.aug_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
])
else:
self.aug_transform = None
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
pil_image, label = self.dataset[idx]
if self.aug_transform:
pil_image = self.aug_transform(pil_image)
# Convert to numpy array (grayscale)
gray_image = np.array(pil_image.convert('L'))
resized = cv2.resize(gray_image, (64, 64))
# Apply HOG (Histogram of Oriented Gradients)
_, hog_img = hog(resized, orientations=8, pixels_per_cell=(8, 8), cells_per_block=(2, 2), visualize=True)
# Normalize
hog_normalized = hog_img / (np.max(hog_img) + 1e-5)
# Create tensor [1, 64, 64]
image_tensor = torch.tensor(hog_normalized, dtype=torch.float32).unsqueeze(0)
return image_tensor, label
def run_fer2013_training():
print("\n--- Initializing FER-2013 HOG DCNN-BiLSTM-DAM Engine ---")
# 1. Download/Locate FER2013 via KaggleHub
print("Fetching FER-2013 Dataset from Kaggle...")
dataset_path = kagglehub.dataset_download("msambare/fer2013")
print(f"Dataset located at: {dataset_path}")
train_dir = os.path.join(dataset_path, "train")
test_dir = os.path.join(dataset_path, "test")
# 'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3, 'neutral': 4, 'sad': 5, 'surprise': 6
raw_train = datasets.ImageFolder(train_dir)
raw_test = datasets.ImageFolder(test_dir)
print(f"Classes correctly mapped: {raw_train.classes}")
# Create HOG-wrapped datasets
train_dataset = FER2013HOGDataset(raw_train, apply_augmentations=True)
val_dataset = FER2013HOGDataset(raw_test, apply_augmentations=False)
epochs = 12
batch_size = 64
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Hardware Backbone: {device}")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True if torch.cuda.is_available() else False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True if torch.cuda.is_available() else False)
model = DCNN_BiLSTM_DAM(num_classes=7).to(device)
# Label smoothing & AdamW for robustness
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
best_val_acc = 0.0
print(f"Training on {len(train_dataset)} images. Validating on {len(val_dataset)} images.")
# Create models directory
os.makedirs("./models", exist_ok=True)
for epoch in range(1, epochs + 1):
model.train()
running_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 0:
print(f" Batch {i}/{len(train_loader)} Loss: {loss.item():.4f}")
# Validation Loop
model.eval()
correct = 0; total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_acc = 100 * correct / total
epoch_loss = running_loss / len(train_loader)
current_lr = optimizer.param_groups[0]['lr']
scheduler.step()
print(f"--> Epoch [{epoch}/{epochs}] | LR: {current_lr:.5f} | Avg Loss: {epoch_loss:.4f} | Validation Acc: {val_acc:.2f}%")
if val_acc >= best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), './models/best_model_dcnn_dam.pth')
print(f"🌟 New best model saved! Validation Acc: {best_val_acc:.2f}%")
print(f"\n✅ Full FER-2013 Training Complete. Max Acc Reached: {best_val_acc:.2f}%")
if __name__ == "__main__":
run_fer2013_training()