QueryMind / src /agent_graph /tool_travel_sqlagent.py
7beshoyarnest's picture
Update src/agent_graph/tool_travel_sqlagent.py
aacd4aa verified
Raw
History Blame Contribute Delete
3.59 kB
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_groq import ChatGroq
from agent_graph.load_tools_config import LoadToolsConfig
TOOLS_CFG = LoadToolsConfig()
class TravelSQLAgentTool:
"""
A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.
This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.
The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
generate a final answer for the user.
Attributes:
sql_agent_llm (ChatGroq): An instance of a ChatGroq language model used to generate and process SQL queries.
system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
db (SQLDatabase): An instance of the SQL database used to execute queries.
chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.
Methods:
__init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
"""
def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
"""
Initializes the TravelSQLAgentTool with the necessary configurations.
Args:
llm (str): The name of the language model to be used for generating and interpreting SQL queries.
sqldb_directory (str): The directory path where the SQLite database is stored.
llm_temerature (float): The temperature setting for the language model, controlling response randomness.
"""
self.sql_agent_llm = ChatGroq(
model=llm, temperature=llm_temerature)
self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
Question: {question}\n
SQL Query: {query}\n
SQL Result: {result}\n
Answer:
"""
self.db = SQLDatabase.from_uri(
f"sqlite:///{sqldb_directory}")
print(self.db.get_usable_table_names())
execute_query = QuerySQLDataBaseTool(db=self.db)
write_query = create_sql_query_chain(
self.sql_agent_llm, self.db)
answer_prompt = PromptTemplate.from_template(
self.system_role)
answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
self.chain = (
RunnablePassthrough.assign(query=write_query).assign(
result=itemgetter("query") | execute_query
)
| answer
)
@tool
def query_travel_sqldb(query: str) -> str:
"""Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
agent = TravelSQLAgentTool(
llm=TOOLS_CFG.travel_sqlagent_llm,
sqldb_directory=TOOLS_CFG.travel_sqldb_directory,
llm_temerature=TOOLS_CFG.travel_sqlagent_llm_temperature
)
response = agent.chain.invoke({"question": query})
return response