jason137 commited on
Commit
b1d1131
·
1 Parent(s): acc6beb

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +26 -18
script.py CHANGED
@@ -7,38 +7,46 @@ import streamlit as st
7
 
8
  API_KEY = os.getenv('OPENAI_API_KEY')
9
 
10
- st.title("English to SQL via LangChain")
11
 
12
- tables = ['Album', 'Artist', 'Track']
13
- db = SQLDatabase.from_uri("sqlite:///Chinook.db", include_tables=tables)
 
 
14
 
15
- con = sqlite3.connect("Chinook.db")
16
- cur = con.cursor()
17
 
18
- st.subheader("table metadata")
19
- for table in tables:
 
 
 
20
 
21
- rows = cur.execute("select * from %s limit 1" % table)
22
- cols = [k[0] for k in rows.description]
23
- st.text("%s: %s" % (table, cols))
24
 
25
- con.close()
 
26
 
27
- # st.dataframe(pd.DataFrame(metadata, columns=['table', 'columns']))
 
 
 
 
 
28
 
29
  llm = OpenAI(temperature=0.0, openai_api_key=API_KEY)
30
- db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True)
 
31
 
32
- queries = ("How many albums are there?"
33
- , "Which album has the most tracks?"
34
- , "What artist has the album with the most tracks?"
35
- )
36
 
37
- for query in queries:
38
 
39
  result = db_chain.run(query)
40
  # print(result)
41
 
42
  st.text("Q: %s" % query)
 
43
  st.text("A: %s" % result)
44
 
 
7
 
8
  API_KEY = os.getenv('OPENAI_API_KEY')
9
 
10
+ TABLES = ['Album', 'Artist', 'Track']
11
 
12
+ QUERIES = ("How many albums are there?"
13
+ , "Which album has the most tracks?"
14
+ , "What artist has the album with the 2nd most tracks?"
15
+ )
16
 
17
+ def _get_metadata()
 
18
 
19
+ con = sqlite3.connect("Chinook.db")
20
+ cur = con.cursor()
21
+
22
+ ts, cs = list(), list()
23
+ for table in tables:
24
 
25
+ rows = cur.execute("select * from %s limit 1" % table)
26
+ cols = [k[0] for k in rows.description]
 
27
 
28
+ ts.append(table)
29
+ cs.append(cols)
30
 
31
+ con.close()
32
+ return {'table': ts, 'columns': cs}
33
+
34
+ st.title("English to SQL via LangChain")
35
+ st.subheader("table metadata")
36
+ st.dataframe(_get_metadata())
37
 
38
  llm = OpenAI(temperature=0.0, openai_api_key=API_KEY)
39
+ db = SQLDatabase.from_uri("sqlite:///Chinook.db", include_tables=TABLES)
40
+ db_chain = SQLDatabaseChain.from_llm(llm, db, use_query_checker=True, return_intermediate_steps=True)
41
 
42
+ # db_chain = SQLDatabaseChain.from_llm(llm, db, use_query_checker=True)
 
 
 
43
 
44
+ for query in QUERIES:
45
 
46
  result = db_chain.run(query)
47
  # print(result)
48
 
49
  st.text("Q: %s" % query)
50
+ st.text(result["intermediate_steps"])
51
  st.text("A: %s" % result)
52