Spaces:
Sleeping
Sleeping
File size: 2,143 Bytes
eacbcc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score
from model import create_fulla_model
from utils import create_dataloaders
# ๐ Main Evaluation Script
if __name__ == "__main__":
# ๐ฆ Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ๐งช Load the test data
_, _, test_loader = create_dataloaders()
# ๐ Load the trained model
model = create_fulla_model()
# Load the saved weights from your .pth file
model.load_state_dict(torch.load("../fulla_model.pth"))
model.to(device)
model.eval() # Set model to evaluation mode
# ๐โโ๏ธ Run Inference and Collect Prediction
y_true = []
y_pred = []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
print("Finished inference on the test set.")
# ๐ Calculate accuracy
accuracy = (
100 * sum(1 for i in range(len(y_true)) if y_true[i] == y_pred[i]) / len(y_true)
)
print(f"\nFinal Test Accuracy: {accuracy:.2f}%")
# ๐งฎ Calculate F1 Score
f1 = f1_score(
y_true, y_pred, average="weighted"
) # 'weighted' accounts for any imbalance in the number of samples per class
print(f"Final F1 Score: {f1:.4f}")
# ๐จ Plot Confusion Matrix
print("\nGenerating confusion matrix...")
conf_matrix = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(20, 20)) # Increase figure size for 102 classes
sns.heatmap(
conf_matrix, annot=False
) # Annotations off for clarity with many classes
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.savefig("../confusion_matrix.png") # Save the plot as a file
print("Confusion matrix saved to confusion_matrix.png")
|