jason137 commited on
Commit
67928b0
·
1 Parent(s): 57803c2

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +27 -3
script.py CHANGED
@@ -4,6 +4,8 @@ from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
4
  import sqlite3
5
  import streamlit as st
6
 
 
 
7
  st.title("English to SQL via LangChain")
8
 
9
  tables = ['Album', 'Artist', 'Track']
@@ -12,8 +14,30 @@ db = SQLDatabase.from_uri("sqlite:///Chinook.db", include_tables=tables)
12
  con = sqlite3.connect("Chinook.db")
13
  cur = con.cursor()
14
 
 
15
  for table in tables:
16
- rows = cur.execute("select * from %s limit 5" % table).fetchall()
17
- st.table(rows)
18
 
19
- con.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import sqlite3
5
  import streamlit as st
6
 
7
+ API_KEY = os.getenv('OPENAI_API_KEY')
8
+
9
  st.title("English to SQL via LangChain")
10
 
11
  tables = ['Album', 'Artist', 'Track']
 
14
  con = sqlite3.connect("Chinook.db")
15
  cur = con.cursor()
16
 
17
+ metadata = dict()
18
  for table in tables:
 
 
19
 
20
+ rows = cur.execute("select * from %s limit 1" % table)
21
+ cols = [k[0] for k in rows.description]
22
+
23
+ metadata[table] = [cols]
24
+
25
+ con.close()
26
+
27
+ llm = OpenAI(temperature=0.0, openai_api_key=API_KEY)
28
+
29
+ db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True)
30
+
31
+ queries = ("How many albums are there?"
32
+ , "Which album has the most tracks?"
33
+ , "What artist has the album with the most tracks?"
34
+ )
35
+
36
+ for query in queries:
37
+ result = db_chain.run(query)
38
+ print(result)
39
+ st.text(query)
40
+ st.text(result)
41
+
42
+ st.subheader("table metadata")
43
+ st.dataframe(pd.DataFrame(metadata), columns=['table', 'columns'])