|
|
import torch |
|
|
from torchvision import transforms |
|
|
from torchvision.datasets import ImageFolder |
|
|
from torch.utils.data import DataLoader |
|
|
from sklearn.metrics import classification_report, confusion_matrix |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import pandas as pd |
|
|
from utils.model_loader import load_model |
|
|
|
|
|
|
|
|
MODEL_PATH = 'data/models/expression_predictor_cnn.pth' |
|
|
TEST_DIR = 'data/validation' |
|
|
BATCH_SIZE = 32 |
|
|
CLASSES = ['Angry', 'Disgust', 'Scared', 'Happy', 'Neutral', 'Sad', 'Surprised'] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = load_model(MODEL_PATH, device) |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Grayscale(), |
|
|
transforms.Resize((48, 48)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5,), (0.5,)) |
|
|
]) |
|
|
|
|
|
|
|
|
test_data = ImageFolder(TEST_DIR, transform=transform) |
|
|
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False) |
|
|
|
|
|
|
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
with torch.no_grad(): |
|
|
for images, labels in test_loader: |
|
|
images = images.to(device) |
|
|
outputs = model(images) |
|
|
_, preds = torch.max(outputs, 1) |
|
|
all_preds.extend(preds.cpu().numpy()) |
|
|
all_labels.extend(labels.numpy()) |
|
|
|
|
|
|
|
|
report = classification_report(all_labels, all_preds, target_names=CLASSES, output_dict=True) |
|
|
report_df = pd.DataFrame(report).transpose() |
|
|
print("\nClassification Report:") |
|
|
print(report_df) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_preds) |
|
|
plt.figure(figsize=(10, 8)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', xticklabels=CLASSES, yticklabels=CLASSES, cmap="Blues") |
|
|
plt.xlabel("Predicted") |
|
|
plt.ylabel("True") |
|
|
plt.title("Confusion Matrix") |
|
|
plt.tight_layout() |
|
|
plt.savefig("confusion_matrix.png") |
|
|
print("\nβ
Saved confusion matrix as 'confusion_matrix.png'") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
sns.heatmap(report_df.iloc[:-1, :-1], annot=True, fmt=".2f", cmap="YlGnBu") |
|
|
plt.title("Classification Report (Precision, Recall, F1-score, Support)") |
|
|
plt.tight_layout() |
|
|
plt.savefig("classification_report_matrix.png") |
|
|
print("β
Saved classification_report_matrix.png") |