from sklearn.metrics import classification_report, confusion_matrix import torch import matplotlib.pyplot as plt import seaborn as sns def calculate_accuracy(y_pred, y_true): preds = torch.argmax(y_pred, dim=1) correct = (preds == y_true).sum().item() return correct / len(y_true) def get_classification_report(y_true, y_pred, class_names): report = classification_report(y_true, y_pred, target_names=class_names) return report def get_confusion_matrix(y_true, y_pred, class_names): cm = confusion_matrix(y_true, y_pred) return cm def plot_confusion_matrix(cm, class_names): plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Confusion Matrix") plt.tight_layout() plt.show()