File size: 2,787 Bytes
9c6b905
 
279af50
 
 
 
9c6b905
 
 
 
279af50
 
 
 
 
9c6b905
 
 
 
 
 
279af50
9c6b905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from sklearn.metrics import classification_report
import mlflow
    
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
import numpy as np
 
import numpy as np


def log_metrics(y_true, y_pred, mode):
    """Log evaluation metrics."""
    precision = classification_report(y_true, y_pred, output_dict=True)['weighted avg']['precision']
    recall = classification_report(y_true, y_pred, output_dict=True)['weighted avg']['recall']
    f1_score = classification_report(y_true, y_pred, output_dict=True)['weighted avg']['f1-score']

    mlflow.log_metric(f"{mode}_precision", precision)
    mlflow.log_metric(f"{mode}_recall", recall)
    mlflow.log_metric(f"{mode}_f1_score", f1_score)




def plot_roc_curve(y_true, y_pred_prob, mode, class_names=None):
    """

    Plot ROC curve for binary or multi-class classification.



    Args:

        y_true: True labels (array-like).

        y_pred_prob: Predicted probabilities (array-like).

        mode: Mode of classification ('binary' or 'multi-class').

        class_names: List of class names (optional, required for multi-class).

    """
    plt.figure(figsize=(10, 7))

    if mode == 'binary':
        fpr, tpr, _ = roc_curve(y_true, y_pred_prob[:, 1])
        auc_score = roc_auc_score(y_true, y_pred_prob[:, 1])
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC curve (area = {auc_score:.2f})")
    else:
        for i, class_name in enumerate(class_names):
            fpr, tpr, _ = roc_curve(y_true == i, y_pred_prob[:, i])
            auc_score = roc_auc_score(y_true == i, y_pred_prob[:, i])
            plt.plot(fpr, tpr, lw=2, label=f"Class {class_name} (AUC = {auc_score:.2f})")

    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({mode})")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.show()



import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(y_true, y_pred, class_names, mode):
    """

    Plot confusion matrix for binary or multi-class classification.



    Args:

        y_true: True labels (array-like).

        y_pred: Predicted labels (array-like).

        class_names: List of class names.

        mode: Mode of classification ('binary' or 'multi-class').

    """
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title(f"Confusion Matrix ({mode})")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.tight_layout()
    plt.show()