File size: 2,938 Bytes
94d2494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
src.visualisations.classification – Classification-specific charts.
"""

from __future__ import annotations

from typing import Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.pipeline import Pipeline

from src import config
from src.visualisations.common import _base_layout


def roc_curves_chart(
    fitted_models: Dict[str, Pipeline],
    X_test: pd.DataFrame,
    y_test: pd.Series,
) -> Optional[go.Figure]:
    """Return a Plotly ROC-curve figure (binary classification only)."""
    if len(np.unique(y_test)) != 2:
        return None

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=[0, 1], y=[0, 1], mode="lines",
        line=dict(dash="dash", color="gray"),
        name="Random Classifier", showlegend=True,
    ))

    for i, (name, pipeline) in enumerate(fitted_models.items()):
        if not hasattr(pipeline, "predict_proba"):
            continue
        try:
            y_prob = pipeline.predict_proba(X_test)[:, 1]
            fpr, tpr, _ = roc_curve(y_test, y_prob)
            roc_auc = auc(fpr, tpr)
            fig.add_trace(go.Scatter(
                x=fpr, y=tpr, mode="lines",
                name=f"{name} (AUC = {roc_auc:.3f})",
                line=dict(color=config.COLOR_PALETTE[i % len(config.COLOR_PALETTE)], width=2),
            ))
        except Exception:
            pass

    fig.update_layout(
        **_base_layout(title="🎯 ROC Curves"),
        xaxis=dict(title="False Positive Rate", range=[0, 1], gridcolor="#2A2E3F"),
        yaxis=dict(title="True Positive Rate", range=[0, 1.05], gridcolor="#2A2E3F"),
        legend=dict(x=0.6, y=0.1),
    )
    return fig


def confusion_matrices_chart(
    fitted_models: Dict[str, Pipeline],
    X_test: pd.DataFrame,
    y_test: pd.Series,
) -> plt.Figure:
    """Matplotlib figure with one confusion-matrix heatmap per model."""
    n = len(fitted_models)
    n_cols = min(2, n)
    n_rows = (n + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 5 * n_rows))
    fig.patch.set_facecolor("#0E1117")

    axes_flat = np.array(axes).ravel() if n > 1 else [axes]

    for ax, (name, pipeline) in zip(axes_flat, fitted_models.items()):
        y_pred = pipeline.predict(X_test)
        cm = confusion_matrix(y_test, y_pred)
        sns.heatmap(
            cm, annot=True, fmt="d", cmap="Blues", ax=ax,
            linewidths=0.5, linecolor="#1E2130", annot_kws={"color": "white"},
        )
        ax.set_facecolor("#0E1117")
        ax.set_title(name, color="white", fontsize=13)
        ax.set_xlabel("Predicted", color="white")
        ax.set_ylabel("Actual", color="white")
        ax.tick_params(colors="white")

    for ax in axes_flat[n:]:
        ax.set_visible(False)

    plt.tight_layout()
    return fig