afshin-dini's picture
Update the margins
3eaf500
"""PCA / t-SNE plot plugin for AI-Dashboard"""
# pylint: disable=R0801
from typing import Any, List
import logging
import plotly.express as px
from dash import html, dcc
from dash.dependencies import Input, Output
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from ..base import BasePlotPlugin
logger = logging.getLogger(__name__)
class PCATSNEPlotPlugin(BasePlotPlugin):
"""Dimensionality reduction plugin (PCA / t-SNE)."""
name = "PCA / t-SNE Plot"
def dropdown(
self, id_suffix: str, label: str, options: List[str], multi: bool = False
) -> Any:
"""Create a dropdown control."""
return html.Div(
[
html.Label(label),
dcc.Dropdown(
id={"type": "control", "plot": self.name, "axis": id_suffix},
options=[{"label": c, "value": c} for c in options], # type: ignore
value=(options[:3] if multi else options[0]),
clearable=False,
multi=multi,
persistence=True,
persistence_type="memory",
style={"color": "#000"},
),
],
style={"width": "130px"},
)
def controls(self) -> Any:
nums = self.numeric_columns()
cats = self.categorical_columns()
method_options = ["PCA", "t-SNE"]
return html.Div(
[
self.dropdown("method", "Method", method_options),
self.dropdown("cols", "Numeric Columns", nums, multi=True),
self.dropdown("color", "Color By", nums + cats),
],
style={"display": "flex", "flexDirection": "column"},
)
def render(self, **kwargs: Any) -> Any:
method = kwargs.get("method_axis", "PCA")
cols = kwargs.get("cols_axis", [])
color = kwargs.get("color_axis", None)
if not cols:
return html.Div(
"Select at least one numeric column.", style={"color": "red"}
)
data_value = self.dataframe[cols].values
if method == "t-SNE":
emb = TSNE(
n_components=2, init="random", learning_rate="auto"
).fit_transform(data_value)
comp_df = self.dataframe.copy()
comp_df["dim1"] = emb[:, 0]
comp_df["dim2"] = emb[:, 1]
fig = px.scatter(comp_df, x="dim1", y="dim2", color=color)
else:
pca = PCA(n_components=2)
emb = pca.fit_transform(data_value)
comp_df = self.dataframe.copy()
comp_df["PC1"] = emb[:, 0]
comp_df["PC2"] = emb[:, 1]
fig = px.scatter(comp_df, x="PC1", y="PC2", color=color)
return dcc.Graph(figure=fig)
def register_callbacks(self, app: Any) -> None:
@app.callback( # type: ignore
Output({"type": "plot-output", "plot": self.name}, "children"),
Input({"type": "control", "plot": self.name, "axis": "method"}, "value"),
Input({"type": "control", "plot": self.name, "axis": "cols"}, "value"),
Input({"type": "control", "plot": self.name, "axis": "color"}, "value"),
)
def update(method_axis: str, cols_axis: List[str], color_axis: str) -> Any:
return self.render(
method_axis=method_axis, cols_axis=cols_axis, color_axis=color_axis
)