Spaces:
Sleeping
Sleeping
| import logging | |
| import json | |
| import os | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.example_selectors import SemanticSimilarityExampleSelector | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_community.agent_toolkits import create_sql_agent | |
| from langchain_core.prompts import ( | |
| ChatPromptTemplate, | |
| FewShotPromptTemplate, | |
| MessagesPlaceholder, | |
| PromptTemplate, | |
| SystemMessagePromptTemplate, | |
| ) | |
| from langchain_community.utilities import SQLDatabase | |
| from dotenv import load_dotenv | |
| load_dotenv(".env") | |
| logging.basicConfig(level=logging.INFO) | |
| # Save the log to a file | |
| handler = logging.FileHandler('extractor.log') | |
| logger = logging.getLogger(__name__) | |
| os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') | |
| if os.getenv('LANGSMITH'): | |
| os.environ['LANGCHAIN_TRACING_V2'] = 'true' | |
| os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com' | |
| os.environ[ | |
| 'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY") | |
| os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT') | |
| def load_json(file_path: str) -> dict: | |
| with open(file_path, 'r') as file: | |
| return json.load(file) | |
| class SqlChain: | |
| def __init__(self, few_shot_prompts: str, llm_model="gpt-3.5-turbo", db_uri="sqlite:///data/games.db", | |
| few_shot_k=2): | |
| self.llm = ChatOpenAI(model=llm_model, temperature=0) | |
| self.db = SQLDatabase.from_uri(db_uri) | |
| self.few_shot_k = few_shot_k | |
| self.few_shot = self._set_up_few_shot_prompts(load_json(few_shot_prompts)) | |
| self.full_prompt = None | |
| self.agent = create_sql_agent( | |
| llm=self.llm, | |
| db=self.db, | |
| prompt=self.full_prompt, | |
| max_iterations=10, | |
| verbose=True, | |
| agent_type="openai-tools", | |
| # Default to 10 examples - Can be overwritten with the prompt | |
| top_k=30, | |
| ) | |
| def _set_up_few_shot_prompts(self, few_shot_prompts: dict) -> None: | |
| few_shots = SemanticSimilarityExampleSelector.from_examples( | |
| few_shot_prompts, | |
| OpenAIEmbeddings(), | |
| FAISS, | |
| k=self.few_shot_k, | |
| input_keys=["input"], | |
| ) | |
| return few_shots | |
| def few_prompt_construct(self, query: str, top_k=5, dialect="SQLite") -> str: | |
| system_prefix = """You are an agent designed to interact with a SQL database. | |
| Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. | |
| ALWAYS query the database before returning an answer. | |
| Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. | |
| You can order the results by a relevant column to return the most interesting examples in the database. | |
| Never query for all the columns from a specific table, only ask for the relevant columns given the question. | |
| You have access to tools for interacting with the database. | |
| Only use the given tools. Only use the information returned by the tools to construct your final answer. | |
| You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | |
| DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | |
| If the question does not seem related to the database, just return 'I don't know' as the answer. | |
| DO NOT include information that is not present in the database in your answer. | |
| Here are some examples of user inputs and their corresponding SQL queries. They are tested and works. | |
| Use them as a guide when creating your own queries:""" | |
| # SUFFIX = """Begin! | |
| # | |
| # Question: {input} | |
| # Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables. | |
| # I will not stop until I query the database and return the answer. | |
| # {agent_scratchpad}""" | |
| SUFFIX = """Begin! | |
| Question: {input} | |
| Thought: I should look at the examples provided and see if I can use them to identify tables and how to build the query. | |
| Then I should query the schema of the most relevant tables. | |
| I will not stop until I query the database and return the answer. | |
| {agent_scratchpad}""" | |
| few_shot_prompt = FewShotPromptTemplate( | |
| example_selector=self.few_shot, | |
| example_prompt=PromptTemplate.from_template( | |
| "User input: {input}\nSQL query: {query}" | |
| ), | |
| input_variables=["input", "dialect", "top_k"], | |
| prefix=system_prefix, | |
| suffix=SUFFIX, | |
| ) | |
| full_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessagePromptTemplate(prompt=few_shot_prompt), | |
| ("human", "{input}"), | |
| MessagesPlaceholder("agent_scratchpad"), | |
| ] | |
| ) | |
| self.full_prompt = full_prompt.invoke( | |
| { | |
| "input": query, | |
| "top_k": top_k, | |
| "dialect": dialect, | |
| "agent_scratchpad": [], | |
| } | |
| ) | |
| def prompt_no_few_shot(self, query: str, dialect="SQLite") -> str: | |
| system_prefix = """You are an agent designed to interact with a SQL database. | |
| Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. | |
| Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. | |
| You can order the results by a relevant column to return the most interesting examples in the database. | |
| Never query for all the columns from a specific table, only ask for the relevant columns given the question. | |
| You have access to tools for interacting with the database. | |
| Only use the given tools. Only use the information returned by the tools to construct your final answer. | |
| You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | |
| DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | |
| If the question does not seem related to the database, just return 'I don't know' as the answer. | |
| DO NOT include information that is not present in the database in your answer.""" | |
| return f"{system_prefix}\n{query}" | |
| def ask(self, query: str, few_prompt: bool = True, rag_test=False) -> str: | |
| if rag_test: | |
| self.few_prompt_construct(query) | |
| # Alter the self.full_prompt to only include whats added by the RAG system | |
| # Get content in self.full_prompt[messages][0][content] | |
| prompt = self.full_prompt.messages | |
| prompt = prompt[0].content | |
| prompt = prompt.split("Use them as a guide when creating your own queries:\n\n")[1] | |
| # Then remove everything after \n\nBegin!\n\n | |
| prompt = prompt.split("\n\nBegin!\n\n")[0] | |
| # Lets split it to a list. One element for each "User input: {input}\nSQL query: {query}" | |
| prompt = prompt.split("User input: ") | |
| # Then remove the first element | |
| prompt = prompt[1:] | |
| return prompt | |
| if few_prompt: | |
| self.few_prompt_construct(query) | |
| return self.agent.invoke({"input": self.full_prompt}), self.full_prompt | |
| else: | |
| return self.agent.invoke(self.prompt_no_few_shot(query)), self.prompt_no_few_shot(query) | |
| def create_agent(few_shot_prompts: str = "src/conf/sqls.json", llm_model="gpt-3.5-turbo-0125", | |
| db_uri="config", few_shot_k=2): | |
| """ Create an agent with the given few_shot_prompts, llm_model and db_uri | |
| Call it with agent.ask(prompt)""" | |
| if db_uri == "config": | |
| db_uri = os.getenv('DATABASE_PATH') | |
| db_uri = f"sqlite:///{db_uri}" | |
| # print(db_uri) | |
| # print("sqlite:///data/games.db") | |
| # exit(0) | |
| return SqlChain(few_shot_prompts, llm_model, db_uri, few_shot_k) | |
| if __name__ == "__main__": | |
| chain = SqlChain("src/conf/sqls.json") | |
| chain.ask("Is Manchester United in the database?", rag_test=True) | |