Nam Fam
add files
472e1d4
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from agents.llms import LLM
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from utils.consts import DB_PATH
from agents.sql_agent.states import SQLAgentState
# Load environment vars
load_dotenv()
# def get_sql_agent():
# """
# Initializes a LangChain SQLDatabaseChain for SQLite.
# """
# # Load SQLite DB
# db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
# # Patch run to strip Markdown fences and log
# orig_run = db.run
# def clean_run(query: str, **kwargs) -> str:
# lines = query.splitlines()
# if lines and lines[0].strip().startswith("```"):
# lines = lines[1:]
# if lines and lines[-1].strip().startswith("```"):
# lines = lines[:-1]
# cleaned = "\n".join(lines).strip()
# print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
# def get_sql_agent():
# """
# Initializes a LangChain SQLDatabaseChain for SQLite.
# """
# # Load SQLite DB
# db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
# # Patch run to strip Markdown fences and log
# orig_run = db.run
# def clean_run(query: str, **kwargs) -> str:
# lines = query.splitlines()
# if lines and lines[0].strip().startswith("```"):
# lines = lines[1:]
# if lines and lines[-1].strip().startswith("```"):
# lines = lines[:-1]
# cleaned = "\n".join(lines).strip()
# print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
# return orig_run(cleaned, **kwargs)
# db.run = clean_run
# # Initialize LLM
# llm_wrapper = LLM()
# # Create SQLDatabaseChain
# chain = SQLDatabaseChain.from_llm(llm_wrapper.chat_model, db, verbose=True)
# return chain
class SQLAgent:
def __init__(self):
self.db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
self.llm = LLM()
self.graph = self.build_graph()
def build_graph(self):
from agents.sql_agent.graph import build_graph
return build_graph().compile()
def run(self, state: SQLAgentState) -> SQLAgentState:
"""
Run the SQL agent with the given query.
"""
return self.graph.invoke(state)
if __name__ == "__main__":
agent = SQLAgent()
state = {
"question": None,
"db_info": {
"tables": [],
"columns": {},
"schema": None
},
"sql_query": None,
"sql_result": None,
"error": None
}
while True:
question = input("Enter your query (or 'exit' to quit): ")
state['question'] = question
if not question or question.lower() in ('exit', 'quit'):
print("Goodbye!")
break
result = agent.run(state)
# print(result)
# answer = result['answer']
# print(answer)
for step in agent.graph.stream(state, stream_mode="updates"):
print(step)