AskIDF / app.py
waelstaha's picture
Update app.py
0354aa8
raw
history blame
4.62 kB
import streamlit as st
import numpy as np
import pandas as pd
import altair as alt
import chat as idf_chat
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain import OpenAI
from langchain import PromptTemplate, OpenAI, LLMChain
from langchain.chains import SimpleSequentialChain
from langchain_experimental.sql import SQLDatabaseChain
from langchain_experimental.agents.agent_toolkits import create_python_agent
from langchain_experimental.tools import PythonREPLTool
from langchain_experimental.utilities import PythonREPL
JSON_DATA_LABEL = 'json_data'
llm=OpenAI(temperature=0)
db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
if JSON_DATA_LABEL not in st.session_state:
st.session_state[JSON_DATA_LABEL] = {}
def get_db_chain():
template = """Your name is IDF. If you recieve a "Hey IDF" salute reply by saying, "Hey BBHer!". Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
If someone asks any question involving client name, you need to join with Client table
volume: you need to count records
Trades: you need to get volume of trades
Currency Bought: you need to use ccyBought
Currency Sold: you need to use ccySold
Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"],
template=template
)
return SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)
def get_json_chain():
prompt_template = "Reformat this {result} in JSON format"
return LLMChain(
llm=llm,
prompt=PromptTemplate.from_template(prompt_template)
)
def plot_chart():
json = st.session_state[JSON_DATA_LABEL]
# print(json)
# if not json:
# return "no data to plot"
agent_executor = create_python_agent(
llm=llm,
tool=PythonREPLTool(),
verbose=True
)
question = "Plot these results: " + json
# Add Open AI call to format outcome in a table
return agent_executor.run(question)
db_chain = get_db_chain()
json_chain = get_json_chain()
def from_gpt(query: str, plot: bool):
try:
chains = [db_chain, json_chain] if plot else [db_chain]
main_chain = SimpleSequentialChain(chains=chains, verbose=True)
ans = main_chain.run(query)
# Execute Python if plot
st.session_state[JSON_DATA_LABEL] = ans if plot else {}
return ans
except Exception as e:
return "Please ask a proper business question related to selected datasets"
def circle_chart():
df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
return alt.Chart(df).mark_circle().encode(
x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
)
# Parse the prompt to pick an example and a render function
def get_response(prompt: str, *kargs):
on_render = st.write
response = f"Here's what you asked: '{prompt}'"
prompt_lower = prompt.lower()
if prompt_lower == 'line chart':
on_render = st.line_chart
response = np.random.randn(30, 30)
elif prompt_lower == 'circle chart':
on_render = st.write
response = circle_chart()
elif prompt_lower.startswith('json'):
p = prompt_lower.split('json ')[1]
on_render = st.write
response = from_gpt(p, plot=True)
elif prompt_lower == 'plot':
on_render = st.write
response = plot_chart()
else:
on_render = st.write
response = from_gpt(prompt, plot=False)
return (response, on_render)
chat = idf_chat.Chat()
sidebar_text = """
# Ask IDF
## Example prompts
Replace `{client}` with a client name.
Hint: you can ask IDF to tell you what clients are available
```
What is the total USD Amount for client {client} for Jan 2022
```
```
What is the average USD Amount per month for client {client} for 2022
```
```
Trades volume for {client} in May 2022
```
```
Get the total USD amount where Currency Bought is {Currency (e.g. GBP)} for {client} in Jun 2022
```
```
Get the Currency Bought/Sold with least total USD amount for {client} in Mar 2022
```
"""
with st.sidebar:
st.markdown(sidebar_text)
# prompt = chat.get_promt("Ask IDF Anything")
chat.process(get_response, llm)