Spaces:
Sleeping
Sleeping
| from typing import TypedDict , Annotated , List , Optional | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import SystemMessage , HumanMessage | |
| from langchain_openai import ChatOpenAI | |
| from src.retrieval import retrieve | |
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import StateGraph, START ,END | |
| from pydantic import BaseModel , Field | |
| import datetime | |
| from langchain_community.utilities import SQLDatabase | |
| load_dotenv() | |
| class State(TypedDict) : | |
| connection_url : str | |
| user_id : str | |
| messages : Annotated[List , add_messages] | |
| scheme : str | |
| sql_query : str | |
| query_result : str | |
| error : Optional[str] | |
| retry : int | |
| final_result : str | |
| llm = ChatOpenAI( | |
| model="openai/gpt-4o-mini", | |
| openai_api_key=os.getenv("OPENROUTER_API_KEY"), | |
| openai_api_base="https://openrouter.ai/api/v1", | |
| temperature=0 | |
| ) | |
| class sql_query(BaseModel) : | |
| generated_sql_query : str = Field(...,description="The raw, valid executable SQL query text. Contain absolutely NO markdown wrapping, code blocks, or conversational formatting.") | |
| def retrieve_node(state : State) : | |
| messages = state.get("messages") | |
| db_url = state.get("connection_url") | |
| user_id = state.get("user_id") | |
| query = messages[-1].content | |
| scheme = retrieve(user_id , query , db_url) | |
| return {'scheme' : scheme} | |
| def generate_node(state : State) : | |
| messages = state.get("messages") | |
| scheme = state.get("scheme") | |
| error = state.get("error") | |
| wrong_query = state.get('sql_query') | |
| llm_with_structured_output = llm.with_structured_output(sql_query) | |
| history_messages = messages[:-1] | |
| current_query_string = messages[-1].content | |
| current_date = datetime.datetime.now().strftime("%Y-%m-%d") | |
| if history_messages: | |
| history_text = "\n".join([ | |
| f"{msg.type.capitalize()}: {msg.content}" | |
| for msg in history_messages | |
| ]) | |
| else: | |
| history_text = "This is the first user request. No history exists." | |
| if error and wrong_query : | |
| error_context = f""" | |
| === 🚨 ERROR CORRECTION MODE 🚨 === | |
| Your previous attempt to answer this request failed. | |
| [PREVIOUS BROKEN QUERY]: | |
| {wrong_query} | |
| [DATABASE ERROR MESSAGE]: | |
| {error} | |
| INSTRUCTION: Analyze the error message and the schema carefully. Fix the syntax, column names, or logic, and generate a CORRECTED query. | |
| """ | |
| else : | |
| error_context = "" | |
| system_prompt = SystemMessage(content=f"""You are an expert Data Analyst and Database Engineer. | |
| Your job is to write highly optimized, perfectly accurate database queries based on user requests. | |
| === DATABASE SCHEMA & DIALECT === | |
| Look at the metadata below to identify the targeted database engine dialect and table layout: | |
| {scheme} | |
| === CONVERSATION HISTORY === | |
| Use this previous context to resolve ambiguous terms (e.g., if the user says "filter those by...", look here to see what "those" refers to): | |
| {history_text} | |
| {error_context} | |
| === CRITICAL RULES === | |
| 1. ALIGNMENT: Only use the tables and columns provided in the schema above. Do not hallucinate column names. | |
| 2. DIALECT MATCHING: Look at the 'Dialect:' specified above and write strict queries matching that exact syntax. | |
| 3. JOINS: Pay close attention to the FOREIGN KEY constraints provided in the schema to perform accurate JOINs. | |
| 4. CURRENT DATE: Today's date is {current_date}. Use this exact date for any relative time filters (e.g., "last month", "this year"). | |
| 5. CASE SENSITIVITY: When filtering by strings, use case-insensitive comparisons (e.g., LOWER(column) = LOWER('value')) unless instructed otherwise. | |
| 6. SECURITY: NEVER generate DML queries (INSERT, UPDATE, DELETE, DROP). Only generate SELECT statements. | |
| === OUTPUT SELECTION RULES === | |
| 1. If the user asks WHO / WHICH / WHAT IS THE NAME / identify a person, customer, user, product, company, or entity, return the human-readable name field, not just the ID. | |
| 2. If the schema has both an ID column and a name column, prefer selecting the name column in the final output. | |
| 3. If the name is in another table, use the required JOIN to fetch it. | |
| 4. Only return an ID alone when the user explicitly asks for the ID, or when no name-like field exists in the schema. | |
| 5. For count/number questions, return an aggregate numeric result, not a list of rows. | |
| 6. For "who/which" questions, do not answer with only identifiers if a readable label exists in the schema. | |
| === INSTRUCTIONS === | |
| First, think through the necessary tables, filters, joins, and the exact type of answer expected. | |
| Then, provide the final executable SQL query specifically for the LATEST USER REQUEST.""") | |
| final_msg = [ | |
| system_prompt, | |
| HumanMessage(content=f"LATEST USER REQUEST:\n{current_query_string}") | |
| ] | |
| response = llm_with_structured_output.invoke(final_msg) | |
| return {'sql_query' : response.generated_sql_query , "error" : None} | |
| def execute_node(state : State) : | |
| url = state.get("connection_url") | |
| sql_query = state.get("sql_query") | |
| retry = state.get("retry" , 0) | |
| try : | |
| db = SQLDatabase.from_uri(url) | |
| result = db.run(sql_query) | |
| return {"query_result" : result , "error" : None , "retry" : 0} | |
| except Exception as e : | |
| return {'error' : str(e) , "retry" : retry+1} | |
| def routing(state : State) : | |
| error = state.get("error") | |
| retry = state.get('retry' , 0) | |
| if error and retry<3 : | |
| return "generate_node" | |
| else : | |
| return "answer_node" | |
| def answer_node(state : State) : | |
| messages = state.get("messages") | |
| query_result = state.get("query_result" , "No records found.") | |
| error = state.get("error") | |
| history_messages = messages[:-1] | |
| user_query = messages[-1].content | |
| if history_messages: | |
| history_text = "\n".join([ | |
| f"{msg.type.capitalize()}: {msg.content}" | |
| for msg in history_messages | |
| ]) | |
| else: | |
| history_text = "This is the first user request. No history exists." | |
| system_prompt = f"""You are a helpful Data Analyst communicating directly with a user. | |
| === CONVERSATION HISTORY === | |
| Use this to maintain the context and tone of the conversation: | |
| {history_text} | |
| === EXECUTION CONTEXT ===\n""" | |
| if error: | |
| system_prompt += f"""Unfortunately, the database returned an error and the data could not be retrieved. | |
| Error details: {error} | |
| INSTRUCTION: Politely apologize to the user and briefly explain that you encountered a technical issue retrieving their specific request.""" | |
| else: | |
| system_prompt += f"""The database returned this raw data: {query_result} | |
| INSTRUCTIONS: | |
| 1. Answer using ONLY the returned data. | |
| 2. Never invent a name, value, or entity that is not present in the result. | |
| 3. If the result contains both an ID and a name, use the name in the final answer and mention the ID only if helpful. | |
| 4. If the result contains only an ID and the user asked for a name/person/entity, say that the returned data only contains an identifier and no readable name. | |
| 5. Do not substitute or guess a name from a customer_id or any other identifier. | |
| 6. Do not mention SQL, the database, schemas, or how you got the data. | |
| 7. Give a clean, professional, and conversational response.""" | |
| final_msg = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=f"LATEST USER REQUEST:\n{user_query}") | |
| ] | |
| response = llm.invoke(final_msg) | |
| return {"messages": [response], "final_result": response.content} | |
| workflow = StateGraph(State) | |
| workflow.add_node("retrieve_node" , retrieve_node) | |
| workflow.add_node("generate_node" , generate_node) | |
| workflow.add_node("execute_node" , execute_node) | |
| workflow.add_node("answer_node" , answer_node) | |
| workflow.add_edge(START , "retrieve_node") | |
| workflow.add_edge("retrieve_node" , "generate_node") | |
| workflow.add_edge("generate_node" , "execute_node") | |
| workflow.add_conditional_edges("execute_node" , routing , { | |
| "answer_node" : "answer_node" , "generate_node" : "generate_node" | |
| }) | |
| workflow.add_edge("answer_node" , END) |