Spaces:
Sleeping
Sleeping
| from langchain_core.tools import tool | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain.chains import create_sql_query_chain | |
| from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from operator import itemgetter | |
| from langchain_groq import ChatGroq | |
| from agent_graph.load_tools_config import LoadToolsConfig | |
| TOOLS_CFG = LoadToolsConfig() | |
| class TravelSQLAgentTool: | |
| """ | |
| A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries. | |
| This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model. | |
| The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to | |
| generate a final answer for the user. | |
| Attributes: | |
| sql_agent_llm (ChatGroq): An instance of a ChatGroq language model used to generate and process SQL queries. | |
| system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results. | |
| db (SQLDatabase): An instance of the SQL database used to execute queries. | |
| chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response. | |
| Methods: | |
| __init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline. | |
| """ | |
| def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None: | |
| """ | |
| Initializes the TravelSQLAgentTool with the necessary configurations. | |
| Args: | |
| llm (str): The name of the language model to be used for generating and interpreting SQL queries. | |
| sqldb_directory (str): The directory path where the SQLite database is stored. | |
| llm_temerature (float): The temperature setting for the language model, controlling response randomness. | |
| """ | |
| self.sql_agent_llm = ChatGroq( | |
| model=llm, temperature=llm_temerature) | |
| self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n | |
| Question: {question}\n | |
| SQL Query: {query}\n | |
| SQL Result: {result}\n | |
| Answer: | |
| """ | |
| self.db = SQLDatabase.from_uri( | |
| f"sqlite:///{sqldb_directory}") | |
| print(self.db.get_usable_table_names()) | |
| execute_query = QuerySQLDataBaseTool(db=self.db) | |
| write_query = create_sql_query_chain( | |
| self.sql_agent_llm, self.db) | |
| answer_prompt = PromptTemplate.from_template( | |
| self.system_role) | |
| answer = answer_prompt | self.sql_agent_llm | StrOutputParser() | |
| self.chain = ( | |
| RunnablePassthrough.assign(query=write_query).assign( | |
| result=itemgetter("query") | execute_query | |
| ) | |
| | answer | |
| ) | |
| def query_travel_sqldb(query: str) -> str: | |
| """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query.""" | |
| agent = TravelSQLAgentTool( | |
| llm=TOOLS_CFG.travel_sqlagent_llm, | |
| sqldb_directory=TOOLS_CFG.travel_sqldb_directory, | |
| llm_temerature=TOOLS_CFG.travel_sqlagent_llm_temperature | |
| ) | |
| response = agent.chain.invoke({"question": query}) | |
| return response | |