afshin-dini's picture
Update plugin manger to be working with many plot categories
8c69abd
"""This module manages plugins for the AI Dashboard application."""
from typing import Any, List, Optional
from .plot_plugin.bar_plot import BarPlotPlugin
from .plot_plugin.histogram_plot import HistogramPlotPlugin
from .plot_plugin.scatter_plot import ScatterPlotPlugin
from .table_plugin.table_sample import SampleTablePlugin
from .plot_plugin.line_plot import LinePlotPlugin
from .plot_plugin.box_plot import BoxPlotPlugin
from .plot_plugin.violin_plot import ViolinPlotPlugin
from .plot_plugin.scatter_matrix import ScatterMatrixPlugin
from .plot_plugin.heatmap_plot import CorrelationHeatmapPlugin
from .plot_plugin.grouped_bar_plot import GroupedBarPlotPlugin
from .plot_plugin.hexbin_plot import HexbinPlotPlugin
from .plot_plugin.tsne import PCATSNEPlotPlugin
from .plot_plugin.scatter_3d import Scatter3DPlotPlugin
from .plot_plugin.geo_scatter import GeoPlotPlugin
from .plot_plugin.word_cloud import WordCloudPlotPlugin
from .plot_plugin.regression_line import RegressionPlotPlugin
from .table_plugin.correlation_matrix import CorrelationMatrixTablePlugin
from .table_plugin.summary_table import SummaryStatisticsTablePlugin
class PluginManager:
"""Manages plot and table plugins for the AI Dashboard."""
def __init__(self, dataframe: Optional[Any] = None) -> None:
"""Initialize the PluginManager with available plot and table plugins."""
self.df = dataframe
self.plot_groups = {
"Basic Plots": [
ScatterPlotPlugin(dataframe),
BarPlotPlugin(dataframe),
LinePlotPlugin(dataframe),
GroupedBarPlotPlugin(dataframe),
],
"Statistical": [
HistogramPlotPlugin(dataframe),
BoxPlotPlugin(dataframe),
ViolinPlotPlugin(dataframe),
CorrelationHeatmapPlugin(dataframe),
],
"Dimensionality Reduction": [
ScatterMatrixPlugin(dataframe),
PCATSNEPlotPlugin(dataframe),
],
"Geospatial": [
GeoPlotPlugin(dataframe),
],
"Advanced Visualizations": [
HexbinPlotPlugin(dataframe),
Scatter3DPlotPlugin(dataframe),
RegressionPlotPlugin(dataframe),
WordCloudPlotPlugin(dataframe),
],
}
self.tables = [
SampleTablePlugin(dataframe),
CorrelationMatrixTablePlugin(dataframe),
SummaryStatisticsTablePlugin(dataframe),
]
def set_dataframe(self, df: Any) -> None:
"""Set the dataframe for all plugins."""
self.df = df
for group in self.plot_groups.values():
for p in group: # type: ignore
p.dataframe = df
for t in self.tables:
t.dataframe = df
def get_categories(self) -> List[str]:
"""Get the list of plot categories."""
return list(self.plot_groups.keys())
def get_plots_in_category(self, category: str) -> List[str]:
"""Get the list of plot names in a given category."""
return [p.name for p in self.plot_groups[category]] # type: ignore
def get_plot(self, name: str) -> Any:
"""Get a plot plugin by its name."""
for group in self.plot_groups.values():
for p in group: # type: ignore
if p.name == name:
return p
raise KeyError(f"Plot {name} not found.")
def get_table_names(self) -> List[str]:
"""Get the list of table names."""
return [t.name for t in self.tables]
def get_table(self, name: str) -> Any:
"""Get a table plugin by its name."""
return next(t for t in self.tables if t.name == name)