"""Geo Scatter / Choropleth plugin for AI-Dashboard""" # pylint: disable=R0801 from typing import Any, List import logging import plotly.express as px from dash import html, dcc from dash.dependencies import Input, Output from ..base import BasePlotPlugin logger = logging.getLogger(__name__) class GeoPlotPlugin(BasePlotPlugin): """Geo Scatter / Choropleth plugin""" name = "Geo Plot" def dropdown(self, axis: str, label: str, options: List[str]) -> Any: """Create a dropdown control for the given axis.""" return html.Div( [ html.Label(label), dcc.Dropdown( id={"type": "control", "plot": self.name, "axis": axis}, options=[{"label": c, "value": c} for c in options], # type: ignore value=options[0], clearable=False, persistence=True, persistence_type="memory", style={"color": "#000"}, ), ], style={"width": "130px"}, ) def controls(self) -> Any: """Render the control panel for the plot.""" cats = self.categorical_columns() nums = self.numeric_columns() mode_options = ["Scatter Geo", "Choropleth"] return html.Div( [ self.dropdown("mode", "Mode", mode_options), self.dropdown("location", "Location Column", cats), self.dropdown("value", "Value / Color", nums), ], style={"display": "flex", "flexDirection": "column"}, ) def render(self, **kwargs: Any) -> Any: """Render the plot based on current control settings.""" mode = kwargs.get("mode_axis", "Scatter Geo") location = kwargs.get("location_axis") value = kwargs.get("value_axis") if mode == "Choropleth": fig = px.choropleth( self.dataframe, locations=location, color=value, locationmode="country names", ) else: fig = px.scatter_geo( self.dataframe, locations=location, color=value, locationmode="country names", ) return dcc.Graph(figure=fig) def register_callbacks(self, app: Any) -> None: """Register Dash callbacks for interactivity.""" @app.callback( # type: ignore Output({"type": "plot-output", "plot": self.name}, "children"), Input({"type": "control", "plot": self.name, "axis": "mode"}, "value"), Input({"type": "control", "plot": self.name, "axis": "location"}, "value"), Input({"type": "control", "plot": self.name, "axis": "value"}, "value"), ) def update(mode_axis: str, location_axis: str, value_axis: str) -> Any: """Update the plot based on control inputs.""" return self.render( mode_axis=mode_axis, location_axis=location_axis, value_axis=value_axis, )