afshin-dini commited on
Commit
d4a56d7
·
1 Parent(s): d453c55

Add geo scatter plots

Browse files
src/ai_dashboard/plot_plugin/geo_scatter.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Geo Scatter / Choropleth plugin for AI-Dashboard"""
2
+
3
+ from typing import Any, List
4
+ import logging
5
+
6
+ import plotly.express as px
7
+ from dash import html, dcc
8
+ from dash.dependencies import Input, Output
9
+
10
+ from ..base import BasePlotPlugin
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class GeoPlotPlugin(BasePlotPlugin):
16
+ """Geo Scatter / Choropleth plugin"""
17
+
18
+ name = "Geo Plot"
19
+
20
+ def dropdown(self, axis: str, label: str, options: List[str]) -> Any:
21
+ """Create a dropdown control for the given axis."""
22
+ return html.Div(
23
+ [
24
+ html.Label(label),
25
+ dcc.Dropdown(
26
+ id={"type": "control", "plot": self.name, "axis": axis},
27
+ options=[{"label": c, "value": c} for c in options], # type: ignore
28
+ value=options[0],
29
+ clearable=False,
30
+ persistence=True,
31
+ ),
32
+ ],
33
+ style={"width": "220px", "marginBottom": "10px"},
34
+ )
35
+
36
+ def controls(self) -> Any:
37
+ """Render the control panel for the plot."""
38
+ cats = self.categorical_columns()
39
+ nums = self.numeric_columns()
40
+ mode_options = ["Scatter Geo", "Choropleth"]
41
+
42
+ return html.Div(
43
+ [
44
+ self.dropdown("mode", "Mode", mode_options),
45
+ self.dropdown("location", "Location Column", cats),
46
+ self.dropdown("value", "Value / Color", nums),
47
+ ],
48
+ style={"display": "flex", "flexDirection": "column"},
49
+ )
50
+
51
+ def render(self, **kwargs: Any) -> Any:
52
+ """Render the plot based on current control settings."""
53
+ mode = kwargs.get("mode_axis", "Scatter Geo")
54
+ location = kwargs.get("location_axis")
55
+ value = kwargs.get("value_axis")
56
+
57
+ if mode == "Choropleth":
58
+ fig = px.choropleth(
59
+ self.dataframe,
60
+ locations=location,
61
+ color=value,
62
+ locationmode="country names",
63
+ )
64
+ else:
65
+ fig = px.scatter_geo(
66
+ self.dataframe,
67
+ locations=location,
68
+ color=value,
69
+ locationmode="country names",
70
+ )
71
+
72
+ return dcc.Graph(figure=fig)
73
+
74
+ def register_callbacks(self, app: Any) -> None:
75
+ """Register Dash callbacks for interactivity."""
76
+
77
+ @app.callback( # type: ignore
78
+ Output({"type": "plot-output", "plot": self.name}, "children"),
79
+ Input({"type": "control", "plot": self.name, "axis": "mode"}, "value"),
80
+ Input({"type": "control", "plot": self.name, "axis": "location"}, "value"),
81
+ Input({"type": "control", "plot": self.name, "axis": "value"}, "value"),
82
+ )
83
+ def update(mode_axis: str, location_axis: str, value_axis: str) -> Any:
84
+ """Update the plot based on control inputs."""
85
+ return self.render(
86
+ mode_axis=mode_axis,
87
+ location_axis=location_axis,
88
+ value_axis=value_axis,
89
+ )