GouravSinghThakur
Initial commit: Intelligent AutoML Studio with 14 algorithms (7 clf + 7 reg)
94d2494 | """ | |
| src.visualisations.classification – Classification-specific charts. | |
| """ | |
| from __future__ import annotations | |
| from typing import Dict, Optional | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix, roc_curve, auc | |
| from sklearn.pipeline import Pipeline | |
| from src import config | |
| from src.visualisations.common import _base_layout | |
| def roc_curves_chart( | |
| fitted_models: Dict[str, Pipeline], | |
| X_test: pd.DataFrame, | |
| y_test: pd.Series, | |
| ) -> Optional[go.Figure]: | |
| """Return a Plotly ROC-curve figure (binary classification only).""" | |
| if len(np.unique(y_test)) != 2: | |
| return None | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=[0, 1], y=[0, 1], mode="lines", | |
| line=dict(dash="dash", color="gray"), | |
| name="Random Classifier", showlegend=True, | |
| )) | |
| for i, (name, pipeline) in enumerate(fitted_models.items()): | |
| if not hasattr(pipeline, "predict_proba"): | |
| continue | |
| try: | |
| y_prob = pipeline.predict_proba(X_test)[:, 1] | |
| fpr, tpr, _ = roc_curve(y_test, y_prob) | |
| roc_auc = auc(fpr, tpr) | |
| fig.add_trace(go.Scatter( | |
| x=fpr, y=tpr, mode="lines", | |
| name=f"{name} (AUC = {roc_auc:.3f})", | |
| line=dict(color=config.COLOR_PALETTE[i % len(config.COLOR_PALETTE)], width=2), | |
| )) | |
| except Exception: | |
| pass | |
| fig.update_layout( | |
| **_base_layout(title="🎯 ROC Curves"), | |
| xaxis=dict(title="False Positive Rate", range=[0, 1], gridcolor="#2A2E3F"), | |
| yaxis=dict(title="True Positive Rate", range=[0, 1.05], gridcolor="#2A2E3F"), | |
| legend=dict(x=0.6, y=0.1), | |
| ) | |
| return fig | |
| def confusion_matrices_chart( | |
| fitted_models: Dict[str, Pipeline], | |
| X_test: pd.DataFrame, | |
| y_test: pd.Series, | |
| ) -> plt.Figure: | |
| """Matplotlib figure with one confusion-matrix heatmap per model.""" | |
| n = len(fitted_models) | |
| n_cols = min(2, n) | |
| n_rows = (n + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 5 * n_rows)) | |
| fig.patch.set_facecolor("#0E1117") | |
| axes_flat = np.array(axes).ravel() if n > 1 else [axes] | |
| for ax, (name, pipeline) in zip(axes_flat, fitted_models.items()): | |
| y_pred = pipeline.predict(X_test) | |
| cm = confusion_matrix(y_test, y_pred) | |
| sns.heatmap( | |
| cm, annot=True, fmt="d", cmap="Blues", ax=ax, | |
| linewidths=0.5, linecolor="#1E2130", annot_kws={"color": "white"}, | |
| ) | |
| ax.set_facecolor("#0E1117") | |
| ax.set_title(name, color="white", fontsize=13) | |
| ax.set_xlabel("Predicted", color="white") | |
| ax.set_ylabel("Actual", color="white") | |
| ax.tick_params(colors="white") | |
| for ax in axes_flat[n:]: | |
| ax.set_visible(False) | |
| plt.tight_layout() | |
| return fig | |