| | from langchain import SQLDatabaseChain |
| | from langchain.sql_database import SQLDatabase |
| | from langchain.llms.openai import OpenAI |
| | from langchain.chat_models import ChatOpenAI |
| | from langchain.prompts.prompt import PromptTemplate |
| |
|
| | llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True) |
| |
|
| | DEFAULT_TABLES = [ |
| | 'Active Players', |
| | 'Team_Per_Game_Statistics_2022_23', |
| | "Team_Totals_Statistics_2022_23", |
| | "Player_Total_Statistics_2022_23", |
| | "Player_Per_Game_Statistics_2022_23" |
| | ] |
| |
|
| | def get_prompt(): |
| | _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. |
| | Use the following format: |
| | |
| | Question: "Question here" |
| | SQLQuery: "SQL Query to run" |
| | SQLResult: "Result of the SQLQuery" |
| | |
| | Answer: "Final answer here" |
| | |
| | Only use the following tables: |
| | |
| | {table_info} |
| | |
| | Question: {input}""" |
| |
|
| | PROMPT = PromptTemplate( |
| | input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE |
| | ) |
| | return PROMPT |
| |
|
| | def check_query(query): |
| | if query.startswith("### Query"): |
| | split = query.split('\n\n') |
| | q_text = split[0] |
| | t_text = split[1] |
| |
|
| | if t_text.startswith("### Tables"): |
| | query_params = dict() |
| | tables = t_text.split('\n') |
| | query_params['tables'] = tables[1:] |
| | query_params['q'] = q_text.split('\n')[1] |
| | print(query_params) |
| | return query_params |
| | else: |
| | return 'error' |
| | return 'small' |
| |
|
| | def get_db(q, tables): |
| | if len(tables) == 0: |
| | db = SQLDatabase.from_uri("sqlite:///nba_small.db", |
| | sample_rows_in_table_info=2) |
| | else: |
| | tables.extend(DEFAULT_TABLES) |
| | db = SQLDatabase.from_uri("sqlite:///nba.db", |
| | include_tables = tables, |
| | sample_rows_in_table_info=2) |
| | return db |
| | def answer_question(query): |
| | PROMPT = get_prompt() |
| | query_check = check_query(query) |
| | if query_check == 'error': |
| | return('ERROR: Wrong format for getting the big db schema') |
| | if isinstance(query_check, dict): |
| | q = query_check['q'] |
| | tables = query_check['tables'] |
| | if query_check == 'small': |
| | q = query |
| | tables = [] |
| | db = get_db(q, tables) |
| |
|
| | db_chain = SQLDatabaseChain.from_llm(llm, db, |
| | prompt=PROMPT, |
| | verbose=True, |
| | return_intermediate_steps=True, |
| | |
| | ) |
| | result = db_chain(q) |
| | return result['result'] |
| |
|
| | if __name__ == "__main__": |
| | import gradio as gr |
| | |
| | |
| | gr.Interface( |
| | answer_question, |
| | [ |
| | gr.inputs.Textbox(lines=10, label="Query"), |
| | ], |
| | gr.outputs.Textbox(label="Response"), |
| | title="Ask NBA Stats", |
| | description=""" Ask NBA Stats is a tool that let's you ask a question with |
| | the NBA SQL tables as a reference |
| | |
| | Ask a simple question to use the small database |
| | |
| | If you would like to access the large DB use format |
| | |
| | ### Query |
| | single line query |
| | |
| | ### Tables |
| | tables to access line by line |
| | table1 |
| | table2""" |
| | ).launch() |