File size: 1,375 Bytes
9225820
 
 
005e4ad
d472eeb
2c74cef
 
9225820
da36009
67928b0
b1d1131
9225820
b1d1131
 
a233a6d
b1d1131
9225820
749a38e
 
 
16ea00b
e3a559e
b1d1131
 
 
 
da36009
e3a559e
b1d1131
 
67928b0
b1d1131
 
67928b0
b1d1131
5fec911
b1d1131
5ff241f
67928b0
da36009
b1d1131
 
67928b0
5ff241f
afd0ea0
2f02f10
a8b63a7
87601fe
45a70a5
5ff241f
 
da5957b
5ff241f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain

import emoji
import pandas as pd
import sqlite3
import streamlit as st

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

TABLES = ['Album', 'Artist', 'Track']

QUERIES = ("How many albums are there?"
    , "Which album has the most tracks?"
    , "What artist has the album with the most tracks?"
)

st.title(emoji.emojize(":robot_face: Text to SQL via LangChain :robot_face:"))
st.subheader("table metadata")

def _get_metadata():

    con = sqlite3.connect("Chinook.db")
    cur = con.cursor()

    ts, cs = list(), list()
    for table in TABLES:

        rows = cur.execute("select * from %s limit 1" % table)
        cols = [k[0] for k in rows.description]

        ts.append(table)
        cs.append(cols)

    con.close()
    return pd.DataFrame({'table': ts, 'columns': cs})

st.write(_get_metadata())

llm = OpenAI(temperature=0.0, openai_api_key=OPENAI_API_KEY)
db = SQLDatabase.from_uri("sqlite:///Chinook.db", include_tables=TABLES)
db_chain = SQLDatabaseChain.from_llm(llm, db, use_query_checker=True, return_intermediate_steps=True)

for i, query in enumerate(QUERIES):

    output = db_chain(query)
    sql, result = output["intermediate_steps"][1], output["result"]

    st.subheader("query %s" % str(1 + i))
    
    st.write("Q: %s" % query)
    st.code(sql)
    st.write("A: %s" % result)