afshin-dini commited on
Commit
222e0c8
·
1 Parent(s): 44939e1

Add scatter plot

Browse files
src/ai_dashboard/plot_plugin/scatter_plot.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is the scatter plot module for AI Dashboard."""
2
+
3
+ from typing import Any
4
+ import logging
5
+
6
+ import plotly.express as px
7
+ from dash import html, dcc
8
+ from dash.dependencies import Input, Output
9
+
10
+ from ..base import BasePlotPlugin
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ScatterPlotPlugin(BasePlotPlugin):
16
+ """Scatter plot plugin for AI Dashboard."""
17
+
18
+ name = "Scatter Plot"
19
+
20
+ def dropdown(self, id_suffix: str, label: str) -> Any:
21
+ """Generate a dropdown for plot controls."""
22
+ cols = self.dataframe.columns.tolist()
23
+ return html.Div(
24
+ [
25
+ html.Label(label),
26
+ dcc.Dropdown(
27
+ id={"type": "control", "plot": self.name, "axis": id_suffix},
28
+ options=[{"label": c, "value": c} for c in cols], # type: ignore
29
+ value=cols[0],
30
+ clearable=False,
31
+ persistence=True,
32
+ persistence_type="memory",
33
+ ),
34
+ ],
35
+ style={"width": "170px", "marginRight": "12px"},
36
+ )
37
+
38
+ def controls(self) -> Any:
39
+ """Connect the scatter plot plugin to the Dash app."""
40
+ return html.Div(
41
+ [
42
+ self.dropdown("x", "X-Axis"),
43
+ self.dropdown("y", "Y-Axis"),
44
+ self.dropdown("size", "Size"),
45
+ self.dropdown("color", "Color"),
46
+ self.dropdown("hover", "Hover"),
47
+ ],
48
+ style={"display": "flex", "flexWrap": "wrap", "gap": "12px"},
49
+ )
50
+
51
+ def render(self, **kwargs: Any) -> Any:
52
+ """Render the scatter plot."""
53
+ x_axis = kwargs.get("x_axis")
54
+ y_axis = kwargs.get("y_axis")
55
+ size_axis = kwargs.get("size_axis")
56
+ color_axis = kwargs.get("color_axis")
57
+ hover_axis = kwargs.get("hover_axis")
58
+ fig = px.scatter(
59
+ self.dataframe,
60
+ x=x_axis,
61
+ y=y_axis,
62
+ size=size_axis if size_axis in self.dataframe.columns else None,
63
+ color=color_axis if color_axis in self.dataframe.columns else None,
64
+ hover_name=hover_axis if hover_axis in self.dataframe.columns else None,
65
+ log_x=True,
66
+ size_max=60,
67
+ )
68
+ return dcc.Graph(figure=fig)
69
+
70
+ def register_callbacks(self, app: Any) -> None:
71
+ """Register the callbacks for the scatter plot plugin."""
72
+
73
+ @app.callback( # type: ignore
74
+ Output({"type": "plot-output", "plot": self.name}, "children"),
75
+ Input({"type": "control", "plot": self.name, "axis": "x"}, "value"),
76
+ Input({"type": "control", "plot": self.name, "axis": "y"}, "value"),
77
+ Input({"type": "control", "plot": self.name, "axis": "size"}, "value"),
78
+ Input({"type": "control", "plot": self.name, "axis": "color"}, "value"),
79
+ Input({"type": "control", "plot": self.name, "axis": "hover"}, "value"),
80
+ )
81
+ def update_scatter(
82
+ x_axis: str, y_axis: str, size_axis: str, color_axis: str, hover_axis: str
83
+ ) -> Any:
84
+ """Update the scatter plot based on user input."""
85
+ return self.render(
86
+ x_axis=x_axis,
87
+ y_axis=y_axis,
88
+ size_axis=size_axis,
89
+ color_axis=color_axis,
90
+ hover_axis=hover_axis,
91
+ )