Spaces:
Running
Running
| from utils.logger import logger | |
| import matplotlib.pyplot as plt | |
| import plotext as plt_terminal | |
| import os | |
| class VisualizationAgent: | |
| def __init__(self, registry): | |
| self.registry = registry | |
| os.makedirs("output", exist_ok=True) | |
| def _detect_dataset(self, query, datasets): | |
| q = query.lower() | |
| for d in datasets: | |
| if d.lower() in q: | |
| return d | |
| logger.info("Dataset not specified, using default dataset.") | |
| return datasets[0] | |
| def _detect_column(self, query, columns): | |
| q = query.lower() | |
| for col in columns: | |
| if col.lower() in q: | |
| return col | |
| return None | |
| def handle(self, query): | |
| q = query.lower() | |
| try: | |
| datasets = self.registry.list_datasets() | |
| if not datasets: | |
| logger.warning("VisualizationAgent called with no datasets loaded.") | |
| return "No datasets available." | |
| dataset = self._detect_dataset(q, datasets) | |
| df = self.registry.load_dataframe(dataset) | |
| columns = df.columns.tolist() | |
| except Exception as e: | |
| logger.error(f"Failed loading dataset in VisualizationAgent | {e}") | |
| return "Failed to load dataset." | |
| try: | |
| column = self._detect_column(q, columns) | |
| if column is None: | |
| logger.warning("Column not detected for visualization.") | |
| return "Column not found in dataset." | |
| # ---------- HISTOGRAM ---------- | |
| if "hist" in q or "histogram" in q: | |
| logger.info(f"Generating histogram for {column} in {dataset}") | |
| values = df[column].dropna().values | |
| # Terminal plot | |
| plt_terminal.clear_figure() | |
| plt_terminal.hist(values, bins=20) | |
| plt_terminal.title(f"Histogram of {column}") | |
| plt_terminal.xlabel(column) | |
| plt_terminal.ylabel("Frequency") | |
| plt_terminal.show() | |
| # Save PNG | |
| filepath = f"output/{dataset}_{column}_hist.png" | |
| plt.figure() | |
| df[column].dropna().hist() | |
| plt.title(f"Histogram of {column}") | |
| plt.xlabel(column) | |
| plt.ylabel("Frequency") | |
| plt.savefig(filepath) | |
| plt.close() | |
| logger.info(f"Histogram saved → {filepath}") | |
| return f"Histogram generated in terminal. PNG saved to {filepath}" | |
| # ---------- BAR CHART ---------- | |
| if "bar" in q or "bar chart" in q: | |
| unique_values = df[column].nunique() | |
| if unique_values > 50: | |
| logger.warning( | |
| f"Column '{column}' has {unique_values} unique values. Skipping bar chart." | |
| ) | |
| return f"Column '{column}' has {unique_values} unique values. Too many to visualize meaningfully." | |
| logger.info(f"Generating bar chart for {column} in {dataset}") | |
| counts = df[column].value_counts() | |
| # Terminal plot | |
| plt_terminal.clear_figure() | |
| plt_terminal.bar( | |
| counts.index.astype(str).tolist(), | |
| counts.values.tolist() | |
| ) | |
| plt_terminal.title(f"Bar Chart of {column}") | |
| plt_terminal.xlabel(column) | |
| plt_terminal.ylabel("Count") | |
| plt_terminal.show() | |
| # Save PNG | |
| filepath = f"output/{dataset}_{column}_bar.png" | |
| plt.figure() | |
| counts.plot(kind="bar") | |
| plt.title(f"Bar Chart of {column}") | |
| plt.xlabel(column) | |
| plt.ylabel("Count") | |
| plt.savefig(filepath) | |
| plt.close() | |
| logger.info(f"Bar chart saved → {filepath}") | |
| return f"Bar chart generated in terminal. PNG saved to {filepath}" | |
| return "Visualization query not understood." | |
| except Exception as e: | |
| logger.error(f"Visualization failed | Query: {query} | Error: {e}") | |
| return "Visualization agent error." |