| |
| |
|
|
| import json |
| from typing import TypedDict, Annotated, Literal |
| from langgraph.graph import StateGraph, END |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
| from src.llm import get_llm |
| from src.tools.agent_tools import ALL_TOOLS, sql_query, semantic_search, get_dataset_info |
| from src.data_platform.duckdb_store import run_query as duckdb_run_query |
| from src.retrieval.hybrid import hybrid_search |
| from src.config import DATASET_DESCRIPTION |
| from src.logger import logger |
|
|
|
|
| |
| |
| |
| |
|
|
| def _run_sql(query: str) -> str: |
| |
| logger.info(f"Direct call: sql_query({query[:100]}...)") |
| try: |
| df = duckdb_run_query(query) |
| if df.empty: |
| return "Query returned no results. Check your filters." |
| if len(df) > 50: |
| result = df.head(50).to_string(index=False) |
| result += f"\n\n... showing 50 of {len(df)} total rows" |
| else: |
| result = df.to_string(index=False) |
| return result |
| except ValueError as e: |
| return f"Query blocked: {e}" |
| except Exception as e: |
| return f"SQL error: {e}. Double-check the table/column names." |
|
|
|
|
| def _run_search(query: str) -> str: |
| |
| logger.info(f"Direct call: semantic_search({query[:100]}...)") |
| try: |
| results = hybrid_search(query) |
| confidence = results["confidence"] |
| if not results["results"]: |
| return "No relevant articles found for this query." |
|
|
| parts = [f"Search confidence: {confidence}", |
| f"Strategy: {results['strategy']}", "---"] |
| for i, doc in enumerate(results["results"], 1): |
| score_info = "" |
| if "rerank_score" in doc: |
| score_info = f" (relevance: {doc['rerank_score']:.3f})" |
| elif "rrf_score" in doc: |
| score_info = f" (RRF: {doc['rrf_score']:.4f})" |
| source = doc.get("metadata", {}).get("source", "unknown") |
| parts.append(f"[{i}]{score_info} ({source}): {doc['text']}") |
|
|
| if confidence in ("LOW", "NONE"): |
| parts.append("\nβ Low confidence β passages may not be directly relevant.") |
| return "\n\n".join(parts) |
| except Exception as e: |
| logger.error(f"Search failed: {e}") |
| return f"Search error: {e}" |
|
|
|
|
| |
| class AgentState(TypedDict): |
| query: str |
| plan: str |
| retrieved_data: str |
| draft_answer: str |
| critique: str |
| final_answer: str |
| confidence: str |
| sources: list[str] |
| retry_count: int |
| error: str |
|
|
|
|
| |
|
|
| PLANNER_PROMPT = f"""You are a financial data analyst planner. Your job is to look at a user's question |
| and create a step-by-step plan for answering it. |
| |
| You have access to: |
| {DATASET_DESCRIPTION} |
| |
| For each step in your plan, specify which tool to use: |
| - sql_query: for numerical lookups, aggregations, comparisons, filtering |
| - semantic_search: for opinions, trends, explanations, analyst sentiment |
| - get_dataset_info: if you need to check what data is available |
| |
| Think about: |
| 1. Does this question need structured data (numbers)? β SQL |
| 2. Does it need unstructured insights (opinions, reasons)? β semantic search |
| 3. Does it need both? β plan multiple steps |
| 4. Is it a multi-hop question? β break it into sub-questions |
| |
| Output your plan as a JSON array of steps, each with "tool" and "input" fields. |
| If the question is outside the dataset's coverage, say so in your plan. |
| Be specific with SQL β use the actual column names from the schema.""" |
|
|
| ANALYST_PROMPT = """You are a financial analyst. You receive a user's question along with |
| data retrieved from a financial database and news articles. |
| |
| Your job: |
| 1. Combine the structured data (SQL results) and unstructured data (article passages) |
| into a clear, well-organized answer |
| 2. Cite specific numbers from the SQL data when available β always include currency symbols |
| (e.g. $59.6 billion, not just 59,619 million). Revenue and financial figures are in USD. |
| 3. Reference the article sources when using qualitative information |
| 4. If the retrieved data doesn't fully answer the question, say what's missing |
| 5. Don't make up numbers or facts that aren't in the provided data |
| |
| Note: All monetary values in the database are in millions USD (column names ending in _mn). |
| So revenue_mn = 59619.85 means $59,619.85 million = $59.6 billion. |
| |
| Keep your tone professional but readable. Use bullet points for comparisons. |
| Include a "Sources" section at the end listing where each piece of info came from.""" |
|
|
| CRITIC_PROMPT = """You are a quick quality checker for financial analysis answers. |
| |
| Only reject an answer if it has one of these SERIOUS problems: |
| 1. WRONG NUMBERS: The answer states a number that directly contradicts the retrieved data. |
| 2. HALLUCINATION: The answer invents facts not present in the retrieved data. |
| 3. WRONG TOPIC: The answer doesn't address the user's question at all. |
| |
| Minor style issues, missing disclaimers, or incomplete sourcing are NOT reasons to reject. |
| When the retrieved data is limited, the answer should work with what's available β don't reject just because the data is sparse. |
| |
| Respond with exactly one of: |
| - APPROVED (if there are no serious problems listed above) |
| - REVISE: [one sentence describing the specific factual error to fix] |
| |
| Default to APPROVED unless there is a clear factual error.""" |
|
|
|
|
| |
|
|
| def planner_node(state: AgentState) -> dict: |
| |
| logger.info(f"Planner: analyzing query") |
| llm = get_llm() |
|
|
| |
| from src.data_platform.duckdb_store import get_schema_info |
| schema_info = DATASET_DESCRIPTION + "\n\n" + get_schema_info() |
|
|
| messages = [ |
| SystemMessage(content=PLANNER_PROMPT), |
| HumanMessage(content=f"Schema info:\n{schema_info}\n\nUser question: {state['query']}"), |
| ] |
|
|
| response = llm.invoke(messages) |
| plan = response.content |
| logger.info(f"Planner output: {plan[:200]}...") |
|
|
| return {"plan": plan} |
|
|
|
|
| def retriever_node(state: AgentState) -> dict: |
| |
| logger.info("Retriever: executing plan") |
| llm = get_llm() |
|
|
| plan_text = state["plan"] |
| all_results = [] |
|
|
| |
| try: |
| |
| import re |
| json_match = re.search(r'\[.*\]', plan_text, re.DOTALL) |
| if json_match: |
| steps = json.loads(json_match.group()) |
| else: |
| |
| steps = [ |
| {"tool": "semantic_search", "input": state["query"]}, |
| {"tool": "sql_query", "input": f"SELECT * FROM companies LIMIT 5"}, |
| ] |
| except json.JSONDecodeError: |
| logger.warning("Could not parse plan as JSON, falling back to direct search") |
| steps = [{"tool": "semantic_search", "input": state["query"]}] |
|
|
| |
| for i, step in enumerate(steps): |
| tool_name = step.get("tool", "semantic_search") |
| tool_input = step.get("input", state["query"]) |
|
|
| |
| |
| if isinstance(tool_input, dict): |
| tool_input = ( |
| tool_input.get("query") |
| or tool_input.get("input") |
| or tool_input.get("sql") |
| or next(iter(tool_input.values()), state["query"]) |
| ) |
| if isinstance(tool_input, (list, tuple)): |
| tool_input = " ".join(str(x) for x in tool_input) |
| if not isinstance(tool_input, str): |
| tool_input = str(tool_input) |
|
|
| logger.info(f" Step {i+1}: {tool_name}({tool_input[:80]}...)") |
|
|
| try: |
| |
| |
| if tool_name == "sql_query": |
| result = _run_sql(tool_input) |
| elif tool_name == "semantic_search": |
| result = _run_search(tool_input) |
| elif tool_name == "get_dataset_info": |
| result = DATASET_DESCRIPTION |
| else: |
| result = f"Unknown tool: {tool_name}" |
|
|
| all_results.append(f"--- Step {i+1} ({tool_name}) ---\n{result}") |
| except Exception as e: |
| logger.error(f" Step {i+1} failed: {e}") |
| all_results.append(f"--- Step {i+1} ({tool_name}) --- ERROR: {e}") |
|
|
| retrieved = "\n\n".join(all_results) |
| logger.info(f"Retriever: collected {len(all_results)} results") |
|
|
| return {"retrieved_data": retrieved} |
|
|
|
|
| def analyst_node(state: AgentState) -> dict: |
| |
| logger.info("Analyst: synthesizing answer") |
| llm = get_llm() |
|
|
| messages = [ |
| SystemMessage(content=ANALYST_PROMPT), |
| HumanMessage(content=( |
| f"User question: {state['query']}\n\n" |
| f"Execution plan:\n{state['plan']}\n\n" |
| f"Retrieved data:\n{state['retrieved_data']}" |
| )), |
| ] |
|
|
| response = llm.invoke(messages) |
| draft = response.content |
| logger.info(f"Analyst: produced {len(draft)} char answer") |
|
|
| return {"draft_answer": draft} |
|
|
|
|
| def critic_node(state: AgentState) -> dict: |
| |
| logger.info("Critic: reviewing answer") |
| llm = get_llm() |
|
|
| messages = [ |
| SystemMessage(content=CRITIC_PROMPT), |
| HumanMessage(content=( |
| f"User question: {state['query']}\n\n" |
| f"Retrieved data:\n{state['retrieved_data']}\n\n" |
| f"Draft answer:\n{state['draft_answer']}" |
| )), |
| ] |
|
|
| response = llm.invoke(messages) |
| critique = response.content.strip() |
| logger.info(f"Critic verdict: {critique[:100]}...") |
|
|
| return {"critique": critique} |
|
|
|
|
| def decide_after_critic(state: AgentState) -> Literal["finalize", "retry"]: |
| |
| critique = state.get("critique", "") |
| retry_count = state.get("retry_count", 0) |
|
|
| if "APPROVED" in critique.upper(): |
| return "finalize" |
|
|
| |
| if retry_count >= 1: |
| logger.warning("Max retries reached, finalizing with current answer") |
| return "finalize" |
|
|
| return "retry" |
|
|
|
|
| def retry_node(state: AgentState) -> dict: |
| |
| logger.info(f"Retry: incorporating feedback (attempt {state.get('retry_count', 0) + 1})") |
| llm = get_llm() |
|
|
| messages = [ |
| SystemMessage(content=ANALYST_PROMPT), |
| HumanMessage(content=( |
| f"User question: {state['query']}\n\n" |
| f"Retrieved data:\n{state['retrieved_data']}\n\n" |
| f"Your previous answer:\n{state['draft_answer']}\n\n" |
| f"Reviewer feedback:\n{state['critique']}\n\n" |
| f"Please revise your answer based on the feedback above." |
| )), |
| ] |
|
|
| response = llm.invoke(messages) |
| return { |
| "draft_answer": response.content, |
| "retry_count": state.get("retry_count", 0) + 1, |
| } |
|
|
|
|
| def finalize_node(state: AgentState) -> dict: |
| |
| answer = state.get("draft_answer", "I wasn't able to generate an answer.") |
|
|
| |
| critique = state.get("critique", "") |
| if "APPROVED" in critique.upper(): |
| confidence = "HIGH" |
| elif state.get("retry_count", 0) > 0: |
| confidence = "MEDIUM" |
| else: |
| confidence = "MEDIUM" |
|
|
| return { |
| "final_answer": answer, |
| "confidence": confidence, |
| } |
|
|
|
|
| def build_graph() -> StateGraph: |
| |
| graph = StateGraph(AgentState) |
|
|
| |
| graph.add_node("planner", planner_node) |
| graph.add_node("retriever", retriever_node) |
| graph.add_node("analyst", analyst_node) |
| graph.add_node("critic", critic_node) |
| graph.add_node("retry", retry_node) |
| graph.add_node("finalize", finalize_node) |
|
|
| |
| graph.set_entry_point("planner") |
| graph.add_edge("planner", "retriever") |
| graph.add_edge("retriever", "analyst") |
| graph.add_edge("analyst", "critic") |
|
|
| |
| graph.add_conditional_edges( |
| "critic", |
| decide_after_critic, |
| {"finalize": "finalize", "retry": "retry"}, |
| ) |
|
|
| |
| graph.add_edge("retry", "critic") |
| graph.add_edge("finalize", END) |
|
|
| return graph.compile() |
|
|
|
|
| |
| agent_graph = None |
|
|
|
|
| def get_agent(): |
| |
| global agent_graph |
| if agent_graph is None: |
| agent_graph = build_graph() |
| return agent_graph |
|
|
|
|
| def run_query(query: str) -> dict: |
| |
| agent = get_agent() |
|
|
| initial_state = { |
| "query": query, |
| "plan": "", |
| "retrieved_data": "", |
| "draft_answer": "", |
| "critique": "", |
| "final_answer": "", |
| "confidence": "", |
| "sources": [], |
| "retry_count": 0, |
| "error": "", |
| } |
|
|
| logger.info(f"Running agent pipeline for: {query[:100]}...") |
| result = agent.invoke(initial_state) |
|
|
| |
| retrieved = result.get("retrieved_data", "") |
| sql_queries = [] |
| import re as _re |
| for match in _re.finditer(r'Step \d+ \(sql_query\).*?\n(.*?)(?=\n---|\'$)', retrieved, _re.DOTALL): |
| pass |
| |
| plan_text = result.get("plan", "") |
| try: |
| json_match = _re.search(r'\[.*\]', plan_text, _re.DOTALL) |
| if json_match: |
| steps = json.loads(json_match.group()) |
| for s in steps: |
| if s.get("tool") == "sql_query": |
| inp = s.get("input", "") |
| if isinstance(inp, str) and inp.strip().upper().startswith(("SELECT", "WITH")): |
| sql_queries.append(inp.strip()) |
| except Exception: |
| pass |
|
|
| return { |
| "answer": result.get("final_answer", "No answer generated."), |
| "confidence": result.get("confidence", "UNKNOWN"), |
| "plan": result.get("plan", ""), |
| "critique": result.get("critique", ""), |
| "retries": result.get("retry_count", 0), |
| "sql_queries": sql_queries, |
| "retrieved_data": retrieved, |
| } |
|
|