File size: 2,229 Bytes
69535bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from utils.model_loader import load_model
# ----- CONFIG -----
MODEL_PATH = 'data/models/expression_predictor_cnn.pth'
TEST_DIR = 'data/validation'
BATCH_SIZE = 32
CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised']
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(MODEL_PATH, device)
# Transforms
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load test data
test_data = ImageFolder(TEST_DIR, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
# Predict
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.numpy())
# Classification report
report = classification_report(all_labels, all_preds, target_names=CLASSES, output_dict=True)
report_df = pd.DataFrame(report).transpose()
print("\nClassification Report:")
print(report_df)
# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=CLASSES, yticklabels=CLASSES, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix.png")
print("\n✅ Saved confusion matrix as 'confusion_matrix.png'")
# Save classification report heatmap (Precision, Recall, F1, Support)
plt.figure(figsize=(12, 6))
sns.heatmap(report_df.iloc[:-1, :-1], annot=True, fmt=".2f", cmap="YlGnBu")
plt.title("Classification Report (Precision, Recall, F1-score, Support)")
plt.tight_layout()
plt.savefig("classification_report_matrix.png")
print("✅ Saved classification_report_matrix.png") |