{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyprojroot import here\n", "from langchain_community.utilities import SQLDatabase\n", "from langchain.chains import create_sql_query_chain\n", "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n", "from langchain_core.prompts import PromptTemplate\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.runnables import RunnablePassthrough\n", "from operator import itemgetter\n", "from langchain_groq import ChatGroq\n", "import os\n", "from dotenv import load_dotenv\n", "load_dotenv()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Set the environment variables and load the LLM**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "os.environ['GROQ_API_KEY'] = os.getenv(\"GROQ_API_KEY\")\n", "\n", "llm = ChatGroq(model=\"openai/gpt-oss-120b\")\n", "# llm = ChatGroq(model=\"llama3-8b-8192\")\n", "# llm = ChatGroq(model=\"mixtral-8x7b-32768\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Load and test the sqlite db**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sqlite\n", "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" ] }, { "data": { "text/plain": [ "\"[('Album',), ('Artist',), ('Customer',), ('Employee',), ('Genre',), ('Invoice',), ('InvoiceLine',), ('MediaType',), ('Playlist',), ('PlaylistTrack',), ('Track',)]\"" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sqldb_directory = here(\"data/Chinook.db\")\n", "db = SQLDatabase.from_uri(\n", " f\"sqlite:///{sqldb_directory}\")\n", "\n", "print(db.dialect)\n", "print(db.get_usable_table_names())\n", "db.run(\"\"\" SELECT name\n", "FROM sqlite_master\n", "WHERE type='table'\n", "AND name NOT LIKE 'sqlite_%'; \"\"\")\n", "\n", "# from sqlalchemy import create_engine, inspect\n", "# from sqlalchemy.orm import sessionmaker\n", "# engine = create_engine(db_path)\n", "\n", "# # Create a session\n", "# Session = sessionmaker(bind=engine)\n", "# session = Session()\n", "\n", "# # Use SQLAlchemy's Inspector to get database information\n", "# inspector = inspect(engine)\n", "\n", "# # Get table names\n", "# tables = inspector.get_table_names()\n", "# print(\"Tables in the database:\", tables)\n", "# print(len(tables))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Create the SQL agent chain and run a test query**" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "system_role = \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\\n\n", " Question: {question}\\n\n", " SQL Query: {query}\\n\n", " SQL Result: {result}\\n\n", " Answer:\n", " \"\"\"\n", "\n", "execute_query = QuerySQLDataBaseTool(db=db)\n", "write_query = create_sql_query_chain(\n", " llm, db)\n", "answer_prompt = PromptTemplate.from_template(\n", " system_role)\n", "answer = answer_prompt | llm | StrOutputParser()\n", "chain = (\n", " RunnablePassthrough.assign(query=write_query).assign(\n", " result=itemgetter(\"query\") | execute_query\n", " )\n", " | answer\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'I’m sorry, but there’s no SQL result provided, so I can’t determine how many tables are in your database or what their names are. If you can share the query output (e.g., a list of table names), I’ll be happy to give you the answer.'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "message = \"How many tables do I have in the database? and what are their names?\"\n", "response = chain.invoke({\"question\": message})\n", "response" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Travel SQL-agent Tool Design**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from langchain_core.tools import tool\n", "from langchain_community.utilities import SQLDatabase\n", "from langchain.chains import create_sql_query_chain\n", "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n", "from langchain_core.prompts import PromptTemplate\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.runnables import RunnablePassthrough\n", "from operator import itemgetter\n", "from langchain_openai import ChatOpenAI\n", "\n", "\n", "class TravelSQLAgentTool:\n", " \"\"\"\n", " A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.\n", "\n", " This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.\n", " The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to\n", " generate a final answer for the user.\n", "\n", " Attributes:\n", " sql_agent_llm (ChatOpenAI): An instance of a ChatOpenAI language model used to generate and process SQL queries.\n", " system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.\n", " db (SQLDatabase): An instance of the SQL database used to execute queries.\n", " chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.\n", "\n", " Methods:\n", " __init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.\n", " \"\"\"\n", "\n", " def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:\n", " \"\"\"\n", " Initializes the TravelSQLAgentTool with the necessary configurations.\n", "\n", " Args:\n", " llm (str): The name of the language model to be used for generating and interpreting SQL queries.\n", " sqldb_directory (str): The directory path where the SQLite database is stored.\n", " llm_temerature (float): The temperature setting for the language model, controlling response randomness.\n", " \"\"\"\n", " self.sql_agent_llm = ChatGroq(\n", " model=llm, temperature=llm_temerature)\n", " self.system_role = \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\\n\n", " Question: {question}\\n\n", " SQL Query: {query}\\n\n", " SQL Result: {result}\\n\n", " Answer:\n", " \"\"\"\n", " self.db = SQLDatabase.from_uri(\n", " f\"sqlite:///{sqldb_directory}\")\n", " print(self.db.get_usable_table_names())\n", "\n", " execute_query = QuerySQLDataBaseTool(db=self.db)\n", " write_query = create_sql_query_chain(\n", " self.sql_agent_llm, self.db)\n", " answer_prompt = PromptTemplate.from_template(\n", " self.system_role)\n", "\n", " answer = answer_prompt | self.sql_agent_llm | StrOutputParser()\n", " self.chain = (\n", " RunnablePassthrough.assign(query=write_query).assign(\n", " result=itemgetter(\"query\") | execute_query\n", " )\n", " | answer\n", " )" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "<>:3: SyntaxWarning: invalid escape sequence '\\e'\n", "<>:3: SyntaxWarning: invalid escape sequence '\\e'\n", "C:\\Users\\AL-MASA\\AppData\\Local\\Temp\\ipykernel_13496\\1650904972.py:3: SyntaxWarning: invalid escape sequence '\\e'\n", " sys.path.insert(0, os.path.abspath('F:\\end_to_end_AI_Projects\\QueryMind _ AI_Powered_Natural_Language_Interface_for_SQL_&_Vector_Databases')) # or the full path to your project root\n" ] } ], "source": [ "import sys\n", "import os\n", "sys.path.insert(0, os.path.abspath('F:\\end_to_end_AI_Projects\\QueryMind _ AI_Powered_Natural_Language_Interface_for_SQL_&_Vector_Databases')) # or the full path to your project root\n", "\n", "from src.agent_graph.load_tools_config import LoadToolsConfig" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from src.agent_graph.load_tools_config import LoadToolsConfig\n", "\n", "TOOLS_CFG = LoadToolsConfig()\n", "\n", "@tool\n", "def query_travel_sqldb(query: str) -> str:\n", " \"\"\"Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query.\"\"\"\n", " agent = TravelSQLAgentTool(\n", " llm=TOOLS_CFG.travel_sqlagent_llm,\n", " sqldb_directory=TOOLS_CFG.travel_sqldb_directory,\n", " llm_temperature=TOOLS_CFG.travel_sqlagent_llm_temperature\n", " )\n", " response = agent.chain.invoke({\"question\": query})\n", " return response" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "querymind (3.12.10)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 2 }