Spaces:
Sleeping
Sleeping
| """PCA / t-SNE plot plugin for AI-Dashboard""" | |
| # pylint: disable=R0801 | |
| from typing import Any, List | |
| import logging | |
| import plotly.express as px | |
| from dash import html, dcc | |
| from dash.dependencies import Input, Output | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| from ..base import BasePlotPlugin | |
| logger = logging.getLogger(__name__) | |
| class PCATSNEPlotPlugin(BasePlotPlugin): | |
| """Dimensionality reduction plugin (PCA / t-SNE).""" | |
| name = "PCA / t-SNE Plot" | |
| def dropdown( | |
| self, id_suffix: str, label: str, options: List[str], multi: bool = False | |
| ) -> Any: | |
| """Create a dropdown control.""" | |
| return html.Div( | |
| [ | |
| html.Label(label), | |
| dcc.Dropdown( | |
| id={"type": "control", "plot": self.name, "axis": id_suffix}, | |
| options=[{"label": c, "value": c} for c in options], # type: ignore | |
| value=(options[:3] if multi else options[0]), | |
| clearable=False, | |
| multi=multi, | |
| persistence=True, | |
| persistence_type="memory", | |
| style={"color": "#000"}, | |
| ), | |
| ], | |
| style={"width": "130px"}, | |
| ) | |
| def controls(self) -> Any: | |
| nums = self.numeric_columns() | |
| cats = self.categorical_columns() | |
| method_options = ["PCA", "t-SNE"] | |
| return html.Div( | |
| [ | |
| self.dropdown("method", "Method", method_options), | |
| self.dropdown("cols", "Numeric Columns", nums, multi=True), | |
| self.dropdown("color", "Color By", nums + cats), | |
| ], | |
| style={"display": "flex", "flexDirection": "column"}, | |
| ) | |
| def render(self, **kwargs: Any) -> Any: | |
| method = kwargs.get("method_axis", "PCA") | |
| cols = kwargs.get("cols_axis", []) | |
| color = kwargs.get("color_axis", None) | |
| if not cols: | |
| return html.Div( | |
| "Select at least one numeric column.", style={"color": "red"} | |
| ) | |
| data_value = self.dataframe[cols].values | |
| if method == "t-SNE": | |
| emb = TSNE( | |
| n_components=2, init="random", learning_rate="auto" | |
| ).fit_transform(data_value) | |
| comp_df = self.dataframe.copy() | |
| comp_df["dim1"] = emb[:, 0] | |
| comp_df["dim2"] = emb[:, 1] | |
| fig = px.scatter(comp_df, x="dim1", y="dim2", color=color) | |
| else: | |
| pca = PCA(n_components=2) | |
| emb = pca.fit_transform(data_value) | |
| comp_df = self.dataframe.copy() | |
| comp_df["PC1"] = emb[:, 0] | |
| comp_df["PC2"] = emb[:, 1] | |
| fig = px.scatter(comp_df, x="PC1", y="PC2", color=color) | |
| return dcc.Graph(figure=fig) | |
| def register_callbacks(self, app: Any) -> None: | |
| def update(method_axis: str, cols_axis: List[str], color_axis: str) -> Any: | |
| return self.render( | |
| method_axis=method_axis, cols_axis=cols_axis, color_axis=color_axis | |
| ) | |