import torch from torchvision import transforms from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader from sklearn.metrics import classification_report, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from utils.model_loader import load_model # ----- CONFIG ----- MODEL_PATH = 'data/models/expression_predictor_cnn.pth' TEST_DIR = 'data/validation' BATCH_SIZE = 32 CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised'] # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model(MODEL_PATH, device) # Transforms transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((48, 48)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Load test data test_data = ImageFolder(TEST_DIR, transform=transform) test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False) # Predict all_preds = [] all_labels = [] with torch.no_grad(): for images, labels in test_loader: images = images.to(device) outputs = model(images) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) # Classification report report = classification_report(all_labels, all_preds, target_names=CLASSES, output_dict=True) report_df = pd.DataFrame(report).transpose() print("\nClassification Report:") print(report_df) # Confusion Matrix cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', xticklabels=CLASSES, yticklabels=CLASSES, cmap="Blues") plt.xlabel("Predicted") plt.ylabel("True") plt.title("Confusion Matrix") plt.tight_layout() plt.savefig("confusion_matrix.png") print("\nāœ… Saved confusion matrix as 'confusion_matrix.png'") # Save classification report heatmap (Precision, Recall, F1, Support) plt.figure(figsize=(12, 6)) sns.heatmap(report_df.iloc[:-1, :-1], annot=True, fmt=".2f", cmap="YlGnBu") plt.title("Classification Report (Precision, Recall, F1-score, Support)") plt.tight_layout() plt.savefig("classification_report_matrix.png") print("āœ… Saved classification_report_matrix.png")