Spaces:
Running
Running
File size: 4,300 Bytes
9eecab5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | from utils.logger import logger
import matplotlib.pyplot as plt
import plotext as plt_terminal
import os
class VisualizationAgent:
def __init__(self, registry):
self.registry = registry
os.makedirs("output", exist_ok=True)
def _detect_dataset(self, query, datasets):
q = query.lower()
for d in datasets:
if d.lower() in q:
return d
logger.info("Dataset not specified, using default dataset.")
return datasets[0]
def _detect_column(self, query, columns):
q = query.lower()
for col in columns:
if col.lower() in q:
return col
return None
def handle(self, query):
q = query.lower()
try:
datasets = self.registry.list_datasets()
if not datasets:
logger.warning("VisualizationAgent called with no datasets loaded.")
return "No datasets available."
dataset = self._detect_dataset(q, datasets)
df = self.registry.load_dataframe(dataset)
columns = df.columns.tolist()
except Exception as e:
logger.error(f"Failed loading dataset in VisualizationAgent | {e}")
return "Failed to load dataset."
try:
column = self._detect_column(q, columns)
if column is None:
logger.warning("Column not detected for visualization.")
return "Column not found in dataset."
# ---------- HISTOGRAM ----------
if "hist" in q or "histogram" in q:
logger.info(f"Generating histogram for {column} in {dataset}")
values = df[column].dropna().values
# Terminal plot
plt_terminal.clear_figure()
plt_terminal.hist(values, bins=20)
plt_terminal.title(f"Histogram of {column}")
plt_terminal.xlabel(column)
plt_terminal.ylabel("Frequency")
plt_terminal.show()
# Save PNG
filepath = f"output/{dataset}_{column}_hist.png"
plt.figure()
df[column].dropna().hist()
plt.title(f"Histogram of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
plt.savefig(filepath)
plt.close()
logger.info(f"Histogram saved → {filepath}")
return f"Histogram generated in terminal. PNG saved to {filepath}"
# ---------- BAR CHART ----------
if "bar" in q or "bar chart" in q:
unique_values = df[column].nunique()
if unique_values > 50:
logger.warning(
f"Column '{column}' has {unique_values} unique values. Skipping bar chart."
)
return f"Column '{column}' has {unique_values} unique values. Too many to visualize meaningfully."
logger.info(f"Generating bar chart for {column} in {dataset}")
counts = df[column].value_counts()
# Terminal plot
plt_terminal.clear_figure()
plt_terminal.bar(
counts.index.astype(str).tolist(),
counts.values.tolist()
)
plt_terminal.title(f"Bar Chart of {column}")
plt_terminal.xlabel(column)
plt_terminal.ylabel("Count")
plt_terminal.show()
# Save PNG
filepath = f"output/{dataset}_{column}_bar.png"
plt.figure()
counts.plot(kind="bar")
plt.title(f"Bar Chart of {column}")
plt.xlabel(column)
plt.ylabel("Count")
plt.savefig(filepath)
plt.close()
logger.info(f"Bar chart saved → {filepath}")
return f"Bar chart generated in terminal. PNG saved to {filepath}"
return "Visualization query not understood."
except Exception as e:
logger.error(f"Visualization failed | Query: {query} | Error: {e}")
return "Visualization agent error." |