File size: 3,078 Bytes
472e1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from agents.llms import LLM
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from utils.consts import DB_PATH
from agents.sql_agent.states import SQLAgentState

# Load environment vars
load_dotenv()

# def get_sql_agent():
#     """
#     Initializes a LangChain SQLDatabaseChain for SQLite.
#     """
#     # Load SQLite DB
#     db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
#     # Patch run to strip Markdown fences and log
#     orig_run = db.run
#     def clean_run(query: str, **kwargs) -> str:
#         lines = query.splitlines()
#         if lines and lines[0].strip().startswith("```"):
#             lines = lines[1:]
#         if lines and lines[-1].strip().startswith("```"):
#             lines = lines[:-1]
#         cleaned = "\n".join(lines).strip()
#         print(f"[SQLDatabaseChain] Running SQL: {cleaned}")

# def get_sql_agent():
#     """
#     Initializes a LangChain SQLDatabaseChain for SQLite.
#     """
#     # Load SQLite DB
#     db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
#     # Patch run to strip Markdown fences and log
#     orig_run = db.run
#     def clean_run(query: str, **kwargs) -> str:
#         lines = query.splitlines()
#         if lines and lines[0].strip().startswith("```"):
#             lines = lines[1:]
#         if lines and lines[-1].strip().startswith("```"):
#             lines = lines[:-1]
#         cleaned = "\n".join(lines).strip()
#         print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
#         return orig_run(cleaned, **kwargs)
#     db.run = clean_run
#     # Initialize LLM
#     llm_wrapper = LLM()
#     # Create SQLDatabaseChain
#     chain = SQLDatabaseChain.from_llm(llm_wrapper.chat_model, db, verbose=True)
#     return chain

class SQLAgent:
    def __init__(self):
        self.db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
        self.llm = LLM()
        self.graph = self.build_graph()

    
    def build_graph(self):
        from agents.sql_agent.graph import build_graph
        return build_graph().compile()
    
    def run(self, state: SQLAgentState) -> SQLAgentState:
        """
        Run the SQL agent with the given query.
        """
        return self.graph.invoke(state)
    
if __name__ == "__main__":
    agent = SQLAgent()
    state = {
        "question": None,
        "db_info": {
            "tables": [],
            "columns": {},
            "schema": None
        },
        "sql_query": None,
        "sql_result": None,
        "error": None
    }
    while True:
        question = input("Enter your query (or 'exit' to quit): ")
        state['question'] = question
        if not question or question.lower() in ('exit', 'quit'):
            print("Goodbye!")
            break
        result = agent.run(state) 
        # print(result)

        # answer = result['answer']
        # print(answer)

        for step in agent.graph.stream(state, stream_mode="updates"):
            print(step)