Spaces:
Sleeping
Sleeping
| from typing import TypedDict , Annotated , List , Optional | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import SystemMessage , HumanMessage | |
| from langchain_openai import ChatOpenAI | |
| from src.retrieval import retrieve | |
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import StateGraph, START ,END | |
| from pydantic import BaseModel , Field | |
| import datetime | |
| from langchain_community.utilities import SQLDatabase | |
| load_dotenv() | |
| class State(TypedDict) : | |
| connection_url : str | |
| user_id : str | |
| messages : Annotated[List , add_messages] | |
| scheme : str | |
| sql_query : str | |
| query_result : str | |
| error : Optional[str] | |
| retry : int | |
| final_result : str | |
| llm = ChatOpenAI( | |
| model="openai/gpt-4o-mini", | |
| openai_api_key=os.getenv("OPENROUTER_API_KEY"), | |
| openai_api_base="https://openrouter.ai/api/v1", | |
| temperature=0 | |
| ) | |
| class sql_query(BaseModel) : | |
| generated_sql_query : str = Field(...,description="The raw, valid executable SQL query text. Contain absolutely NO markdown wrapping, code blocks, or conversational formatting.") | |
| def retrieve_node(state : State) : | |
| messages = state.get("messages") | |
| db_url = state.get("connection_url") | |
| user_id = state.get("user_id") | |
| query = messages[-1].content | |
| scheme = retrieve(user_id , query , db_url) | |
| return {'scheme' : scheme} | |
| def generate_node(state : State) : | |
| messages = state.get("messages") | |
| scheme = state.get("scheme") | |
| error = state.get("error") | |
| wrong_query = state.get('sql_query') | |
| llm_with_structured_output = llm.with_structured_output(sql_query) | |
| history_messages = messages[:-1] | |
| current_query_string = messages[-1].content | |
| current_date = datetime.datetime.now().strftime("%Y-%m-%d") | |
| if history_messages: | |
| history_text = "\n".join([ | |
| f"{msg.type.capitalize()}: {msg.content}" | |
| for msg in history_messages | |
| ]) | |
| else: | |
| history_text = "This is the first user request. No history exists." | |
| if error and wrong_query : | |
| error_context = f""" | |
| === 🚨 ERROR CORRECTION MODE 🚨 === | |
| Your previous attempt to answer this request failed. | |
| [PREVIOUS BROKEN QUERY]: | |
| {wrong_query} | |
| [DATABASE ERROR MESSAGE]: | |
| {error} | |
| INSTRUCTION: Analyze the error message and the schema carefully. Fix the syntax, column names, or logic, and generate a CORRECTED query. | |
| """ | |
| else : | |
| error_context = "" | |
| system_prompt = SystemMessage(content=f""" | |
| You are an expert Data Analyst and SQL Engineer. | |
| Your task is to generate ONE valid SELECT query for the latest user request. | |
| === DATABASE SCHEMA & DIALECT === | |
| {scheme} | |
| === CONVERSATION HISTORY === | |
| {history_text} | |
| === ERROR CORRECTION MODE === | |
| {error_context} | |
| === CRITICAL RULES === | |
| 1. Use ONLY tables and columns that exist in the schema. | |
| 2. Never hallucinate columns, joins, or tables. | |
| 3. Generate only SELECT queries. No INSERT, UPDATE, DELETE, DROP, TRUNCATE, ALTER. | |
| 4. Use the exact SQL dialect implied by the schema metadata. | |
| 5. For any output columns, ALWAYS use clear aliases. | |
| Example: | |
| - customer_id AS customer_id | |
| - customer_name AS customer_name | |
| - SUM(amount) AS total_amount | |
| 6. When the user asks for a person/customer/company/product/entity, return BOTH: | |
| - the readable name field if it exists | |
| - the matching ID field | |
| 7. If a name exists in another table, join to fetch it. | |
| 8. If no readable name exists, return the best human-readable identifier available, and the ID. | |
| 9. For aggregate queries, include a label column when possible so the answer layer can explain the result. | |
| 10. If fixing an error, preserve the original user intent and correct only the broken parts. | |
| === PRIORITY RULE FOR ID VS NAME === | |
| - Priority 1: name + id together, if possible | |
| - Priority 2: name only, if name exists but id cannot be included | |
| - Priority 3: id only, only if no readable name exists | |
| === OUTPUT FORMAT REQUIREMENT === | |
| Return a SQL query whose selected columns are self-explanatory. | |
| Do not rely on positional meaning like column 1, column 2 without aliases. | |
| === CURRENT DATE === | |
| Today's date is {current_date}. | |
| """) | |
| final_msg = [ | |
| system_prompt, | |
| HumanMessage(content=f"LATEST USER REQUEST:\n{current_query_string}") | |
| ] | |
| response = llm_with_structured_output.invoke(final_msg) | |
| return {'sql_query' : response.generated_sql_query , "error" : None} | |
| def execute_node(state : State) : | |
| url = state.get("connection_url") | |
| sql_query = state.get("sql_query") | |
| retry = state.get("retry" , 0) | |
| try : | |
| db = SQLDatabase.from_uri(url) | |
| result = db.run(sql_query) | |
| return {"query_result" : result , "error" : None , "retry" : 0} | |
| except Exception as e : | |
| return {'error' : str(e) , "retry" : retry+1} | |
| def routing(state : State) : | |
| error = state.get("error") | |
| retry = state.get('retry' , 0) | |
| if error and retry<3 : | |
| return "generate_node" | |
| else : | |
| return "answer_node" | |
| def answer_node(state : State) : | |
| messages = state.get("messages") | |
| query_result = state.get("query_result" , "No records found.") | |
| sql_query = state.get("sql_query", "") | |
| error = state.get("error") | |
| history_messages = messages[:-1] | |
| user_query = messages[-1].content | |
| if history_messages: | |
| history_text = "\n".join([ | |
| f"{msg.type.capitalize()}: {msg.content}" | |
| for msg in history_messages | |
| ]) | |
| else: | |
| history_text = "This is the first user request. No history exists." | |
| system_prompt = f""" | |
| You are a helpful Data Analyst communicating directly with a user. | |
| === CONVERSATION HISTORY === | |
| {history_text} | |
| === EXECUTION CONTEXT === | |
| SQL QUERY USED: | |
| {sql_query} | |
| RAW DATABASE RESULT: | |
| {query_result} | |
| === INSTRUCTIONS === | |
| 1. Use ONLY the returned data. | |
| 2. Interpret the result using the SQL query and its selected aliases. | |
| 3. If the query selected columns like customer_id, customer_name, total_amount, use those exact labels in the final response. | |
| 4. If the result is positional, map values to the SQL SELECT order. | |
| 5. Never invent a name or ID. | |
| 6. For who/which questions: | |
| - prefer name + id | |
| - if name is missing, give the id and clearly say no readable name was returned | |
| 7. If the result contains an ID and a value like total_amount, explain them clearly. | |
| 8. Do not mention SQL or the database in the final answer. | |
| 9. Give a clean, professional response. | |
| """ | |
| final_msg = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=f"LATEST USER REQUEST:\n{user_query}") | |
| ] | |
| response = llm.invoke(final_msg) | |
| return {"messages": [response], "final_result": response.content} | |
| workflow = StateGraph(State) | |
| workflow.add_node("retrieve_node" , retrieve_node) | |
| workflow.add_node("generate_node" , generate_node) | |
| workflow.add_node("execute_node" , execute_node) | |
| workflow.add_node("answer_node" , answer_node) | |
| workflow.add_edge(START , "retrieve_node") | |
| workflow.add_edge("retrieve_node" , "generate_node") | |
| workflow.add_edge("generate_node" , "execute_node") | |
| workflow.add_conditional_edges("execute_node" , routing , { | |
| "answer_node" : "answer_node" , "generate_node" : "generate_node" | |
| }) | |
| workflow.add_edge("answer_node" , END) |