Himanshu Gangwar
initial commit
eff8aa5
import json
from typing import TypedDict, Annotated, List, Union, Any
import matplotlib
import matplotlib.pyplot as plt
import os
import pandas as pd
import uuid
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
from app.core.config import settings
from app.db.database import get_db_schema, engine
from app.services.pdf_generator import generate_pdf_report
from app.services.analytics import run_advanced_analytics
matplotlib.use('Agg') # Use non-interactive backend
# Define State
class AgentState(TypedDict, total=False):
query: str
history: List[dict]
schema: str
sql_query: str
data: Any # Pandas DataFrame as dict or list
visualization_path: str
visualization_summary: str
trend_analysis: dict
anomaly_analysis: dict
forecast_analysis: dict
statistical_tests: dict
insights: str
report_path: str
error: str
def _format_history(history: List[dict]) -> str:
if not history:
return "None"
rendered = []
for turn in history[-5:]:
question = turn.get("question", "")
answer = turn.get("insights", "")
rendered.append(f"User: {question}\nAgent: {answer}")
return "\n---\n".join(rendered)
# LLM Setup
def get_llm():
if not settings.GROQ_API_KEY:
# Fallback or mock if needed, but for now assume key is present or will fail
return None
return ChatGroq(
temperature=0,
model_name="openai/gpt-oss-120b",
api_key=settings.GROQ_API_KEY,
)
def _summarize_dataframe(df: pd.DataFrame) -> List[dict]:
summary = []
for col in df.columns:
series = df[col]
summary.append(
{
"name": col,
"dtype": str(series.dtype),
"numeric": pd.api.types.is_numeric_dtype(series),
"unique_count": int(series.nunique()),
"sample_values": series.dropna().astype(str).unique().tolist()[:3],
}
)
return summary
def _fallback_chart_plan(df: pd.DataFrame) -> dict:
numeric_cols = list(df.select_dtypes(include=["number", "bool"]).columns)
categorical_cols = list(df.select_dtypes(include=["object", "category"]).columns)
if categorical_cols and numeric_cols:
return {
"chart_type": "bar",
"x_field": categorical_cols[0],
"y_field": numeric_cols[0],
"aggregation": "sum",
"top_n": 10,
"explanation": f"Bar chart of {numeric_cols[0]} by {categorical_cols[0]}"
}
if len(numeric_cols) >= 2:
return {
"chart_type": "line",
"x_field": numeric_cols[0],
"y_field": numeric_cols[1],
"aggregation": "none",
"top_n": 50,
"explanation": f"Line plot of {numeric_cols[1]} vs {numeric_cols[0]}"
}
return {
"chart_type": "table",
"x_field": None,
"y_field": None,
"aggregation": "none",
"top_n": 0,
"explanation": "No suitable numeric/categorical combination for chart"
}
def _parse_json_response(text: str) -> dict:
text = text.strip()
if text.startswith("```"):
text = text.strip("`")
if text.startswith("json"):
text = text[4:]
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1:
candidate = text[start:end + 1]
return json.loads(candidate)
return json.loads(text)
def _suggest_chart_plan(df: pd.DataFrame, query: str) -> dict:
plan = _fallback_chart_plan(df)
llm = get_llm()
if not llm:
return plan
columns_meta = _summarize_dataframe(df)
sample_rows = df.head(5).to_dict(orient="records")
template = """
You are an analytics visualization planner. Based on the user's question, the column metadata, and sample rows, choose the most appropriate chart to highlight the insight.
Allowed chart_type values: bar, line, area, scatter, pie, table.
aggregation can be sum, mean, avg, average, count, or none. Use count when only frequency matters.
Return ONLY valid JSON with keys: chart_type, x_field, y_field (nullable), aggregation, top_n (int), explanation.
Make sure fields exist in the dataset and chart type matches their dtypes (categorical for x axis on bar/pie, numeric for y).
Pick at most top 12 categories when using bar/pie.
Columns: {columns}
Sample rows: {sample}
User question: {query}
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm | StrOutputParser()
try:
response = chain.invoke({
"columns": json.dumps(columns_meta, ensure_ascii=False),
"sample": json.dumps(sample_rows, ensure_ascii=False),
"query": query,
})
plan = _parse_json_response(response)
except Exception:
# keep fallback plan
plan.setdefault("explanation", "Heuristic visualization applied")
return plan
def _aggregate_for_chart(df: pd.DataFrame, x_field: str, y_field: str, aggregation: str) -> pd.DataFrame:
if not x_field or x_field not in df.columns:
return pd.DataFrame()
agg = (aggregation or "sum").lower()
if agg in ("sum", "total", "mean", "avg", "average"):
target_col = y_field if y_field in df.columns else None
if not target_col:
numeric_cols = df.select_dtypes(include=["number", "bool"]).columns
target_col = numeric_cols[0] if len(numeric_cols) else None
if not target_col or not pd.api.types.is_numeric_dtype(df[target_col]):
return pd.DataFrame()
agg_fn = "mean" if agg in ("mean", "avg", "average") else "sum"
grouped = df.groupby(x_field)[target_col].agg(agg_fn).reset_index()
return grouped.rename(columns={target_col: "value"})
if agg == "count":
grouped = df.groupby(x_field).size().reset_index(name="value")
return grouped
if y_field and y_field in df.columns and pd.api.types.is_numeric_dtype(df[y_field]):
subset = df[[x_field, y_field]].copy()
subset = subset.rename(columns={y_field: "value"})
return subset
return pd.DataFrame()
def _render_chart(path: str, df: pd.DataFrame, plan: dict) -> str:
chart_type = (plan.get("chart_type") or "bar").lower()
x_field = plan.get("x_field")
y_field = plan.get("y_field")
agg = plan.get("aggregation")
top_n = int(plan.get("top_n") or 12)
plt.figure(figsize=(10, 6))
if chart_type == "scatter" and x_field and y_field:
if x_field in df.columns and y_field in df.columns and \
pd.api.types.is_numeric_dtype(df[x_field]) and pd.api.types.is_numeric_dtype(df[y_field]):
plot_df = df[[x_field, y_field]].dropna().head(top_n)
if plot_df.empty:
return ""
plt.scatter(plot_df[x_field], plot_df[y_field], color="#5cd4f4")
plt.xlabel(x_field)
plt.ylabel(y_field)
plt.title(plan.get("explanation", f"{y_field} vs {x_field}"))
plt.tight_layout()
plt.savefig(path, bbox_inches="tight")
plt.close()
return path
return ""
if not x_field:
return ""
plot_df = _aggregate_for_chart(df, x_field, y_field, agg)
if plot_df.empty:
return ""
plot_df = plot_df.sort_values("value", ascending=False)
if top_n > 0:
plot_df = plot_df.head(top_n)
if chart_type == "pie":
plot_df.set_index(x_field)["value"].plot(kind="pie", autopct="%1.1f%%")
plt.ylabel("")
elif chart_type == "line":
plt.plot(plot_df[x_field], plot_df["value"], marker="o")
elif chart_type == "area":
plt.fill_between(plot_df[x_field], plot_df["value"], alpha=0.4)
plt.plot(plot_df[x_field], plot_df["value"], color="#7a83ff")
else:
plt.bar(plot_df[x_field], plot_df["value"], color="#7a83ff")
plt.xticks(rotation=45, ha="right")
plt.xlabel(x_field)
plt.ylabel(plan.get("y_field") or "Value")
plt.title(plan.get("explanation", "Visualization"))
plt.tight_layout()
plt.savefig(path, bbox_inches="tight")
plt.close()
return path
# Nodes
def get_schema_node(state: AgentState):
schema = get_db_schema()
return {"schema": schema}
def generate_sql_node(state: AgentState):
llm = get_llm()
if not llm:
return {"error": "LLM not configured"}
template = """
You are a SQL expert. Convert the following natural language query into a SQL query for SQLite.
Schema:
{schema}
Recent conversation:
{history}
Current Query: {query}
Return ONLY the SQL query, nothing else. Do not wrap in markdown code blocks.
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm | StrOutputParser()
try:
sql_query = chain.invoke({
"schema": state["schema"],
"history": _format_history(state.get("history", [])),
"query": state["query"],
})
# Clean up sql if needed
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
return {"sql_query": sql_query}
except Exception as e:
return {"error": str(e)}
def execute_sql_node(state: AgentState):
if state.get("error"):
return state
try:
df = pd.read_sql(state["sql_query"], engine)
return {"data": df.to_dict(orient="records")}
except Exception as e:
return {"error": f"SQL Execution failed: {str(e)}"}
def generate_visualization_node(state: AgentState):
if state.get("error") or not state.get("data"):
return state
df = pd.DataFrame(state["data"])
if df.empty:
return {"visualization_path": None, "visualization_summary": "No data to visualize."}
plan = _suggest_chart_plan(df, state.get("query", ""))
filename = f"chart_{uuid.uuid4()}.png"
path = os.path.join("backend", "static", filename)
os.makedirs(os.path.dirname(path), exist_ok=True)
image_path = _render_chart(path, df, plan)
if not image_path:
return {"visualization_path": None, "visualization_summary": plan.get("explanation")}
return {"visualization_path": image_path, "visualization_summary": plan.get("explanation")}
def advanced_analytics_node(state: AgentState):
if state.get("error") or not state.get("data"):
return state
df = pd.DataFrame(state["data"])
if df.empty:
return {"trend_analysis": None, "anomaly_analysis": None, "forecast_analysis": None, "statistical_tests": None}
analytics = run_advanced_analytics(df)
return {
"trend_analysis": analytics.get("trend"),
"anomaly_analysis": analytics.get("anomaly"),
"forecast_analysis": analytics.get("forecast"),
"statistical_tests": analytics.get("statistical_tests"),
}
def generate_insights_node(state: AgentState):
if state.get("error"):
return state
llm = get_llm()
if not llm:
return {"insights": "LLM not configured"}
data_summary = str(state["data"])[:2000] # Truncate if too long
template = """
You are an analytics copilot. Use the latest query, the conversation history, the data sample, and the derived diagnostics (trends, anomalies, forecasts, and statistical tests) to provide incremental insights. If the user repeats a question, reference earlier answers instead of restating everything.
History:
{history}
Current Query: {query}
Data Sample: {data}
Trend Analysis: {trend}
Anomaly Analysis: {anomaly}
Forecast Analysis: {forecast}
Statistical Tests: {stats}
Provide 3-5 concise bullet insights plus a short summary paragraph. Mention forecasts and statistical significance when available.
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm | StrOutputParser()
try:
insights = chain.invoke({
"query": state["query"],
"history": _format_history(state.get("history", [])),
"data": data_summary,
"trend": json.dumps(state.get("trend_analysis") or {}, ensure_ascii=False),
"anomaly": json.dumps(state.get("anomaly_analysis") or {}, ensure_ascii=False),
"forecast": json.dumps(state.get("forecast_analysis") or {}, ensure_ascii=False),
"stats": json.dumps(state.get("statistical_tests") or {}, ensure_ascii=False),
})
return {"insights": insights}
except Exception as e:
return {"insights": f"Failed to generate insights: {str(e)}"}
def build_report_node(state: AgentState):
if state.get("error"):
return state
filename = f"report_{uuid.uuid4()}.pdf"
path = os.path.join("backend", "static", filename)
try:
generate_pdf_report(
report_path=path,
title="Autonomous Data Analyst Report",
query=state.get("query", ""),
sql_query=state.get("sql_query", ""),
insights=state.get("insights", "No insights generated."),
chart_image_path=state.get("visualization_path"),
chart_summary=state.get("visualization_summary"),
trend_analysis=state.get("trend_analysis"),
anomaly_analysis=state.get("anomaly_analysis"),
forecast_analysis=state.get("forecast_analysis"),
statistical_tests=state.get("statistical_tests"),
data_sample=state.get("data"),
)
return {"report_path": path}
except Exception as e:
return {"error": f"Report generation failed: {str(e)}"}
def create_agent_graph():
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("get_schema", get_schema_node)
workflow.add_node("generate_sql", generate_sql_node)
workflow.add_node("execute_sql", execute_sql_node)
workflow.add_node("visualize", generate_visualization_node)
workflow.add_node("advanced_analytics", advanced_analytics_node)
workflow.add_node("generate_insights", generate_insights_node)
workflow.add_node("build_report", build_report_node)
# Define edges
workflow.set_entry_point("get_schema")
workflow.add_edge("get_schema", "generate_sql")
workflow.add_edge("generate_sql", "execute_sql")
workflow.add_edge("execute_sql", "visualize")
workflow.add_edge("visualize", "advanced_analytics")
workflow.add_edge("advanced_analytics", "generate_insights")
workflow.add_edge("generate_insights", "build_report")
workflow.add_edge("build_report", END)
return workflow.compile()