Spaces:
Sleeping
Sleeping
File size: 8,031 Bytes
52adb86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | from typing import TypedDict , Annotated , List , Optional
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage , HumanMessage
from langchain_openai import ChatOpenAI
from src.retrieval import retrieve
import os
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START ,END
from pydantic import BaseModel , Field
import datetime
from langchain_community.utilities import SQLDatabase
load_dotenv()
class State(TypedDict) :
connection_url : str
user_id : str
messages : Annotated[List , add_messages]
scheme : str
sql_query : str
query_result : str
error : Optional[str]
retry : int
final_result : str
llm = ChatOpenAI(
model="openai/gpt-4o-mini",
openai_api_key=os.getenv("OPENROUTER_API_KEY"),
openai_api_base="https://openrouter.ai/api/v1",
temperature=0
)
class sql_query(BaseModel) :
generated_sql_query : str = Field(...,description="The raw, valid executable SQL query text. Contain absolutely NO markdown wrapping, code blocks, or conversational formatting.")
def retrieve_node(state : State) :
messages = state.get("messages")
db_url = state.get("connection_url")
user_id = state.get("user_id")
query = messages[-1].content
scheme = retrieve(user_id , query , db_url)
return {'scheme' : scheme}
def generate_node(state : State) :
messages = state.get("messages")
scheme = state.get("scheme")
error = state.get("error")
wrong_query = state.get('sql_query')
llm_with_structured_output = llm.with_structured_output(sql_query)
history_messages = messages[:-1]
current_query_string = messages[-1].content
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
if history_messages:
history_text = "\n".join([
f"{msg.type.capitalize()}: {msg.content}"
for msg in history_messages
])
else:
history_text = "This is the first user request. No history exists."
if error and wrong_query :
error_context = f"""
=== 🚨 ERROR CORRECTION MODE 🚨 ===
Your previous attempt to answer this request failed.
[PREVIOUS BROKEN QUERY]:
{wrong_query}
[DATABASE ERROR MESSAGE]:
{error}
INSTRUCTION: Analyze the error message and the schema carefully. Fix the syntax, column names, or logic, and generate a CORRECTED query.
"""
else :
error_context = ""
system_prompt = SystemMessage(content=f"""You are an expert Data Analyst and Database Engineer.
Your job is to write highly optimized, perfectly accurate database queries based on user requests.
=== DATABASE SCHEMA & DIALECT ===
Look at the metadata below to identify the targeted database engine dialect and table layout:
{scheme}
=== CONVERSATION HISTORY ===
Use this previous context to resolve ambiguous terms (e.g., if the user says "filter those by...", look here to see what "those" refers to):
{history_text}
{error_context}
=== CRITICAL RULES ===
1. ALIGNMENT: Only use the tables and columns provided in the schema above. Do not hallucinate column names.
2. DIALECT MATCHING: Look at the 'Dialect:' specified above and write strict queries matching that exact syntax.
3. JOINS: Pay close attention to the FOREIGN KEY constraints provided in the schema to perform accurate JOINs.
4. CURRENT DATE: Today's date is {current_date}. Use this exact date for any relative time filters (e.g., "last month", "this year").
5. CASE SENSITIVITY: When filtering by strings, use case-insensitive comparisons (e.g., LOWER(column) = LOWER('value')) unless instructed otherwise.
6. SECURITY: NEVER generate DML queries (INSERT, UPDATE, DELETE, DROP). Only generate SELECT statements.
=== OUTPUT SELECTION RULES ===
1. If the user asks WHO / WHICH / WHAT IS THE NAME / identify a person, customer, user, product, company, or entity, return the human-readable name field, not just the ID.
2. If the schema has both an ID column and a name column, prefer selecting the name column in the final output.
3. If the name is in another table, use the required JOIN to fetch it.
4. Only return an ID alone when the user explicitly asks for the ID, or when no name-like field exists in the schema.
5. For count/number questions, return an aggregate numeric result, not a list of rows.
6. For "who/which" questions, do not answer with only identifiers if a readable label exists in the schema.
=== INSTRUCTIONS ===
First, think through the necessary tables, filters, joins, and the exact type of answer expected.
Then, provide the final executable SQL query specifically for the LATEST USER REQUEST.""")
final_msg = [
system_prompt,
HumanMessage(content=f"LATEST USER REQUEST:\n{current_query_string}")
]
response = llm_with_structured_output.invoke(final_msg)
return {'sql_query' : response.generated_sql_query , "error" : None}
def execute_node(state : State) :
url = state.get("connection_url")
sql_query = state.get("sql_query")
retry = state.get("retry" , 0)
try :
db = SQLDatabase.from_uri(url)
result = db.run(sql_query)
return {"query_result" : result , "error" : None , "retry" : 0}
except Exception as e :
return {'error' : str(e) , "retry" : retry+1}
def routing(state : State) :
error = state.get("error")
retry = state.get('retry' , 0)
if error and retry<3 :
return "generate_node"
else :
return "answer_node"
def answer_node(state : State) :
messages = state.get("messages")
query_result = state.get("query_result" , "No records found.")
error = state.get("error")
history_messages = messages[:-1]
user_query = messages[-1].content
if history_messages:
history_text = "\n".join([
f"{msg.type.capitalize()}: {msg.content}"
for msg in history_messages
])
else:
history_text = "This is the first user request. No history exists."
system_prompt = f"""You are a helpful Data Analyst communicating directly with a user.
=== CONVERSATION HISTORY ===
Use this to maintain the context and tone of the conversation:
{history_text}
=== EXECUTION CONTEXT ===\n"""
if error:
system_prompt += f"""Unfortunately, the database returned an error and the data could not be retrieved.
Error details: {error}
INSTRUCTION: Politely apologize to the user and briefly explain that you encountered a technical issue retrieving their specific request."""
else:
system_prompt += f"""The database returned this raw data: {query_result}
INSTRUCTIONS:
1. Answer using ONLY the returned data.
2. Never invent a name, value, or entity that is not present in the result.
3. If the result contains both an ID and a name, use the name in the final answer and mention the ID only if helpful.
4. If the result contains only an ID and the user asked for a name/person/entity, say that the returned data only contains an identifier and no readable name.
5. Do not substitute or guess a name from a customer_id or any other identifier.
6. Do not mention SQL, the database, schemas, or how you got the data.
7. Give a clean, professional, and conversational response."""
final_msg = [
SystemMessage(content=system_prompt),
HumanMessage(content=f"LATEST USER REQUEST:\n{user_query}")
]
response = llm.invoke(final_msg)
return {"messages": [response], "final_result": response.content}
workflow = StateGraph(State)
workflow.add_node("retrieve_node" , retrieve_node)
workflow.add_node("generate_node" , generate_node)
workflow.add_node("execute_node" , execute_node)
workflow.add_node("answer_node" , answer_node)
workflow.add_edge(START , "retrieve_node")
workflow.add_edge("retrieve_node" , "generate_node")
workflow.add_edge("generate_node" , "execute_node")
workflow.add_conditional_edges("execute_node" , routing , {
"answer_node" : "answer_node" , "generate_node" : "generate_node"
})
workflow.add_edge("answer_node" , END) |