Lexis / agents /data_analysis_agent.py
aryachakraborty's picture
Upload 44 files
b9a1ba4 verified
# agents/data_analysis_agent.py
"""
Data Analysis Agent for Lexis
══════════════════════════════════════════════════════════════════════
Architecture β€” Two-pass LLM + real pandas execution:
PASS 1 (Plan)
LLM receives the dataset schema and user query.
It returns a JSON "analysis plan" β€” a list of named operations,
each specifying which pandas method to call and on which columns.
No code is exec'd from the LLM; we map operation names to
whitelisted pandas calls.
EXECUTE
The backend runs the whitelisted pandas operations and collects
the computed results (numbers, tables, ranked lists, etc.)
PASS 2 (Interpret)
LLM receives the user query + the actual computed results.
It writes a rich, structured analytical response in plain English,
with key findings, patterns and recommendations.
RETURN
{
"success": true,
"narrative": "...", # LLM's full analytical write-up
"sections": [...], # structured sections for the UI card
"stats_table": [...], # optional summary table rows
"filename": "...",
"rows": N, "columns": [...]
}
Safe by design:
- No eval(), no exec(), no arbitrary code from LLM
- All operations are whitelisted pandas method calls
- LLM only sees column names and schema β€” never raw data
══════════════════════════════════════════════════════════════════════
"""
import os
import json
import re
import math
import traceback
from typing import Any
import numpy as np
import pandas as pd
from langchain.chat_models import init_chat_model
from dotenv import load_dotenv
from services.query_logging import record_llm_call
load_dotenv()
# ── Config ─────────────────────────────────────────────────────────
DATASETS_DIR = os.path.join("data", "datasets")
try:
from config.settings import GENERATION_MODEL_NAME
except ImportError:
GENERATION_MODEL_NAME = "groq:llama-3.3-70b-versatile"
_MAX_SAMPLE_ROWS = 6
_MAX_UNIQUE_VALS = 25
_MAX_RESULT_ROWS = 20 # cap table results sent to LLM
# ── Dataset loader (shared pattern with viz agent) ─────────────────
def load_dataset(filename: str) -> pd.DataFrame:
path = os.path.join(DATASETS_DIR, filename)
if not os.path.exists(path):
raise FileNotFoundError(f"Dataset '{filename}' not found in {DATASETS_DIR}/")
ext = filename.rsplit(".", 1)[-1].lower()
if ext == "csv":
df = pd.read_csv(path)
elif ext in ("xlsx", "xls"):
df = pd.read_excel(path)
else:
raise ValueError(f"Unsupported file type: .{ext}")
return df
# ── Schema builder ─────────────────────────────────────────────────
def build_schema_summary(df: pd.DataFrame) -> str:
lines = [f"Shape: {df.shape[0]} rows Γ— {df.shape[1]} columns\n", "Columns:"]
for col in df.columns:
dtype = str(df[col].dtype)
n_null = int(df[col].isna().sum())
if pd.api.types.is_numeric_dtype(df[col]):
info = (
f"numeric | min={df[col].min():.4g}, max={df[col].max():.4g}, "
f"mean={df[col].mean():.4g}, std={df[col].std():.4g}"
)
elif pd.api.types.is_datetime64_any_dtype(df[col]):
info = f"datetime | range: {df[col].min()} β†’ {df[col].max()}"
else:
uniq = df[col].dropna().unique()
if len(uniq) <= _MAX_UNIQUE_VALS:
info = f"categorical | unique values: {list(uniq[:_MAX_UNIQUE_VALS])}"
else:
info = f"categorical | {len(uniq)} unique values, e.g. {list(uniq[:6])}"
lines.append(f" β€’ {col!r} [{dtype}] nulls={n_null} β€” {info}")
lines.append(f"\nSample rows ({_MAX_SAMPLE_ROWS} rows):")
lines.append(df.head(_MAX_SAMPLE_ROWS).to_markdown(index=False))
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════
# PASS 1 β€” LLM Analysis Planner
# ══════════════════════════════════════════════════════════════════
_PLAN_SYSTEM_PROMPT = """You are a senior data analyst planning an analysis for a junior analyst who will execute it in pandas.
You will receive:
1. A dataset schema (columns, dtypes, sample rows)
2. A user's analytical question
Your job is to produce a JSON array of analysis "operations" β€” a structured plan of what to compute.
Each operation is an object with these fields:
- "op_id" : short unique snake_case identifier (e.g. "top_salary", "avg_profit_region")
- "op_type" : one of the whitelisted operations listed below
- "label" : human-readable description of what this computes (e.g. "Top 5 days by Sales")
- "col" : primary column name to operate on (must match schema exactly)
- "group_by" : column name to group by (null if not applicable)
- "n" : integer β€” for top_n / bottom_n operations (null otherwise)
- "col2" : secondary column for ratio operations (null if not applicable)
WHITELISTED op_types:
"top_n" β€” Top N rows by col (sorted desc), optionally grouped by group_by
"bottom_n" β€” Bottom N rows by col (sorted asc)
"mean_by_group" β€” Mean of col grouped by group_by
"sum_by_group" β€” Sum of col grouped by group_by
"count_by_group" β€” Count of rows grouped by group_by
"std_by_group" β€” Std deviation of col grouped by group_by (for volatility/stability)
"ratio_by_group" β€” Compute col/col2 ratio then mean by group_by (for margins)
"distribution" β€” Percentile summary of col (min, p25, median, p75, max, mean, std)
"outliers" β€” IQR-based outlier detection on col, with group_by column for labeling
"correlation" β€” Correlation matrix of all numeric columns
"overall_summary" β€” Full descriptive statistics of all numeric columns
"value_counts" β€” Frequency count of categorical col
"time_trend" β€” Group col by group_by (date column), compute sum/mean of col
RULES:
1. Output ONLY a raw JSON array β€” no markdown, no backticks, no explanation.
2. Use EXACT column names from the schema.
3. Choose only the operations actually needed to answer the question. Do not add unnecessary operations.
4. Maximum 6 operations per plan.
5. For complex questions (e.g. "which segment should be prioritized for growth"), include
multiple complementary operations (e.g. mean_by_group + std_by_group + sum_by_group).
Example output for "What is the average profit per region?":
[
{
"op_id": "avg_profit_region",
"op_type": "mean_by_group",
"label": "Average Profit by Region",
"col": "Profit",
"group_by": "Region",
"n": null,
"col2": null
}
]
"""
def _plan_analysis(schema: str, query: str) -> list[dict]:
"""PASS 1 β€” Ask LLM to produce a structured analysis plan."""
llm = init_chat_model(GENERATION_MODEL_NAME)
messages = [
{"role": "system", "content": _PLAN_SYSTEM_PROMPT},
{"role": "user", "content": f"Dataset schema:\n{schema}\n\nUser question:\n{query}\n\nOutput ONLY the raw JSON array."},
]
response = llm.invoke(messages)
raw = response.content.strip()
record_llm_call(
use_case="data_analysis_plan",
output_text=raw,
response=response,
model_name=GENERATION_MODEL_NAME,
)
raw = re.sub(r"^```(?:json)?\s*", "", raw)
raw = re.sub(r"\s*```$", "", raw)
raw = raw.strip()
try:
plan = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError(f"LLM returned invalid JSON plan: {e}\nRaw: {raw[:400]}")
if not isinstance(plan, list):
raise ValueError("Plan must be a JSON array")
return plan
# ══════════════════════════════════════════════════════════════════
# EXECUTE β€” Whitelisted pandas operations
# ══════════════════════════════════════════════════════════════════
def _safe_val(v: Any) -> Any:
"""Convert numpy/pandas scalars to JSON-safe Python types."""
if isinstance(v, (np.integer,)): return int(v)
if isinstance(v, (np.floating,)):
if math.isnan(v) or math.isinf(v): return None
return round(float(v), 4)
if isinstance(v, float):
if math.isnan(v) or math.isinf(v): return None
return round(v, 4)
if isinstance(v, (np.bool_,)): return bool(v)
if pd.isna(v): return None
return v
def _df_to_records(df: pd.DataFrame, max_rows: int = _MAX_RESULT_ROWS) -> list[dict]:
"""Convert a dataframe to a list of JSON-safe dicts."""
df = df.head(max_rows).copy()
# Round numeric columns to 4dp
for col in df.select_dtypes(include="number").columns:
df[col] = df[col].apply(lambda x: round(float(x), 4) if pd.notna(x) and not math.isinf(float(x)) else None)
records = df.to_dict(orient="records")
return [{k: _safe_val(v) for k, v in row.items()} for row in records]
def _execute_operation(op: dict, df: pd.DataFrame) -> dict:
"""
Execute one whitelisted operation and return a result dict.
Returns: { op_id, label, op_type, result_type, data, columns, error? }
"""
op_id = op.get("op_id", "unnamed")
op_type = op.get("op_type", "")
label = op.get("label", op_id)
col = op.get("col")
group = op.get("group_by")
n = op.get("n") or 10
col2 = op.get("col2")
base = {"op_id": op_id, "label": label, "op_type": op_type}
try:
# ── top_n ──────────────────────────────────────────────
if op_type == "top_n":
if group:
result = (
df.groupby(group)[col]
.mean()
.reset_index()
.sort_values(col, ascending=False)
.head(n)
)
else:
cols_keep = [c for c in [col, group] if c]
result = df.nlargest(n, col)[list(set(df.columns) & set(cols_keep + [col]))]
# Include a sensible label column if available
label_candidates = [c for c in df.columns if df[c].dtype == object and c != col]
if label_candidates and label_candidates[0] not in result.columns:
result = df.nlargest(n, col)[[label_candidates[0], col]]
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": list(result.columns)}
# ── bottom_n ───────────────────────────────────────────
elif op_type == "bottom_n":
label_candidates = [c for c in df.columns if df[c].dtype == object and c != col]
if label_candidates:
result = df.nsmallest(n, col)[[label_candidates[0], col]]
else:
result = df.nsmallest(n, col)[[col]]
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": list(result.columns)}
# ── mean_by_group ──────────────────────────────────────
elif op_type == "mean_by_group":
result = df.groupby(group)[col].mean().reset_index().sort_values(col, ascending=False)
result[col] = result[col].round(4)
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": [group, col]}
# ── sum_by_group ───────────────────────────────────────
elif op_type == "sum_by_group":
result = df.groupby(group)[col].sum().reset_index().sort_values(col, ascending=False)
result[col] = result[col].round(4)
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": [group, col]}
# ── count_by_group ─────────────────────────────────────
elif op_type == "count_by_group":
result = df.groupby(group).size().reset_index(name="count").sort_values("count", ascending=False)
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": [group, "count"]}
# ── std_by_group ───────────────────────────────────────
elif op_type == "std_by_group":
result = df.groupby(group)[col].std().reset_index().sort_values(col, ascending=True)
result.columns = [group, f"{col}_std"]
result[f"{col}_std"] = result[f"{col}_std"].round(4)
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": list(result.columns)}
# ── ratio_by_group ─────────────────────────────────────
elif op_type == "ratio_by_group":
if not col2:
raise ValueError("ratio_by_group requires col2")
ratio_col = f"{col}_margin"
temp = df.copy()
temp[ratio_col] = temp.apply(
lambda r: (r[col] / r[col2]) if pd.notna(r[col2]) and r[col2] != 0 else None,
axis=1
)
result = temp.groupby(group)[ratio_col].mean().reset_index().sort_values(ratio_col, ascending=False)
result[ratio_col] = result[ratio_col].round(4)
return {**base, "result_type": "table", "data": _df_to_records(result), "columns": [group, ratio_col]}
# ── distribution ───────────────────────────────────────
elif op_type == "distribution":
s = df[col].dropna()
dist = {
"count": int(len(s)),
"min": _safe_val(s.min()),
"p25": _safe_val(s.quantile(0.25)),
"median":_safe_val(s.median()),
"p75": _safe_val(s.quantile(0.75)),
"max": _safe_val(s.max()),
"mean": _safe_val(s.mean()),
"std": _safe_val(s.std()),
}
return {**base, "result_type": "scalar_dict", "data": dist, "columns": list(dist.keys())}
# ── outliers ───────────────────────────────────────────
elif op_type == "outliers":
q1 = df[col].quantile(0.25)
q3 = df[col].quantile(0.75)
iqr = q3 - q1
low = q1 - 1.5 * iqr
high = q3 + 1.5 * iqr
outlier_df = df[(df[col] < low) | (df[col] > high)].copy()
keep_cols = [col]
if group and group in df.columns:
keep_cols = [group, col]
# Try to add a label column
label_candidates = [c for c in df.columns if df[c].dtype == object and c not in keep_cols]
if label_candidates:
keep_cols = [label_candidates[0]] + keep_cols
outlier_df = outlier_df[keep_cols].sort_values(col, ascending=False).head(_MAX_RESULT_ROWS)
summary = {
"total_outliers": int(len(df[(df[col] < low) | (df[col] > high)])),
"iqr_low_bound": _safe_val(low),
"iqr_high_bound": _safe_val(high),
"q1": _safe_val(q1), "q3": _safe_val(q3), "iqr": _safe_val(iqr),
}
return {
**base,
"result_type": "outliers",
"data": _df_to_records(outlier_df),
"columns": keep_cols,
"summary": summary,
}
# ── correlation ────────────────────────────────────────
elif op_type == "correlation":
num_cols = df.select_dtypes(include="number").columns.tolist()
corr = df[num_cols].corr().round(4)
records = corr.reset_index().rename(columns={"index": "column"})
return {**base, "result_type": "table", "data": _df_to_records(records, 30), "columns": list(records.columns)}
# ── overall_summary ────────────────────────────────────
elif op_type == "overall_summary":
desc = df.describe(include="number").T.reset_index().rename(columns={"index": "column"})
desc = desc.round(4)
return {**base, "result_type": "table", "data": _df_to_records(desc, 30), "columns": list(desc.columns)}
# ── value_counts ───────────────────────────────────────
elif op_type == "value_counts":
vc = df[col].value_counts().reset_index()
vc.columns = [col, "count"]
return {**base, "result_type": "table", "data": _df_to_records(vc), "columns": [col, "count"]}
# ── time_trend ─────────────────────────────────────────
elif op_type == "time_trend":
if group not in df.columns:
raise ValueError(f"time_trend: column '{group}' not found")
temp = df.copy()
temp[group] = pd.to_datetime(temp[group], errors="coerce")
temp = temp.dropna(subset=[group])
# Try monthly grouping first, fall back to daily
try:
temp["_period"] = temp[group].dt.to_period("M").astype(str)
except Exception:
temp["_period"] = temp[group].dt.strftime("%Y-%m-%d")
result = (
temp.groupby("_period")[col]
.sum()
.reset_index()
.rename(columns={"_period": group})
.sort_values(group)
)
result[col] = result[col].round(4)
return {**base, "result_type": "table", "data": _df_to_records(result, 36), "columns": [group, col]}
else:
return {**base, "result_type": "error", "error": f"Unknown op_type: '{op_type}'"}
except Exception as e:
return {**base, "result_type": "error", "error": str(e), "detail": traceback.format_exc()}
def execute_plan(plan: list[dict], df: pd.DataFrame) -> list[dict]:
"""Execute all operations in the plan and return results."""
return [_execute_operation(op, df) for op in plan]
# ══════════════════════════════════════════════════════════════════
# PASS 2 β€” LLM Interpreter
# ══════════════════════════════════════════════════════════════════
_INTERPRET_SYSTEM_PROMPT = """You are a senior business data analyst delivering insights to an executive audience.
You will receive:
1. The user's analytical question
2. Pre-computed results from a pandas analysis (actual numbers, tables, ranked lists)
Your job is to write a comprehensive, structured analytical response based ONLY on the provided results.
OUTPUT FORMAT β€” return a single JSON object with these keys:
{
"headline": "One crisp sentence summarising the single most important finding.",
"narrative": "3-6 paragraph detailed analytical write-up. Be specific β€” cite actual numbers from the results. Explain patterns, causes, and business implications. Write like a McKinsey analyst, not a chatbot.",
"key_findings": [
"Bullet point 1 β€” specific finding with a number",
"Bullet point 2 β€” specific finding with a number",
"Bullet point 3 β€” specific finding with a number"
],
"recommendation": "1-2 sentences of actionable recommendation based on the findings. If the question is purely factual (e.g. 'who has highest salary'), set this to null.",
"stats_table": [
{"label": "Metric name", "value": "formatted value", "note": "optional context"}
]
}
RULES:
1. Output ONLY raw JSON β€” no markdown, no backticks, no explanation outside the JSON.
2. Cite EXACT numbers from the computed results. Never invent numbers.
3. key_findings must have 3-6 items, each starting with a capital letter.
4. stats_table should capture the top 5-8 most important numeric findings as key-value pairs.
5. narrative must be substantive β€” minimum 100 words.
6. If a result contains an error, acknowledge it gracefully and work with the other results.
"""
def _interpret_results(query: str, results: list[dict], df: pd.DataFrame) -> dict:
"""PASS 2 β€” Ask LLM to interpret computed results into a structured analytical response."""
llm = init_chat_model(GENERATION_MODEL_NAME)
# Serialize results compactly for the LLM
results_text = json.dumps(results, indent=2, default=str)
# Trim if very long
if len(results_text) > 8000:
results_text = results_text[:8000] + "\n... [truncated for length]"
user_message = (
f"User question: {query}\n\n"
f"Dataset: {df.shape[0]} rows, columns: {list(df.columns)}\n\n"
f"Computed analysis results:\n{results_text}\n\n"
"Write the structured analytical response as a JSON object. Output ONLY raw JSON."
)
messages = [
{"role": "system", "content": _INTERPRET_SYSTEM_PROMPT},
{"role": "user", "content": user_message},
]
response = llm.invoke(messages)
raw = response.content.strip()
record_llm_call(
use_case="data_analysis_interpretation",
output_text=raw,
response=response,
model_name=GENERATION_MODEL_NAME,
)
raw = re.sub(r"^```(?:json)?\s*", "", raw)
raw = re.sub(r"\s*```$", "", raw)
raw = raw.strip()
try:
interpretation = json.loads(raw)
except json.JSONDecodeError:
# Graceful fallback if JSON is malformed
interpretation = {
"headline": "Analysis complete.",
"narrative": raw[:2000], # use raw text as narrative
"key_findings": [],
"recommendation": None,
"stats_table": [],
}
return interpretation
# ══════════════════════════════════════════════════════════════════
# RESULT BUILDER β€” builds the final stats table the UI will render
# ══════════════════════════════════════════════════════════════════
def _build_primary_table(results: list[dict]) -> dict | None:
"""
Pick the most relevant result table to surface in the UI card.
Returns the first table-type result that has data.
"""
for r in results:
if r.get("result_type") in ("table", "outliers") and r.get("data"):
return {
"label": r["label"],
"columns": r["columns"],
"rows": r["data"][:15], # cap at 15 rows in UI
}
return None
# ══════════════════════════════════════════════════════════════════
# Main entry point
# ══════════════════════════════════════════════════════════════════
def run_data_analysis_agent(query: str, filename: str) -> dict:
"""
Main entry point called by the Flask route /agent/analyze.
Args:
query : User's analytical question in natural language
filename : Dataset filename (must exist in data/datasets/)
Returns dict:
success : bool
headline : str β€” one-line finding
narrative : str β€” full analytical write-up
key_findings : list[str]
recommendation : str | None
stats_table : list[{label, value, note}]
primary_table : {label, columns, rows} | None β€” best result table
operations : list β€” the operations that were executed
filename, rows, columns
error : str (only on failure)
"""
try:
# 1. Load dataset
df = load_dataset(filename)
# 2. Schema for LLM
schema = build_schema_summary(df)
# 3. PASS 1 β€” get analysis plan
plan = _plan_analysis(schema, query)
# 4. EXECUTE β€” run whitelisted pandas ops
results = execute_plan(plan, df)
# 5. PASS 2 β€” interpret results
interpretation = _interpret_results(query, results, df)
# 6. Build primary display table
primary_table = _build_primary_table(results)
return {
"success": True,
"headline": interpretation.get("headline", ""),
"narrative": interpretation.get("narrative", ""),
"key_findings": interpretation.get("key_findings", []),
"recommendation": interpretation.get("recommendation"),
"stats_table": interpretation.get("stats_table", []),
"primary_table": primary_table,
"operations": [
{"op_id": r["op_id"], "label": r["label"], "status": "ok" if r.get("result_type") != "error" else "error"}
for r in results
],
"filename": filename,
"rows": df.shape[0],
"columns": list(df.columns),
}
except FileNotFoundError as e:
return {"success": False, "error": str(e)}
except ValueError as e:
return {"success": False, "error": str(e)}
except Exception as e:
return {
"success": False,
"error": f"Unexpected error: {str(e)}",
"detail": traceback.format_exc(),
}