hackathon-4 / expression_recognition.py
care2achieve's picture
expression model train
400e9a1
Raw
History Blame Contribute Delete
14.5 kB
# enhanced_expression_recognition.py
import os
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from collections import Counter
torch.set_float32_matmul_precision('high')
# ============================================================
# CONFIG
# ============================================================
DATASET_DIR = "./data/Expression_data"
TRAIN_DIR = os.path.join(
DATASET_DIR,
"Facial_expression_train"
)
TEST_DIR = os.path.join(
DATASET_DIR,
"Facial_expression_test"
)
IMAGE_SIZE = 72
BATCH_SIZE = 64
NUM_EPOCHS = 60
LEARNING_RATE = 0.001
NUM_WORKERS = 0
MODEL_SAVE_PATH = "./models/expression_model.pth"
PATIENCE = 12
# ============================================================
# DOWNLOAD DATASET
# ============================================================
import urllib.request
import zipfile
def download_required_files():
dataset_url = (
"https://cdn.talentsprint.com/"
"aiml/Experiment_related_data/"
"Expression_data.zip"
)
os.makedirs("./data", exist_ok=True)
os.makedirs("./models", exist_ok=True)
dataset_zip = "./data/Expression_data.zip"
if not os.path.exists(dataset_zip):
print("Downloading dataset...")
urllib.request.urlretrieve(
dataset_url,
dataset_zip
)
print("Dataset downloaded")
else:
print("Dataset zip already exists")
if not os.path.exists(DATASET_DIR):
print("Extracting dataset...")
with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
zip_ref.extractall("./data")
print("Dataset extracted")
else:
print("Dataset already extracted")
# ============================================================
# DEVICE SETUP
# ============================================================
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
device = torch.device("cuda")
print("Using CUDA GPU")
else:
device = torch.device("cpu")
print("Using CPU")
# ============================================================
# DATASET
# ============================================================
class ExpressionDataset(Dataset):
def __init__(self, image_folder, transform=None):
self.image_folder = image_folder
self.transform = transform
self.image_paths = []
self.labels = []
self.class_to_idx = {}
self.idx_to_class = {}
self.classes = sorted([
d for d in os.listdir(image_folder)
if os.path.isdir(
os.path.join(image_folder, d)
)
])
for idx, class_name in enumerate(self.classes):
self.class_to_idx[class_name] = idx
self.idx_to_class[idx] = class_name
class_dir = os.path.join(
image_folder,
class_name
)
for file_name in os.listdir(class_dir):
if file_name.lower().endswith(
(".jpg", ".jpeg", ".png")
):
self.image_paths.append(
os.path.join(
class_dir,
file_name
)
)
self.labels.append(idx)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
image_path = self.image_paths[index]
label = self.labels[index]
image = Image.open(
image_path
).convert("L")
if self.transform:
image = self.transform(image)
return image, label
# ============================================================
# TRANSFORMS
# ============================================================
train_transforms = transforms.Compose([
transforms.Resize(
(IMAGE_SIZE, IMAGE_SIZE)
),
transforms.RandomHorizontalFlip(
p=0.5
),
transforms.RandomRotation(
2
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5],
std=[0.5]
)
])
val_transforms = transforms.Compose([
transforms.Resize(
(IMAGE_SIZE, IMAGE_SIZE)
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5],
std=[0.5]
)
])
# ============================================================
# RESIDUAL BLOCK
# ============================================================
class ResidualBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride=1
):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn1 = nn.BatchNorm2d(
out_channels
)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(
out_channels
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(
out_channels
)
)
def forward(self, x):
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
# ============================================================
# MODEL
# ============================================================
class ExpressionCNN(nn.Module):
def __init__(self, num_classes=7):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(
1,
32,
kernel_size=3,
padding=1
),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
ResidualBlock(32, 64),
nn.MaxPool2d(2),
ResidualBlock(64, 128),
nn.MaxPool2d(2),
ResidualBlock(128, 256),
nn.MaxPool2d(2),
ResidualBlock(256, 512),
nn.AdaptiveAvgPool2d((1, 1))
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# ============================================================
# PREDICTION
# ============================================================
def predict_expression(
model,
image_path,
transform,
idx_to_class
):
image = Image.open(
image_path
).convert("L")
tensor = transform(image)
tensor = tensor.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
probabilities = torch.softmax(
outputs,
dim=1
)
confidence, predicted = torch.max(
probabilities,
1
)
predicted_class = idx_to_class[
predicted.item()
]
return predicted_class, confidence.item()
# ============================================================
# MAIN
# ============================================================
def main():
download_required_files()
print("\nLoading dataset...\n")
full_train_dataset = ExpressionDataset(
TRAIN_DIR,
transform=train_transforms
)
full_val_dataset = ExpressionDataset(
TRAIN_DIR,
transform=val_transforms
)
indices = np.arange(
len(full_train_dataset)
)
np.random.shuffle(indices)
val_size = int(
0.1 * len(indices)
)
val_indices = indices[:val_size]
train_indices = indices[val_size:]
train_dataset = Subset(
full_train_dataset,
train_indices
)
val_dataset = Subset(
full_val_dataset,
val_indices
)
print(
f"Training images: {len(train_dataset)}"
)
print(
f"Validation images: {len(val_dataset)}"
)
print(
f"Classes: "
f"{full_train_dataset.classes}"
)
# ========================================================
# CLASS WEIGHTS
# ========================================================
label_counts = Counter(
full_train_dataset.labels
)
weights = []
for i in range(
len(full_train_dataset.classes)
):
weight = np.sqrt(
len(full_train_dataset.labels)
/ label_counts[i]
)
weights.append(weight)
weights = torch.tensor(
weights,
dtype=torch.float32
).to(device)
# ========================================================
# DATALOADERS
# ========================================================
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
# ========================================================
# MODEL
# ========================================================
model = ExpressionCNN(
num_classes=len(
full_train_dataset.classes
)
).to(device)
print(
f"\nModel Device: "
f"{next(model.parameters()).device}"
)
criterion = nn.CrossEntropyLoss(
weight=weights,
label_smoothing=0.1
)
optimizer = optim.Adam(
model.parameters(),
lr=LEARNING_RATE
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.5,
patience=4
)
# ========================================================
# TRAINING LOOP
# ========================================================
best_accuracy = 0.0
epochs_without_improvement = 0
print("\nStarting Training...\n")
for epoch in range(NUM_EPOCHS):
model.train()
correct_train = 0
total_train = 0
train_bar = tqdm(train_loader)
for images, labels in train_bar:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(
outputs,
labels
)
loss.backward()
optimizer.step()
_, predicted = torch.max(
outputs,
1
)
total_train += labels.size(0)
correct_train += (
predicted == labels
).sum().item()
train_bar.set_description(
f"Epoch {epoch+1}/{NUM_EPOCHS} "
f"Loss: {loss.item():.4f}"
)
train_accuracy = (
100 * correct_train / total_train
)
# ====================================================
# VALIDATION
# ====================================================
model.eval()
correct_val = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(
outputs,
1
)
total_val += labels.size(0)
correct_val += (
predicted == labels
).sum().item()
val_accuracy = (
100 * correct_val / total_val
)
scheduler.step(val_accuracy)
print(
f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
f"Train Accuracy: "
f"{train_accuracy:.2f}% | "
f"Validation Accuracy: "
f"{val_accuracy:.2f}%"
)
# ====================================================
# SAVE BEST MODEL
# ====================================================
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
epochs_without_improvement = 0
torch.save(
{
"model_state_dict":
model.state_dict(),
"class_to_idx":
full_train_dataset.class_to_idx,
"idx_to_class":
full_train_dataset.idx_to_class,
"accuracy":
best_accuracy
},
MODEL_SAVE_PATH
)
print(
f"Best model saved "
f"with accuracy: "
f"{best_accuracy:.2f}%"
)
else:
epochs_without_improvement += 1
# ====================================================
# EARLY STOPPING
# ====================================================
if epochs_without_improvement >= PATIENCE:
print("\nEarly stopping triggered")
break
print("\nTraining Complete")
print(
f"Best Validation Accuracy: "
f"{best_accuracy:.2f}%"
)
if __name__ == "__main__":
import multiprocessing
multiprocessing.freeze_support()
main()