EDA_Explorer / agents /visualization_agent.py
ProfessionalMario's picture
Fresh deployment with LFS tracking
9eecab5
from utils.logger import logger
import matplotlib.pyplot as plt
import plotext as plt_terminal
import os
class VisualizationAgent:
def __init__(self, registry):
self.registry = registry
os.makedirs("output", exist_ok=True)
def _detect_dataset(self, query, datasets):
q = query.lower()
for d in datasets:
if d.lower() in q:
return d
logger.info("Dataset not specified, using default dataset.")
return datasets[0]
def _detect_column(self, query, columns):
q = query.lower()
for col in columns:
if col.lower() in q:
return col
return None
def handle(self, query):
q = query.lower()
try:
datasets = self.registry.list_datasets()
if not datasets:
logger.warning("VisualizationAgent called with no datasets loaded.")
return "No datasets available."
dataset = self._detect_dataset(q, datasets)
df = self.registry.load_dataframe(dataset)
columns = df.columns.tolist()
except Exception as e:
logger.error(f"Failed loading dataset in VisualizationAgent | {e}")
return "Failed to load dataset."
try:
column = self._detect_column(q, columns)
if column is None:
logger.warning("Column not detected for visualization.")
return "Column not found in dataset."
# ---------- HISTOGRAM ----------
if "hist" in q or "histogram" in q:
logger.info(f"Generating histogram for {column} in {dataset}")
values = df[column].dropna().values
# Terminal plot
plt_terminal.clear_figure()
plt_terminal.hist(values, bins=20)
plt_terminal.title(f"Histogram of {column}")
plt_terminal.xlabel(column)
plt_terminal.ylabel("Frequency")
plt_terminal.show()
# Save PNG
filepath = f"output/{dataset}_{column}_hist.png"
plt.figure()
df[column].dropna().hist()
plt.title(f"Histogram of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
plt.savefig(filepath)
plt.close()
logger.info(f"Histogram saved → {filepath}")
return f"Histogram generated in terminal. PNG saved to {filepath}"
# ---------- BAR CHART ----------
if "bar" in q or "bar chart" in q:
unique_values = df[column].nunique()
if unique_values > 50:
logger.warning(
f"Column '{column}' has {unique_values} unique values. Skipping bar chart."
)
return f"Column '{column}' has {unique_values} unique values. Too many to visualize meaningfully."
logger.info(f"Generating bar chart for {column} in {dataset}")
counts = df[column].value_counts()
# Terminal plot
plt_terminal.clear_figure()
plt_terminal.bar(
counts.index.astype(str).tolist(),
counts.values.tolist()
)
plt_terminal.title(f"Bar Chart of {column}")
plt_terminal.xlabel(column)
plt_terminal.ylabel("Count")
plt_terminal.show()
# Save PNG
filepath = f"output/{dataset}_{column}_bar.png"
plt.figure()
counts.plot(kind="bar")
plt.title(f"Bar Chart of {column}")
plt.xlabel(column)
plt.ylabel("Count")
plt.savefig(filepath)
plt.close()
logger.info(f"Bar chart saved → {filepath}")
return f"Bar chart generated in terminal. PNG saved to {filepath}"
return "Visualization query not understood."
except Exception as e:
logger.error(f"Visualization failed | Query: {query} | Error: {e}")
return "Visualization agent error."