File size: 3,476 Bytes
d453c55
 
3eaf500
d453c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eaf500
 
d453c55
 
3eaf500
d453c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""PCA / t-SNE plot plugin for AI-Dashboard"""

# pylint: disable=R0801
from typing import Any, List
import logging

import plotly.express as px
from dash import html, dcc
from dash.dependencies import Input, Output
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from ..base import BasePlotPlugin

logger = logging.getLogger(__name__)


class PCATSNEPlotPlugin(BasePlotPlugin):
    """Dimensionality reduction plugin (PCA / t-SNE)."""

    name = "PCA / t-SNE Plot"

    def dropdown(
        self, id_suffix: str, label: str, options: List[str], multi: bool = False
    ) -> Any:
        """Create a dropdown control."""
        return html.Div(
            [
                html.Label(label),
                dcc.Dropdown(
                    id={"type": "control", "plot": self.name, "axis": id_suffix},
                    options=[{"label": c, "value": c} for c in options],  # type: ignore
                    value=(options[:3] if multi else options[0]),
                    clearable=False,
                    multi=multi,
                    persistence=True,
                    persistence_type="memory",
                    style={"color": "#000"},
                ),
            ],
            style={"width": "130px"},
        )

    def controls(self) -> Any:
        nums = self.numeric_columns()
        cats = self.categorical_columns()
        method_options = ["PCA", "t-SNE"]

        return html.Div(
            [
                self.dropdown("method", "Method", method_options),
                self.dropdown("cols", "Numeric Columns", nums, multi=True),
                self.dropdown("color", "Color By", nums + cats),
            ],
            style={"display": "flex", "flexDirection": "column"},
        )

    def render(self, **kwargs: Any) -> Any:
        method = kwargs.get("method_axis", "PCA")
        cols = kwargs.get("cols_axis", [])
        color = kwargs.get("color_axis", None)

        if not cols:
            return html.Div(
                "Select at least one numeric column.", style={"color": "red"}
            )

        data_value = self.dataframe[cols].values

        if method == "t-SNE":
            emb = TSNE(
                n_components=2, init="random", learning_rate="auto"
            ).fit_transform(data_value)
            comp_df = self.dataframe.copy()
            comp_df["dim1"] = emb[:, 0]
            comp_df["dim2"] = emb[:, 1]
            fig = px.scatter(comp_df, x="dim1", y="dim2", color=color)
        else:
            pca = PCA(n_components=2)
            emb = pca.fit_transform(data_value)
            comp_df = self.dataframe.copy()
            comp_df["PC1"] = emb[:, 0]
            comp_df["PC2"] = emb[:, 1]
            fig = px.scatter(comp_df, x="PC1", y="PC2", color=color)

        return dcc.Graph(figure=fig)

    def register_callbacks(self, app: Any) -> None:
        @app.callback(  # type: ignore
            Output({"type": "plot-output", "plot": self.name}, "children"),
            Input({"type": "control", "plot": self.name, "axis": "method"}, "value"),
            Input({"type": "control", "plot": self.name, "axis": "cols"}, "value"),
            Input({"type": "control", "plot": self.name, "axis": "color"}, "value"),
        )
        def update(method_axis: str, cols_axis: List[str], color_axis: str) -> Any:
            return self.render(
                method_axis=method_axis, cols_axis=cols_axis, color_axis=color_axis
            )