Spaces:
Sleeping
Sleeping
File size: 2,787 Bytes
9c6b905 279af50 9c6b905 279af50 9c6b905 279af50 9c6b905 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from sklearn.metrics import classification_report
import mlflow
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
import numpy as np
import numpy as np
def log_metrics(y_true, y_pred, mode):
"""Log evaluation metrics."""
precision = classification_report(y_true, y_pred, output_dict=True)['weighted avg']['precision']
recall = classification_report(y_true, y_pred, output_dict=True)['weighted avg']['recall']
f1_score = classification_report(y_true, y_pred, output_dict=True)['weighted avg']['f1-score']
mlflow.log_metric(f"{mode}_precision", precision)
mlflow.log_metric(f"{mode}_recall", recall)
mlflow.log_metric(f"{mode}_f1_score", f1_score)
def plot_roc_curve(y_true, y_pred_prob, mode, class_names=None):
"""
Plot ROC curve for binary or multi-class classification.
Args:
y_true: True labels (array-like).
y_pred_prob: Predicted probabilities (array-like).
mode: Mode of classification ('binary' or 'multi-class').
class_names: List of class names (optional, required for multi-class).
"""
plt.figure(figsize=(10, 7))
if mode == 'binary':
fpr, tpr, _ = roc_curve(y_true, y_pred_prob[:, 1])
auc_score = roc_auc_score(y_true, y_pred_prob[:, 1])
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (area = {auc_score:.2f})")
else:
for i, class_name in enumerate(class_names):
fpr, tpr, _ = roc_curve(y_true == i, y_pred_prob[:, i])
auc_score = roc_auc_score(y_true == i, y_pred_prob[:, i])
plt.plot(fpr, tpr, lw=2, label=f"Class {class_name} (AUC = {auc_score:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"ROC Curve ({mode})")
plt.legend(loc="lower right")
plt.tight_layout()
plt.show()
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
def plot_confusion_matrix(y_true, y_pred, class_names, mode):
"""
Plot confusion matrix for binary or multi-class classification.
Args:
y_true: True labels (array-like).
y_pred: Predicted labels (array-like).
class_names: List of class names.
mode: Mode of classification ('binary' or 'multi-class').
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.title(f"Confusion Matrix ({mode})")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.show()
|