File size: 2,539 Bytes
31f1cfa
 
3eaf500
31f1cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eaf500
 
31f1cfa
 
3eaf500
31f1cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Word Cloud plot plugin for AI-Dashboard"""

# pylint: disable=R0801
from typing import Any, List
import logging
import io
import base64

from wordcloud import WordCloud
from dash import html, dcc
from dash.dependencies import Input, Output

from ..base import BasePlotPlugin

logger = logging.getLogger(__name__)


class WordCloudPlotPlugin(BasePlotPlugin):
    """Word Cloud plugin."""

    name = "Word Cloud"

    def dropdown(self, axis: str, label: str, options: List[str]) -> Any:
        """Create a dropdown control."""
        return html.Div(
            [
                html.Label(label),
                dcc.Dropdown(
                    id={"type": "control", "plot": self.name, "axis": axis},
                    options=[{"label": c, "value": c} for c in options],  # type: ignore
                    value=options[0],
                    clearable=False,
                    persistence=True,
                    persistence_type="memory",
                    style={"color": "#000"},
                ),
            ],
            style={"width": "130px"},
        )

    def controls(self) -> Any:
        """Render the controls for the plot."""
        text_cols = self.categorical_columns()
        return html.Div(
            [self.dropdown("text", "Text Column", text_cols)],
            style={"display": "flex", "flexDirection": "column"},
        )

    def render(self, **kwargs: Any) -> Any:
        """Render the word cloud plot."""
        text_col = kwargs.get("text_axis")
        text_series = self.dataframe[text_col].dropna().astype(str)
        full_text = " ".join(text_series.values)

        wc = WordCloud(width=800, height=400, background_color="white").generate(
            full_text
        )
        img_buffer = io.BytesIO()
        wc.to_image().save(img_buffer, format="PNG")
        img_buffer.seek(0)
        encoded = base64.b64encode(img_buffer.read()).decode("utf-8")

        return html.Img(
            src=f"data:image/png;base64,{encoded}",
            style={"maxWidth": "100%", "height": "auto"},
        )

    def register_callbacks(self, app: Any) -> None:
        """Register the callbacks for the plot."""

        @app.callback(  # type: ignore
            Output({"type": "plot-output", "plot": self.name}, "children"),
            Input({"type": "control", "plot": self.name, "axis": "text"}, "value"),
        )
        def update(text_axis: str) -> Any:
            """Update the plot based on control inputs."""
            return self.render(text_axis=text_axis)