File size: 2,128 Bytes
a7d529a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# backend/utils.py
import os
import json
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc

# Ensure directories exist
# Make sure multiple directories exist
def ensure_dirs(*paths):
    for path in paths:
        os.makedirs(path, exist_ok=True)

# Save JSON report
def save_json(data, path):
    with open(path, "w") as f:
        json.dump(data, f, indent=4)

# Save model pickle
def save_model(model, path):
    joblib.dump(model, path)

# Load model pickle
def load_model(path):
    return joblib.load(path)

# Plot Confusion Matrix
def plot_cm(y_true, y_pred, title="Confusion Matrix", save_path=None):
    """Plot confusion matrix and optionally save to file."""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_roc(y_true, y_proba, title="ROC Curve", save_path=None):
    """Plot ROC curve and optionally save to file."""
    fpr, tpr, _ = roc_curve(y_true, y_proba)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
    plt.plot([0, 1], [0, 1], "r--")
    plt.title(title)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def barplot_metric(df, metric, save_path, title):
    plt.figure(figsize=(8,5))
    sns.barplot(x="Model", y=metric, data=df)
    plt.title(title)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()


def lineplot_curves(curves, ylabel, title, save_path):
    plt.figure()
    for label, values in curves.items():
        plt.plot(values, label=label)
    plt.xlabel("Iterations")
    plt.ylabel(ylabel)   # Now ylabel is separate
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()