Spaces:
Sleeping
Sleeping
File size: 11,715 Bytes
ce34ff4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 | # agents/visualization_agent.py
"""
Visualization Agent for Lexis
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Flow:
1. Receive query + filename from Flask route
2. Load the dataset (CSV or Excel) from data/datasets/
3. Build a schema summary (columns, dtypes, sample rows)
4. Ask the LLM to produce a Plotly figure as valid JSON
5. Validate & sanitize the JSON (no exec, no eval)
6. Return the Plotly figure dict + a plain-English summary
The frontend renders the figure dict using Plotly.js.
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
"""
import os
import json
import re
import traceback
import pandas as pd
from langchain.chat_models import init_chat_model
from dotenv import load_dotenv
load_dotenv()
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DATASETS_DIR = os.path.join("data", "datasets")
# Re-use the same LLM used by AnswerGenerator
# Import from config so it's always in sync
try:
from config.settings import GENERATION_MODEL_NAME
except ImportError:
GENERATION_MODEL_NAME = "groq:llama-3.3-70b-versatile"
print(GENERATION_MODEL_NAME)
_MAX_SAMPLE_ROWS = 5 # rows shown to LLM for context
_MAX_UNIQUE_VALS = 20 # max unique values shown per column
# ββ Dataset loader βββββββββββββββββββββββββββββββββββββββββββββββββ
def load_dataset(filename: str) -> pd.DataFrame:
"""Load CSV or Excel file from the datasets directory."""
path = os.path.join(DATASETS_DIR, filename)
if not os.path.exists(path):
raise FileNotFoundError(f"Dataset '{filename}' not found in {DATASETS_DIR}/")
ext = filename.rsplit(".", 1)[-1].lower()
if ext == "csv":
return pd.read_csv(path)
elif ext in ("xlsx", "xls"):
return pd.read_excel(path)
else:
raise ValueError(f"Unsupported file type: .{ext} (only CSV and Excel are supported)")
# ββ Schema builder βββββββββββββββββββββββββββββββββββββββββββββββββ
def build_schema_summary(df: pd.DataFrame) -> str:
"""
Produce a compact, LLM-readable description of the dataframe:
- shape
- column names, dtypes
- sample unique values for categorical columns
- numeric range for numeric columns
- first N sample rows as a markdown table
"""
lines = []
lines.append(f"Shape: {df.shape[0]} rows Γ {df.shape[1]} columns\n")
lines.append("Columns:")
for col in df.columns:
dtype = str(df[col].dtype)
n_null = int(df[col].isna().sum())
if pd.api.types.is_numeric_dtype(df[col]):
info = f"numeric | min={df[col].min():.4g}, max={df[col].max():.4g}, mean={df[col].mean():.4g}"
else:
uniq = df[col].dropna().unique()
if len(uniq) <= _MAX_UNIQUE_VALS:
info = f"categorical | unique values: {list(uniq[:_MAX_UNIQUE_VALS])}"
else:
info = f"categorical | {len(uniq)} unique values, e.g. {list(uniq[:5])}"
lines.append(f" β’ {col!r} [{dtype}] nulls={n_null} β {info}")
lines.append(f"\nFirst {_MAX_SAMPLE_ROWS} rows (markdown):")
lines.append(df.head(_MAX_SAMPLE_ROWS).to_markdown(index=False))
return "\n".join(lines)
# ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββ
_SYSTEM_PROMPT = """You are a data visualization expert.
You will receive:
1. A dataset schema (columns, dtypes, sample rows)
2. A user request describing a chart they want
Your job is to output a SINGLE valid JSON object that represents a Plotly figure.
The JSON must be a Plotly figure dict with two top-level keys:
- "data" β list of trace dicts (e.g. go.Bar, go.Scatter, go.Pie, go.Histogram, etc.)
- "layout" β layout dict (title, xaxis, yaxis, etc.)
STRICT RULES:
1. Output ONLY the raw JSON β no markdown, no backticks, no explanation before or after.
2. Do NOT use Python code or executable code anywhere in your response.
3. Use exact column names from the schema β do not invent column names.
4. For aggregations (e.g., "average salary by age band"), compute the aggregation
by producing the aggregated x and y arrays DIRECTLY in the JSON using the raw values
that would result from that computation. Do NOT use formulas.
5. Choose the most appropriate chart type based on the request.
6. Always include a descriptive title in layout.title.text
7. Always include axis labels: layout.xaxis.title.text and layout.yaxis.title.text
(skip yaxis label for pie charts)
8. Use a clean, professional color scheme.
9. The JSON must be parseable by json.loads() β no trailing commas, no comments.
Example valid output structure:
{
"data": [
{
"type": "bar",
"x": ["A", "B", "C"],
"y": [10, 25, 15],
"marker": {"color": "#818cf8"}
}
],
"layout": {
"title": {"text": "My Chart"},
"xaxis": {"title": {"text": "Category"}},
"yaxis": {"title": {"text": "Value"}},
"plot_bgcolor": "rgba(0,0,0,0)",
"paper_bgcolor": "rgba(0,0,0,0)",
"font": {"color": "#f2f2f2"}
}
}
"""
# ββ LLM caller ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _call_llm(schema_summary: str, user_query: str, df: pd.DataFrame) -> dict:
"""
Call the LLM with schema + query.
Returns the parsed Plotly figure dict.
"""
llm = init_chat_model(GENERATION_MODEL_NAME)
user_message = f"""Dataset schema:
{schema_summary}
User request:
{user_query}
Remember: output ONLY raw JSON. No markdown, no explanation."""
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_message},
]
response = llm.invoke(messages)
raw = response.content.strip()
# Strip any accidental markdown code fences the LLM added
raw = re.sub(r"^```(?:json)?\s*", "", raw)
raw = re.sub(r"\s*```$", "", raw)
raw = raw.strip()
try:
figure_dict = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(
f"LLM returned invalid JSON: {e}\n\nRaw response (first 500 chars):\n{raw[:500]}"
)
return figure_dict
# ββ Figure validator βββββββββββββββββββββββββββββββββββββββββββββββ
def _validate_figure(figure_dict: dict) -> dict:
"""
Basic sanity checks on the Plotly figure dict.
Adds transparent background so it blends with the UI.
"""
if not isinstance(figure_dict, dict):
raise ValueError("Figure must be a JSON object (dict)")
if "data" not in figure_dict:
raise ValueError("Figure JSON missing required key: 'data'")
if not isinstance(figure_dict["data"], list):
raise ValueError("'data' must be a list of trace objects")
# Ensure layout exists
figure_dict.setdefault("layout", {})
layout = figure_dict["layout"]
# Transparent backgrounds so chart blends into the dark/light UI
layout.setdefault("plot_bgcolor", "rgba(0,0,0,0)")
layout.setdefault("paper_bgcolor", "rgba(0,0,0,0)")
layout.setdefault("font", {}).setdefault("color", "#f2f2f2")
# Clean margins
layout.setdefault("margin", {"t": 60, "r": 20, "b": 60, "l": 60})
# Responsive
figure_dict.setdefault("config", {
"responsive": True,
"displayModeBar": True,
"modeBarButtonsToRemove": ["toImage"],
})
return figure_dict
# ββ Summary generator ββββββββββββββββββββββββββββββββββββββββββββββ
def _generate_summary(user_query: str, figure_dict: dict, df: pd.DataFrame) -> str:
"""
Generate a short plain-English summary of what the chart shows.
Uses the LLM for a conversational 1-2 sentence interpretation.
"""
llm = init_chat_model(GENERATION_MODEL_NAME)
# Pull chart title if available
title = (
figure_dict.get("layout", {})
.get("title", {})
.get("text", "the chart")
)
messages = [
{
"role": "system",
"content": (
"You are a helpful data analyst. "
"Write 1-2 concise, conversational sentences describing what the chart shows. "
"Do not mention Plotly or technical details. "
"Be insightful β mention the key trend or takeaway if obvious."
),
},
{
"role": "user",
"content": (
f"The user asked: '{user_query}'\n"
f"A chart titled '{title}' was generated from a dataset "
f"with {df.shape[0]} rows and columns: {list(df.columns)}.\n"
"Write a short summary of what this chart likely shows."
),
},
]
response = llm.invoke(messages)
return response.content.strip()
# ββ Main entry point βββββββββββββββββββββββββββββββββββββββββββββββ
def run_visualization_agent(query: str, filename: str) -> dict:
"""
Main entry point called by the Flask route.
Args:
query : User's natural language chart request
filename : Dataset filename (must exist in data/datasets/)
Returns dict with keys:
success : bool
figure : Plotly figure dict (for Plotly.js on frontend)
summary : Plain-English description of the chart
filename : Echo back the filename used
error : str (only present if success=False)
"""
try:
# 1. Load dataset
df = load_dataset(filename)
# 2. Build schema summary for LLM
schema = build_schema_summary(df)
# 3. Ask LLM to generate Plotly JSON
raw_figure = _call_llm(schema, query, df)
# 4. Validate and apply UI theme
figure = _validate_figure(raw_figure)
# 5. Generate a short text summary
summary = _generate_summary(query, figure, df)
return {
"success": True,
"figure": figure,
"summary": summary,
"filename": filename,
"rows": df.shape[0],
"columns": list(df.columns),
}
except FileNotFoundError as e:
return {"success": False, "error": str(e)}
except ValueError as e:
return {"success": False, "error": str(e)}
except Exception as e:
return {
"success": False,
"error": f"Unexpected error: {str(e)}",
"detail": traceback.format_exc(),
} |