text-to-sql / script.py
jason137's picture
Update script.py
a233a6d
import os
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
import emoji
import pandas as pd
import sqlite3
import streamlit as st
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
TABLES = ['Album', 'Artist', 'Track']
QUERIES = ("How many albums are there?"
, "Which album has the most tracks?"
, "What artist has the album with the most tracks?"
)
st.title(emoji.emojize(":robot_face: Text to SQL via LangChain :robot_face:"))
st.subheader("table metadata")
def _get_metadata():
con = sqlite3.connect("Chinook.db")
cur = con.cursor()
ts, cs = list(), list()
for table in TABLES:
rows = cur.execute("select * from %s limit 1" % table)
cols = [k[0] for k in rows.description]
ts.append(table)
cs.append(cols)
con.close()
return pd.DataFrame({'table': ts, 'columns': cs})
st.write(_get_metadata())
llm = OpenAI(temperature=0.0, openai_api_key=OPENAI_API_KEY)
db = SQLDatabase.from_uri("sqlite:///Chinook.db", include_tables=TABLES)
db_chain = SQLDatabaseChain.from_llm(llm, db, use_query_checker=True, return_intermediate_steps=True)
for i, query in enumerate(QUERIES):
output = db_chain(query)
sql, result = output["intermediate_steps"][1], output["result"]
st.subheader("query %s" % str(1 + i))
st.write("Q: %s" % query)
st.code(sql)
st.write("A: %s" % result)