Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import getpass | |
| from contextlib import contextmanager | |
| from typing import List | |
| from operator import itemgetter | |
| from sqlalchemy import create_engine, text, inspect | |
| from sqlalchemy.orm import sessionmaker | |
| from dotenv import load_dotenv | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.output_parsers.openai_tools import PydanticToolsParser | |
| from langchain.chains import create_sql_query_chain | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Set environment variables for API keys | |
| if not os.environ.get("OPENAI_API_KEY"): | |
| os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ") | |
| if not os.environ.get("LANGCHAIN_API_KEY"): | |
| os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ") | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| # Setup SQLite Database | |
| db_path = os.path.join(os.path.dirname(__file__), "chinook.db") | |
| engine = create_engine(f"sqlite:///{db_path}") | |
| Session = sessionmaker(bind=engine) | |
| db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
| print(db.dialect) | |
| print(db.get_usable_table_names()) | |
| with Session() as session: | |
| result = session.execute(text("SELECT * FROM artists LIMIT 10;")).fetchall() | |
| print(result) | |
| # Initialize LLM | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125") | |
| class Table(BaseModel): | |
| """Table in SQL database.""" | |
| name: str = Field(description="Name of table in SQL database.") | |
| # Function to get schema information | |
| def get_schema_info(): | |
| inspector = inspect(engine) | |
| schema_info = {} | |
| for table_name in inspector.get_table_names(): | |
| columns = inspector.get_columns(table_name) | |
| schema_info[table_name] = [(column["name"], str(column["type"])) for column in columns] | |
| return schema_info | |
| # Provide schema info to LLM | |
| schema_info = get_schema_info() | |
| formatted_schema_info = "\n".join( | |
| f"Table: {table}\nColumns: {', '.join([f'{col[0]} ({col[1]})' for col in cols])}" | |
| for table, cols in schema_info.items() | |
| ) | |
| system = f"""You are an expert in querying SQL databases. The database schema is as follows: | |
| {formatted_schema_info} | |
| Given an input question, create a syntactically correct SQL query to run, then look at the results of the query and return the answer to the input question. | |
| Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. | |
| You can order the results to return the most informative data in the database. Never query for all columns from a table. | |
| You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
| Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. | |
| Also, pay attention to which column is in which table. Use the following format: | |
| SQLQuery: """ | |
| table_names = "\n".join(db.get_usable_table_names()) | |
| system_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \ | |
| The tables are: | |
| {table_names} | |
| Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.""" | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| llm_with_tools = llm.bind_tools([Table]) | |
| output_parser = PydanticToolsParser(tools=[Table]) | |
| table_chain = prompt | llm_with_tools | output_parser | |
| # Function to get table names from the output | |
| def get_table_names(output: List[Table]) -> List[str]: | |
| return [table.name for table in output] | |
| # Create the SQL query chain | |
| query_chain = create_sql_query_chain(llm, db) | |
| # Combine table selection and query generation | |
| full_chain = ( | |
| RunnablePassthrough.assign( | |
| table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]})) | |
| ) | |
| | query_chain | |
| ) | |
| # Function to strip markdown formatting from SQL query | |
| def strip_markdown(text): | |
| # Remove code block formatting | |
| text = re.sub(r'```sql\s*|\s*```', '', text) | |
| # Remove any leading/trailing whitespace | |
| return text.strip() | |
| # Function to execute SQL query | |
| def get_db_session(): | |
| session = Session() | |
| try: | |
| yield session | |
| finally: | |
| session.close() | |
| def execute_sql_query(query: str) -> str: | |
| try: | |
| with get_db_session() as session: | |
| # Strip markdown formatting before executing | |
| clean_query = strip_markdown(query) | |
| result = session.execute(text(clean_query)).fetchall() | |
| return str(result) | |
| except Exception as e: | |
| return f"Error executing query: {str(e)}" | |
| # Create the answer generation prompt | |
| answer_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question. | |
| If there was an error in executing the SQL query, please explain the error and suggest a correction. | |
| Do not include any SQL code formatting or markdown in your response."""), | |
| ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:") | |
| ]) | |
| # Assemble the final chain | |
| chain = ( | |
| RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x)) | |
| .assign(result=lambda x: execute_sql_query(x["query"])) | |
| | answer_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # Unit test function | |
| def unit_test(): | |
| print("Running unit test...") | |
| # Example query | |
| response = chain.invoke({"question": "How many employees are there?"}) | |
| print("Final Answer:", response) | |
| print("Unit test completed.") | |
| # Main function | |
| def main(): | |
| # Print schema information | |
| print("Database Schema Information:") | |
| print(formatted_schema_info) | |
| # Run unit test | |
| unit_test() | |
| # Continuously ask the user for queries until "quit" is entered | |
| while True: | |
| user_question = input("Please enter your query (or type 'quit' to exit): ") | |
| if user_question.lower() == 'quit': | |
| print("Exiting the program.") | |
| break | |
| # Process user's query | |
| response = chain.invoke({"question": user_question}) | |
| print("Final Answer:", response) | |
| if __name__ == "__main__": | |
| main() | |