expression-recognition / evaluation.py
allantacuelwvsu's picture
added evaluations
69535bd
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from utils.model_loader import load_model
# ----- CONFIG -----
MODEL_PATH = 'data/models/expression_predictor_cnn.pth'
TEST_DIR = 'data/validation'
BATCH_SIZE = 32
CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised']
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(MODEL_PATH, device)
# Transforms
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load test data
test_data = ImageFolder(TEST_DIR, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
# Predict
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.numpy())
# Classification report
report = classification_report(all_labels, all_preds, target_names=CLASSES, output_dict=True)
report_df = pd.DataFrame(report).transpose()
print("\nClassification Report:")
print(report_df)
# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=CLASSES, yticklabels=CLASSES, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png")
print("\nβœ… Saved confusion matrix as 'confusion_matrix.png'")
# Save classification report heatmap (Precision, Recall, F1, Support)
plt.figure(figsize=(12, 6))
sns.heatmap(report_df.iloc[:-1, :-1], annot=True, fmt=".2f", cmap="YlGnBu")
plt.title("Classification Report (Precision, Recall, F1-score, Support)")
plt.tight_layout()
plt.savefig("classification_report_matrix.png")
print("βœ… Saved classification_report_matrix.png")