neecat's picture
Upload 5088 files
84d0c9e verified
from sklearn.metrics import classification_report, confusion_matrix
import torch
import matplotlib.pyplot as plt
import seaborn as sns
def calculate_accuracy(y_pred, y_true):
preds = torch.argmax(y_pred, dim=1)
correct = (preds == y_true).sum().item()
return correct / len(y_true)
def get_classification_report(y_true, y_pred, class_names):
report = classification_report(y_true, y_pred, target_names=class_names)
return report
def get_confusion_matrix(y_true, y_pred, class_names):
cm = confusion_matrix(y_true, y_pred)
return cm
def plot_confusion_matrix(cm, class_names):
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()