"""This module manages plugins for the AI Dashboard application.""" from typing import Any, List, Optional from .plot_plugin.bar_plot import BarPlotPlugin from .plot_plugin.histogram_plot import HistogramPlotPlugin from .plot_plugin.scatter_plot import ScatterPlotPlugin from .table_plugin.table_sample import SampleTablePlugin from .plot_plugin.line_plot import LinePlotPlugin from .plot_plugin.box_plot import BoxPlotPlugin from .plot_plugin.violin_plot import ViolinPlotPlugin from .plot_plugin.scatter_matrix import ScatterMatrixPlugin from .plot_plugin.heatmap_plot import CorrelationHeatmapPlugin from .plot_plugin.grouped_bar_plot import GroupedBarPlotPlugin from .plot_plugin.hexbin_plot import HexbinPlotPlugin from .plot_plugin.tsne import PCATSNEPlotPlugin from .plot_plugin.scatter_3d import Scatter3DPlotPlugin from .plot_plugin.geo_scatter import GeoPlotPlugin from .plot_plugin.word_cloud import WordCloudPlotPlugin from .plot_plugin.regression_line import RegressionPlotPlugin from .table_plugin.correlation_matrix import CorrelationMatrixTablePlugin from .table_plugin.summary_table import SummaryStatisticsTablePlugin class PluginManager: """Manages plot and table plugins for the AI Dashboard.""" def __init__(self, dataframe: Optional[Any] = None) -> None: """Initialize the PluginManager with available plot and table plugins.""" self.df = dataframe self.plot_groups = { "Basic Plots": [ ScatterPlotPlugin(dataframe), BarPlotPlugin(dataframe), LinePlotPlugin(dataframe), GroupedBarPlotPlugin(dataframe), ], "Statistical": [ HistogramPlotPlugin(dataframe), BoxPlotPlugin(dataframe), ViolinPlotPlugin(dataframe), CorrelationHeatmapPlugin(dataframe), ], "Dimensionality Reduction": [ ScatterMatrixPlugin(dataframe), PCATSNEPlotPlugin(dataframe), ], "Geospatial": [ GeoPlotPlugin(dataframe), ], "Advanced Visualizations": [ HexbinPlotPlugin(dataframe), Scatter3DPlotPlugin(dataframe), RegressionPlotPlugin(dataframe), WordCloudPlotPlugin(dataframe), ], } self.tables = [ SampleTablePlugin(dataframe), CorrelationMatrixTablePlugin(dataframe), SummaryStatisticsTablePlugin(dataframe), ] def set_dataframe(self, df: Any) -> None: """Set the dataframe for all plugins.""" self.df = df for group in self.plot_groups.values(): for p in group: # type: ignore p.dataframe = df for t in self.tables: t.dataframe = df def get_categories(self) -> List[str]: """Get the list of plot categories.""" return list(self.plot_groups.keys()) def get_plots_in_category(self, category: str) -> List[str]: """Get the list of plot names in a given category.""" return [p.name for p in self.plot_groups[category]] # type: ignore def get_plot(self, name: str) -> Any: """Get a plot plugin by its name.""" for group in self.plot_groups.values(): for p in group: # type: ignore if p.name == name: return p raise KeyError(f"Plot {name} not found.") def get_table_names(self) -> List[str]: """Get the list of table names.""" return [t.name for t in self.tables] def get_table(self, name: str) -> Any: """Get a table plugin by its name.""" return next(t for t in self.tables if t.name == name)