File size: 5,279 Bytes
94d2494 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """
src.visualisations.common β Shared chart builders (used by both classification and regression).
"""
from __future__ import annotations
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
from src import config
_PLOTLY_TEMPLATE = "plotly_dark"
_FONT_FAMILY = "Inter, sans-serif"
def _base_layout(**kwargs) -> dict:
return dict(
template=_PLOTLY_TEMPLATE,
font=dict(family=_FONT_FAMILY, size=13),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
**kwargs,
)
def correlation_heatmap(data: pd.DataFrame) -> plt.Figure:
"""Return a Matplotlib figure with a styled correlation heatmap."""
fig, ax = plt.subplots(figsize=(max(8, data.shape[1]), max(6, data.shape[1] - 1)))
fig.patch.set_facecolor("#0E1117")
ax.set_facecolor("#0E1117")
mask = np.triu(np.ones_like(data.corr(), dtype=bool))
sns.heatmap(
data.corr(), mask=mask, annot=True, fmt=".2f", cmap="coolwarm",
center=0, ax=ax, annot_kws={"size": 9, "color": "white"},
linewidths=0.5, linecolor="#1E2130",
)
ax.tick_params(colors="white")
for spine in ax.spines.values():
spine.set_edgecolor("#1E2130")
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", color="white")
plt.setp(ax.get_yticklabels(), color="white")
ax.set_title("Feature Correlation Matrix", color="white", fontsize=14, pad=12)
plt.tight_layout()
return fig
def feature_distributions(data: pd.DataFrame, target_col: Optional[str] = None) -> go.Figure:
"""Interactive histograms for all numeric columns."""
num_cols = data.select_dtypes(include="number").columns.tolist()
if target_col and target_col in num_cols:
num_cols = [c for c in num_cols if c != target_col]
n_cols = 3
n_rows = max(1, (len(num_cols) + n_cols - 1) // n_cols)
from plotly.subplots import make_subplots
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=num_cols)
for i, col in enumerate(num_cols):
row, col_idx = divmod(i, n_cols)
trace = go.Histogram(
x=data[col], name=col,
marker_color=config.COLOR_PALETTE[i % len(config.COLOR_PALETTE)],
opacity=0.85,
)
fig.add_trace(trace, row=row + 1, col=col_idx + 1)
fig.update_layout(
**_base_layout(title_text="Feature Distributions", showlegend=False),
height=n_rows * 250,
)
return fig
def model_comparison_chart(results_df: pd.DataFrame) -> go.Figure:
"""Grouped bar chart comparing all metrics across models."""
df = results_df.reset_index() if "Model" not in results_df.columns else results_df.copy()
if "Model" not in df.columns and df.index.name == "Model":
df = df.reset_index()
metric_cols = [c for c in df.columns if c not in ("Model", "CV Score", "index")]
melted = df.melt(id_vars="Model", value_vars=metric_cols, var_name="Metric", value_name="Score")
fig = px.bar(
melted.dropna(subset=["Score"]),
x="Model", y="Score", color="Metric", barmode="group",
color_discrete_sequence=config.COLOR_PALETTE,
title="π Model Performance Comparison",
labels={"Score": "Score", "Model": ""},
)
fig.update_layout(
**_base_layout(legend_title_text="Metric"),
xaxis=dict(gridcolor="#2A2E3F"),
)
return fig
def feature_importance_chart(feature_names: List[str], importances: np.ndarray) -> go.Figure:
"""Horizontal bar chart of feature importances."""
df = pd.DataFrame({"Feature": feature_names, "Importance": importances}).sort_values("Importance", ascending=True)
fig = px.bar(
df, x="Importance", y="Feature", orientation="h",
color="Importance", color_continuous_scale=["#6C63FF", "#FF6584"],
title="π Feature Importances",
)
fig.update_layout(
**_base_layout(coloraxis_showscale=False),
xaxis=dict(gridcolor="#2A2E3F"),
yaxis=dict(gridcolor="#2A2E3F"),
)
return fig
def target_distribution_chart(y: pd.Series, label_encoders: dict, target_col: str) -> go.Figure:
"""Pie chart of target class distribution."""
counts = y.value_counts().reset_index()
counts.columns = ["Class", "Count"]
if target_col in label_encoders:
le = label_encoders[target_col]
counts["Class"] = le.inverse_transform(counts["Class"].astype(int))
fig = px.pie(
counts, names="Class", values="Count",
color_discrete_sequence=config.COLOR_PALETTE,
title="π― Target Class Distribution", hole=0.4,
)
fig.update_layout(**_base_layout())
return fig
def target_histogram_chart(y: pd.Series, target_col: str) -> go.Figure:
"""Histogram of a continuous target variable."""
fig = px.histogram(
x=y, nbins=40, color_discrete_sequence=["#6C63FF"],
title=f"π― Target Distribution β {target_col}",
labels={"x": target_col, "y": "Count"},
)
fig.update_layout(
**_base_layout(showlegend=False),
xaxis=dict(gridcolor="#2A2E3F"),
yaxis=dict(gridcolor="#2A2E3F"),
)
return fig
|