"""PCA / t-SNE plot plugin for AI-Dashboard""" # pylint: disable=R0801 from typing import Any, List import logging import plotly.express as px from dash import html, dcc from dash.dependencies import Input, Output from sklearn.decomposition import PCA from sklearn.manifold import TSNE from ..base import BasePlotPlugin logger = logging.getLogger(__name__) class PCATSNEPlotPlugin(BasePlotPlugin): """Dimensionality reduction plugin (PCA / t-SNE).""" name = "PCA / t-SNE Plot" def dropdown( self, id_suffix: str, label: str, options: List[str], multi: bool = False ) -> Any: """Create a dropdown control.""" return html.Div( [ html.Label(label), dcc.Dropdown( id={"type": "control", "plot": self.name, "axis": id_suffix}, options=[{"label": c, "value": c} for c in options], # type: ignore value=(options[:3] if multi else options[0]), clearable=False, multi=multi, persistence=True, persistence_type="memory", style={"color": "#000"}, ), ], style={"width": "130px"}, ) def controls(self) -> Any: nums = self.numeric_columns() cats = self.categorical_columns() method_options = ["PCA", "t-SNE"] return html.Div( [ self.dropdown("method", "Method", method_options), self.dropdown("cols", "Numeric Columns", nums, multi=True), self.dropdown("color", "Color By", nums + cats), ], style={"display": "flex", "flexDirection": "column"}, ) def render(self, **kwargs: Any) -> Any: method = kwargs.get("method_axis", "PCA") cols = kwargs.get("cols_axis", []) color = kwargs.get("color_axis", None) if not cols: return html.Div( "Select at least one numeric column.", style={"color": "red"} ) data_value = self.dataframe[cols].values if method == "t-SNE": emb = TSNE( n_components=2, init="random", learning_rate="auto" ).fit_transform(data_value) comp_df = self.dataframe.copy() comp_df["dim1"] = emb[:, 0] comp_df["dim2"] = emb[:, 1] fig = px.scatter(comp_df, x="dim1", y="dim2", color=color) else: pca = PCA(n_components=2) emb = pca.fit_transform(data_value) comp_df = self.dataframe.copy() comp_df["PC1"] = emb[:, 0] comp_df["PC2"] = emb[:, 1] fig = px.scatter(comp_df, x="PC1", y="PC2", color=color) return dcc.Graph(figure=fig) def register_callbacks(self, app: Any) -> None: @app.callback( # type: ignore Output({"type": "plot-output", "plot": self.name}, "children"), Input({"type": "control", "plot": self.name, "axis": "method"}, "value"), Input({"type": "control", "plot": self.name, "axis": "cols"}, "value"), Input({"type": "control", "plot": self.name, "axis": "color"}, "value"), ) def update(method_axis: str, cols_axis: List[str], color_axis: str) -> Any: return self.render( method_axis=method_axis, cols_axis=cols_axis, color_axis=color_axis )