Spaces:
Sleeping
Sleeping
| """This is the scatter plot module for AI Dashboard.""" | |
| # pylint: disable=R0801 | |
| from typing import Any, List | |
| import logging | |
| import plotly.express as px | |
| from dash import html, dcc | |
| from dash.dependencies import Input, Output | |
| from ..base import BasePlotPlugin | |
| logger = logging.getLogger(__name__) | |
| class ScatterPlotPlugin(BasePlotPlugin): | |
| """Scatter plot plugin for AI Dashboard.""" | |
| name = "Scatter Plot" | |
| def dropdown(self, id_suffix: str, label: str, option: List[str]) -> Any: | |
| """Generate a dropdown for plot controls.""" | |
| return html.Div( | |
| [ | |
| html.Label(label), | |
| dcc.Dropdown( | |
| id={"type": "control", "plot": self.name, "axis": id_suffix}, | |
| options=[{"label": c, "value": c} for c in option], # type: ignore | |
| value=option[0], | |
| clearable=False, | |
| persistence=True, | |
| persistence_type="memory", | |
| style={"color": "#000"}, | |
| ), | |
| ], | |
| style={"width": "130px"}, | |
| ) | |
| def controls(self) -> Any: | |
| """Connect the scatter plot plugin to the Dash app.""" | |
| nums = self.numeric_columns() | |
| cats = self.categorical_columns() | |
| all_cols = nums + cats | |
| return html.Div( | |
| [ | |
| self.dropdown("x", "X-Axis", nums), | |
| self.dropdown("y", "Y-Axis", nums), | |
| self.dropdown("size", "Size", nums), | |
| self.dropdown("color", "Color", all_cols), | |
| self.dropdown("hover", "Hover", all_cols), | |
| ], | |
| style={"display": "flex", "flexWrap": "wrap"}, | |
| ) | |
| def render(self, **kwargs: Any) -> Any: | |
| """Render the scatter plot.""" | |
| x_axis = kwargs.get("x_axis") | |
| y_axis = kwargs.get("y_axis") | |
| size_axis = kwargs.get("size_axis") | |
| color_axis = kwargs.get("color_axis") | |
| hover_axis = kwargs.get("hover_axis") | |
| fig = px.scatter( | |
| self.dataframe, | |
| x=x_axis, | |
| y=y_axis, | |
| size=size_axis if size_axis in self.dataframe.columns else None, | |
| color=color_axis if color_axis in self.dataframe.columns else None, | |
| hover_name=hover_axis if hover_axis in self.dataframe.columns else None, | |
| log_x=True, | |
| size_max=60, | |
| ) | |
| return dcc.Graph(figure=fig) | |
| def register_callbacks(self, app: Any) -> None: | |
| """Register the callbacks for the scatter plot plugin.""" | |
| def update_scatter( | |
| x_axis: str, y_axis: str, size_axis: str, color_axis: str, hover_axis: str | |
| ) -> Any: | |
| """Update the scatter plot based on user input.""" | |
| return self.render( | |
| x_axis=x_axis, | |
| y_axis=y_axis, | |
| size_axis=size_axis, | |
| color_axis=color_axis, | |
| hover_axis=hover_axis, | |
| ) | |