File size: 8,031 Bytes
52adb86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from typing import TypedDict , Annotated , List , Optional
from langgraph.graph.message import add_messages 
from langchain_core.messages import SystemMessage , HumanMessage
from langchain_openai import ChatOpenAI
from src.retrieval import retrieve
import os
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START ,END
from pydantic import BaseModel , Field
import datetime
from langchain_community.utilities import SQLDatabase

load_dotenv()

class State(TypedDict) :
    connection_url : str 
    user_id : str
    messages : Annotated[List , add_messages]
    scheme : str
    sql_query : str
    query_result : str
    error : Optional[str]
    retry : int
    final_result : str


llm = ChatOpenAI(
    model="openai/gpt-4o-mini",
    openai_api_key=os.getenv("OPENROUTER_API_KEY"),
    openai_api_base="https://openrouter.ai/api/v1",
    temperature=0
)

class sql_query(BaseModel) :
    generated_sql_query : str = Field(...,description="The raw, valid executable SQL query text. Contain absolutely NO markdown wrapping, code blocks, or conversational formatting.")

def retrieve_node(state : State) :
    messages = state.get("messages")
    db_url = state.get("connection_url")
    user_id = state.get("user_id")

    query = messages[-1].content

    scheme = retrieve(user_id , query , db_url)

    return {'scheme' : scheme}

def generate_node(state : State) :
    messages = state.get("messages")
    scheme = state.get("scheme")
    error = state.get("error")
    wrong_query = state.get('sql_query')

    llm_with_structured_output = llm.with_structured_output(sql_query)

    history_messages = messages[:-1]
    current_query_string = messages[-1].content

    current_date = datetime.datetime.now().strftime("%Y-%m-%d")

    if history_messages:
        history_text = "\n".join([
            f"{msg.type.capitalize()}: {msg.content}" 
            for msg in history_messages
        ])
    else:
        history_text = "This is the first user request. No history exists."

    if error and wrong_query :
        error_context = f"""
=== 🚨 ERROR CORRECTION MODE 🚨 ===
Your previous attempt to answer this request failed.
[PREVIOUS BROKEN QUERY]: 
{wrong_query}

[DATABASE ERROR MESSAGE]: 
{error}

INSTRUCTION: Analyze the error message and the schema carefully. Fix the syntax, column names, or logic, and generate a CORRECTED query.
"""
    else :
        error_context = ""

        system_prompt = SystemMessage(content=f"""You are an expert Data Analyst and Database Engineer. 
Your job is to write highly optimized, perfectly accurate database queries based on user requests.

=== DATABASE SCHEMA & DIALECT ===
Look at the metadata below to identify the targeted database engine dialect and table layout:
{scheme}

=== CONVERSATION HISTORY ===
Use this previous context to resolve ambiguous terms (e.g., if the user says "filter those by...", look here to see what "those" refers to):
{history_text}
{error_context}

=== CRITICAL RULES ===
1. ALIGNMENT: Only use the tables and columns provided in the schema above. Do not hallucinate column names.
2. DIALECT MATCHING: Look at the 'Dialect:' specified above and write strict queries matching that exact syntax. 
3. JOINS: Pay close attention to the FOREIGN KEY constraints provided in the schema to perform accurate JOINs.
4. CURRENT DATE: Today's date is {current_date}. Use this exact date for any relative time filters (e.g., "last month", "this year").
5. CASE SENSITIVITY: When filtering by strings, use case-insensitive comparisons (e.g., LOWER(column) = LOWER('value')) unless instructed otherwise.
6. SECURITY: NEVER generate DML queries (INSERT, UPDATE, DELETE, DROP). Only generate SELECT statements.

=== OUTPUT SELECTION RULES ===
1. If the user asks WHO / WHICH / WHAT IS THE NAME / identify a person, customer, user, product, company, or entity, return the human-readable name field, not just the ID.
2. If the schema has both an ID column and a name column, prefer selecting the name column in the final output.
3. If the name is in another table, use the required JOIN to fetch it.
4. Only return an ID alone when the user explicitly asks for the ID, or when no name-like field exists in the schema.
5. For count/number questions, return an aggregate numeric result, not a list of rows.
6. For "who/which" questions, do not answer with only identifiers if a readable label exists in the schema.

=== INSTRUCTIONS ===
First, think through the necessary tables, filters, joins, and the exact type of answer expected.
Then, provide the final executable SQL query specifically for the LATEST USER REQUEST.""")
    
    final_msg = [
        system_prompt,
        HumanMessage(content=f"LATEST USER REQUEST:\n{current_query_string}")
    ]

    response = llm_with_structured_output.invoke(final_msg)

    return {'sql_query' : response.generated_sql_query , "error" : None}

def execute_node(state : State) :
    url = state.get("connection_url")
    sql_query = state.get("sql_query")
    retry = state.get("retry" , 0)

    try :
        db = SQLDatabase.from_uri(url)

        result = db.run(sql_query)

        return {"query_result" : result , "error" : None , "retry" : 0}
    
    except Exception as e :
        return {'error' : str(e) , "retry" : retry+1}
    
def routing(state : State) :
    error = state.get("error")
    retry = state.get('retry' , 0)

    if error and retry<3 :
        return "generate_node"
    else :
        return "answer_node"
    
def answer_node(state : State) :
    messages = state.get("messages")
    query_result = state.get("query_result" , "No records found.")
    error = state.get("error")

    history_messages = messages[:-1]
    user_query = messages[-1].content

    if history_messages:
        history_text = "\n".join([
            f"{msg.type.capitalize()}: {msg.content}" 
            for msg in history_messages
        ])
    else:
        history_text = "This is the first user request. No history exists."

    system_prompt = f"""You are a helpful Data Analyst communicating directly with a user.

=== CONVERSATION HISTORY ===
Use this to maintain the context and tone of the conversation:
{history_text}

=== EXECUTION CONTEXT ===\n"""

    if error:
        system_prompt += f"""Unfortunately, the database returned an error and the data could not be retrieved. 
Error details: {error}
INSTRUCTION: Politely apologize to the user and briefly explain that you encountered a technical issue retrieving their specific request."""
    else:
        system_prompt += f"""The database returned this raw data: {query_result}

INSTRUCTIONS:
1. Answer using ONLY the returned data.
2. Never invent a name, value, or entity that is not present in the result.
3. If the result contains both an ID and a name, use the name in the final answer and mention the ID only if helpful.
4. If the result contains only an ID and the user asked for a name/person/entity, say that the returned data only contains an identifier and no readable name.
5. Do not substitute or guess a name from a customer_id or any other identifier.
6. Do not mention SQL, the database, schemas, or how you got the data.
7. Give a clean, professional, and conversational response."""
        
    final_msg = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=f"LATEST USER REQUEST:\n{user_query}")
    ]
    
    response = llm.invoke(final_msg)

    return {"messages": [response], "final_result": response.content}

workflow = StateGraph(State)

workflow.add_node("retrieve_node" , retrieve_node)
workflow.add_node("generate_node" , generate_node)
workflow.add_node("execute_node" , execute_node)
workflow.add_node("answer_node" , answer_node)

workflow.add_edge(START , "retrieve_node")
workflow.add_edge("retrieve_node" , "generate_node")
workflow.add_edge("generate_node" , "execute_node")
workflow.add_conditional_edges("execute_node" , routing , {
    "answer_node" : "answer_node" , "generate_node" : "generate_node"
})
workflow.add_edge("answer_node" , END)