AlzheimersXEquity / evaluate.py
noah34's picture
Create evaluate.py
50e0989 verified
from datasets import load_dataset
from transformers import pipeline
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay
)
import matplotlib.pyplot as plt
# Label mapping from the dataset
LABEL_NAMES = ["Mild_Demented", "Moderate_Demented", "Non_Demented", "Very_Mild_Demented"]
# Load the test split
print("Loading dataset...")
dataset = load_dataset("Falah/Alzheimer_MRI", split="test")
# Load your model
print("Loading model...")
clf = pipeline("image-classification", model="Thamer/resnet-fine_tuned")
# Run inference on every test image
y_true, y_pred = [], []
for i, example in enumerate(dataset):
img = example["image"].convert("RGB")
preds = clf(img)
top = max(preds, key=lambda x: x["score"])["label"]
true = LABEL_NAMES[example["label"]]
y_true.append(true)
y_pred.append(top)
if i % 100 == 0:
print(f" {i}/{len(dataset)} done...")
# Print metrics
print("\n--- Results ---")
print(f"Accuracy: {accuracy_score(y_true, y_pred):.2%}")
print(f"Precision: {precision_score(y_true, y_pred, average='macro', zero_division=0):.2%}")
print(f"Recall: {recall_score(y_true, y_pred, average='macro', zero_division=0):.2%}")
print(f"F1: {f1_score(y_true, y_pred, average='macro', zero_division=0):.2%}")
# Save confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=LABEL_NAMES)
disp = ConfusionMatrixDisplay(cm, display_labels=LABEL_NAMES)
fig, ax = plt.subplots(figsize=(7, 7))
disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format="d")
plt.title("Confusion Matrix — ResNet50")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.savefig("confusion_matrix.png", dpi=200)
print("\nSaved confusion_matrix.png")