Spaces:
Paused
Paused
| import json | |
| from typing import TypedDict, Annotated, List, Union, Any | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import os | |
| import pandas as pd | |
| import uuid | |
| from langchain_groq import ChatGroq | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langgraph.graph import StateGraph, END | |
| from app.core.config import settings | |
| from app.db.database import get_db_schema, engine | |
| from app.services.pdf_generator import generate_pdf_report | |
| from app.services.analytics import run_advanced_analytics | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| # Define State | |
| class AgentState(TypedDict, total=False): | |
| query: str | |
| history: List[dict] | |
| schema: str | |
| sql_query: str | |
| data: Any # Pandas DataFrame as dict or list | |
| visualization_path: str | |
| visualization_summary: str | |
| trend_analysis: dict | |
| anomaly_analysis: dict | |
| forecast_analysis: dict | |
| statistical_tests: dict | |
| insights: str | |
| report_path: str | |
| error: str | |
| def _format_history(history: List[dict]) -> str: | |
| if not history: | |
| return "None" | |
| rendered = [] | |
| for turn in history[-5:]: | |
| question = turn.get("question", "") | |
| answer = turn.get("insights", "") | |
| rendered.append(f"User: {question}\nAgent: {answer}") | |
| return "\n---\n".join(rendered) | |
| # LLM Setup | |
| def get_llm(): | |
| if not settings.GROQ_API_KEY: | |
| # Fallback or mock if needed, but for now assume key is present or will fail | |
| return None | |
| return ChatGroq( | |
| temperature=0, | |
| model_name="openai/gpt-oss-120b", | |
| api_key=settings.GROQ_API_KEY, | |
| ) | |
| def _summarize_dataframe(df: pd.DataFrame) -> List[dict]: | |
| summary = [] | |
| for col in df.columns: | |
| series = df[col] | |
| summary.append( | |
| { | |
| "name": col, | |
| "dtype": str(series.dtype), | |
| "numeric": pd.api.types.is_numeric_dtype(series), | |
| "unique_count": int(series.nunique()), | |
| "sample_values": series.dropna().astype(str).unique().tolist()[:3], | |
| } | |
| ) | |
| return summary | |
| def _fallback_chart_plan(df: pd.DataFrame) -> dict: | |
| numeric_cols = list(df.select_dtypes(include=["number", "bool"]).columns) | |
| categorical_cols = list(df.select_dtypes(include=["object", "category"]).columns) | |
| if categorical_cols and numeric_cols: | |
| return { | |
| "chart_type": "bar", | |
| "x_field": categorical_cols[0], | |
| "y_field": numeric_cols[0], | |
| "aggregation": "sum", | |
| "top_n": 10, | |
| "explanation": f"Bar chart of {numeric_cols[0]} by {categorical_cols[0]}" | |
| } | |
| if len(numeric_cols) >= 2: | |
| return { | |
| "chart_type": "line", | |
| "x_field": numeric_cols[0], | |
| "y_field": numeric_cols[1], | |
| "aggregation": "none", | |
| "top_n": 50, | |
| "explanation": f"Line plot of {numeric_cols[1]} vs {numeric_cols[0]}" | |
| } | |
| return { | |
| "chart_type": "table", | |
| "x_field": None, | |
| "y_field": None, | |
| "aggregation": "none", | |
| "top_n": 0, | |
| "explanation": "No suitable numeric/categorical combination for chart" | |
| } | |
| def _parse_json_response(text: str) -> dict: | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = text.strip("`") | |
| if text.startswith("json"): | |
| text = text[4:] | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1: | |
| candidate = text[start:end + 1] | |
| return json.loads(candidate) | |
| return json.loads(text) | |
| def _suggest_chart_plan(df: pd.DataFrame, query: str) -> dict: | |
| plan = _fallback_chart_plan(df) | |
| llm = get_llm() | |
| if not llm: | |
| return plan | |
| columns_meta = _summarize_dataframe(df) | |
| sample_rows = df.head(5).to_dict(orient="records") | |
| template = """ | |
| You are an analytics visualization planner. Based on the user's question, the column metadata, and sample rows, choose the most appropriate chart to highlight the insight. | |
| Allowed chart_type values: bar, line, area, scatter, pie, table. | |
| aggregation can be sum, mean, avg, average, count, or none. Use count when only frequency matters. | |
| Return ONLY valid JSON with keys: chart_type, x_field, y_field (nullable), aggregation, top_n (int), explanation. | |
| Make sure fields exist in the dataset and chart type matches their dtypes (categorical for x axis on bar/pie, numeric for y). | |
| Pick at most top 12 categories when using bar/pie. | |
| Columns: {columns} | |
| Sample rows: {sample} | |
| User question: {query} | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| chain = prompt | llm | StrOutputParser() | |
| try: | |
| response = chain.invoke({ | |
| "columns": json.dumps(columns_meta, ensure_ascii=False), | |
| "sample": json.dumps(sample_rows, ensure_ascii=False), | |
| "query": query, | |
| }) | |
| plan = _parse_json_response(response) | |
| except Exception: | |
| # keep fallback plan | |
| plan.setdefault("explanation", "Heuristic visualization applied") | |
| return plan | |
| def _aggregate_for_chart(df: pd.DataFrame, x_field: str, y_field: str, aggregation: str) -> pd.DataFrame: | |
| if not x_field or x_field not in df.columns: | |
| return pd.DataFrame() | |
| agg = (aggregation or "sum").lower() | |
| if agg in ("sum", "total", "mean", "avg", "average"): | |
| target_col = y_field if y_field in df.columns else None | |
| if not target_col: | |
| numeric_cols = df.select_dtypes(include=["number", "bool"]).columns | |
| target_col = numeric_cols[0] if len(numeric_cols) else None | |
| if not target_col or not pd.api.types.is_numeric_dtype(df[target_col]): | |
| return pd.DataFrame() | |
| agg_fn = "mean" if agg in ("mean", "avg", "average") else "sum" | |
| grouped = df.groupby(x_field)[target_col].agg(agg_fn).reset_index() | |
| return grouped.rename(columns={target_col: "value"}) | |
| if agg == "count": | |
| grouped = df.groupby(x_field).size().reset_index(name="value") | |
| return grouped | |
| if y_field and y_field in df.columns and pd.api.types.is_numeric_dtype(df[y_field]): | |
| subset = df[[x_field, y_field]].copy() | |
| subset = subset.rename(columns={y_field: "value"}) | |
| return subset | |
| return pd.DataFrame() | |
| def _render_chart(path: str, df: pd.DataFrame, plan: dict) -> str: | |
| chart_type = (plan.get("chart_type") or "bar").lower() | |
| x_field = plan.get("x_field") | |
| y_field = plan.get("y_field") | |
| agg = plan.get("aggregation") | |
| top_n = int(plan.get("top_n") or 12) | |
| plt.figure(figsize=(10, 6)) | |
| if chart_type == "scatter" and x_field and y_field: | |
| if x_field in df.columns and y_field in df.columns and \ | |
| pd.api.types.is_numeric_dtype(df[x_field]) and pd.api.types.is_numeric_dtype(df[y_field]): | |
| plot_df = df[[x_field, y_field]].dropna().head(top_n) | |
| if plot_df.empty: | |
| return "" | |
| plt.scatter(plot_df[x_field], plot_df[y_field], color="#5cd4f4") | |
| plt.xlabel(x_field) | |
| plt.ylabel(y_field) | |
| plt.title(plan.get("explanation", f"{y_field} vs {x_field}")) | |
| plt.tight_layout() | |
| plt.savefig(path, bbox_inches="tight") | |
| plt.close() | |
| return path | |
| return "" | |
| if not x_field: | |
| return "" | |
| plot_df = _aggregate_for_chart(df, x_field, y_field, agg) | |
| if plot_df.empty: | |
| return "" | |
| plot_df = plot_df.sort_values("value", ascending=False) | |
| if top_n > 0: | |
| plot_df = plot_df.head(top_n) | |
| if chart_type == "pie": | |
| plot_df.set_index(x_field)["value"].plot(kind="pie", autopct="%1.1f%%") | |
| plt.ylabel("") | |
| elif chart_type == "line": | |
| plt.plot(plot_df[x_field], plot_df["value"], marker="o") | |
| elif chart_type == "area": | |
| plt.fill_between(plot_df[x_field], plot_df["value"], alpha=0.4) | |
| plt.plot(plot_df[x_field], plot_df["value"], color="#7a83ff") | |
| else: | |
| plt.bar(plot_df[x_field], plot_df["value"], color="#7a83ff") | |
| plt.xticks(rotation=45, ha="right") | |
| plt.xlabel(x_field) | |
| plt.ylabel(plan.get("y_field") or "Value") | |
| plt.title(plan.get("explanation", "Visualization")) | |
| plt.tight_layout() | |
| plt.savefig(path, bbox_inches="tight") | |
| plt.close() | |
| return path | |
| # Nodes | |
| def get_schema_node(state: AgentState): | |
| schema = get_db_schema() | |
| return {"schema": schema} | |
| def generate_sql_node(state: AgentState): | |
| llm = get_llm() | |
| if not llm: | |
| return {"error": "LLM not configured"} | |
| template = """ | |
| You are a SQL expert. Convert the following natural language query into a SQL query for SQLite. | |
| Schema: | |
| {schema} | |
| Recent conversation: | |
| {history} | |
| Current Query: {query} | |
| Return ONLY the SQL query, nothing else. Do not wrap in markdown code blocks. | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| chain = prompt | llm | StrOutputParser() | |
| try: | |
| sql_query = chain.invoke({ | |
| "schema": state["schema"], | |
| "history": _format_history(state.get("history", [])), | |
| "query": state["query"], | |
| }) | |
| # Clean up sql if needed | |
| sql_query = sql_query.replace("```sql", "").replace("```", "").strip() | |
| return {"sql_query": sql_query} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def execute_sql_node(state: AgentState): | |
| if state.get("error"): | |
| return state | |
| try: | |
| df = pd.read_sql(state["sql_query"], engine) | |
| return {"data": df.to_dict(orient="records")} | |
| except Exception as e: | |
| return {"error": f"SQL Execution failed: {str(e)}"} | |
| def generate_visualization_node(state: AgentState): | |
| if state.get("error") or not state.get("data"): | |
| return state | |
| df = pd.DataFrame(state["data"]) | |
| if df.empty: | |
| return {"visualization_path": None, "visualization_summary": "No data to visualize."} | |
| plan = _suggest_chart_plan(df, state.get("query", "")) | |
| filename = f"chart_{uuid.uuid4()}.png" | |
| path = os.path.join("backend", "static", filename) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| image_path = _render_chart(path, df, plan) | |
| if not image_path: | |
| return {"visualization_path": None, "visualization_summary": plan.get("explanation")} | |
| return {"visualization_path": image_path, "visualization_summary": plan.get("explanation")} | |
| def advanced_analytics_node(state: AgentState): | |
| if state.get("error") or not state.get("data"): | |
| return state | |
| df = pd.DataFrame(state["data"]) | |
| if df.empty: | |
| return {"trend_analysis": None, "anomaly_analysis": None, "forecast_analysis": None, "statistical_tests": None} | |
| analytics = run_advanced_analytics(df) | |
| return { | |
| "trend_analysis": analytics.get("trend"), | |
| "anomaly_analysis": analytics.get("anomaly"), | |
| "forecast_analysis": analytics.get("forecast"), | |
| "statistical_tests": analytics.get("statistical_tests"), | |
| } | |
| def generate_insights_node(state: AgentState): | |
| if state.get("error"): | |
| return state | |
| llm = get_llm() | |
| if not llm: | |
| return {"insights": "LLM not configured"} | |
| data_summary = str(state["data"])[:2000] # Truncate if too long | |
| template = """ | |
| You are an analytics copilot. Use the latest query, the conversation history, the data sample, and the derived diagnostics (trends, anomalies, forecasts, and statistical tests) to provide incremental insights. If the user repeats a question, reference earlier answers instead of restating everything. | |
| History: | |
| {history} | |
| Current Query: {query} | |
| Data Sample: {data} | |
| Trend Analysis: {trend} | |
| Anomaly Analysis: {anomaly} | |
| Forecast Analysis: {forecast} | |
| Statistical Tests: {stats} | |
| Provide 3-5 concise bullet insights plus a short summary paragraph. Mention forecasts and statistical significance when available. | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| chain = prompt | llm | StrOutputParser() | |
| try: | |
| insights = chain.invoke({ | |
| "query": state["query"], | |
| "history": _format_history(state.get("history", [])), | |
| "data": data_summary, | |
| "trend": json.dumps(state.get("trend_analysis") or {}, ensure_ascii=False), | |
| "anomaly": json.dumps(state.get("anomaly_analysis") or {}, ensure_ascii=False), | |
| "forecast": json.dumps(state.get("forecast_analysis") or {}, ensure_ascii=False), | |
| "stats": json.dumps(state.get("statistical_tests") or {}, ensure_ascii=False), | |
| }) | |
| return {"insights": insights} | |
| except Exception as e: | |
| return {"insights": f"Failed to generate insights: {str(e)}"} | |
| def build_report_node(state: AgentState): | |
| if state.get("error"): | |
| return state | |
| filename = f"report_{uuid.uuid4()}.pdf" | |
| path = os.path.join("backend", "static", filename) | |
| try: | |
| generate_pdf_report( | |
| report_path=path, | |
| title="Autonomous Data Analyst Report", | |
| query=state.get("query", ""), | |
| sql_query=state.get("sql_query", ""), | |
| insights=state.get("insights", "No insights generated."), | |
| chart_image_path=state.get("visualization_path"), | |
| chart_summary=state.get("visualization_summary"), | |
| trend_analysis=state.get("trend_analysis"), | |
| anomaly_analysis=state.get("anomaly_analysis"), | |
| forecast_analysis=state.get("forecast_analysis"), | |
| statistical_tests=state.get("statistical_tests"), | |
| data_sample=state.get("data"), | |
| ) | |
| return {"report_path": path} | |
| except Exception as e: | |
| return {"error": f"Report generation failed: {str(e)}"} | |
| def create_agent_graph(): | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("get_schema", get_schema_node) | |
| workflow.add_node("generate_sql", generate_sql_node) | |
| workflow.add_node("execute_sql", execute_sql_node) | |
| workflow.add_node("visualize", generate_visualization_node) | |
| workflow.add_node("advanced_analytics", advanced_analytics_node) | |
| workflow.add_node("generate_insights", generate_insights_node) | |
| workflow.add_node("build_report", build_report_node) | |
| # Define edges | |
| workflow.set_entry_point("get_schema") | |
| workflow.add_edge("get_schema", "generate_sql") | |
| workflow.add_edge("generate_sql", "execute_sql") | |
| workflow.add_edge("execute_sql", "visualize") | |
| workflow.add_edge("visualize", "advanced_analytics") | |
| workflow.add_edge("advanced_analytics", "generate_insights") | |
| workflow.add_edge("generate_insights", "build_report") | |
| workflow.add_edge("build_report", END) | |
| return workflow.compile() | |