Spaces:
Sleeping
Sleeping
File size: 4,616 Bytes
02ec225 920ff5a e470cbb d3c12b0 787b5cd d3c12b0 485e943 0a473b6 465bcde 0354aa8 6e5c75d 02ec225 7c3a8b1 0cbd39d 54cda27 bb18880 7c3a8b1 3ee3873 83678c1 fe7c766 83678c1 54cda27 83678c1 54cda27 83678c1 54cda27 83678c1 54cda27 83678c1 54cda27 83678c1 0cbd39d 485e943 0a473b6 485e943 7c3a8b1 03d7502 b2feba5 09a1315 ffd6b0c 09a1315 7c3a8b1 ffd6b0c 54cda27 485e943 0cbd39d 4be3544 ffd6b0c 0a87fb1 408e4a2 7c3a8b1 920ff5a e470cbb b7580ac 920ff5a 5c3a841 920ff5a 5c3a841 920ff5a 21dfa43 920ff5a 7c3a8b1 d72f607 9952459 7c3a8b1 ffd6b0c 8e461e8 3ee3873 9952459 920ff5a 0a473b6 920ff5a ba3ca71 cb7f516 ba3ca71 5aff5f0 ba3ca71 cb7f516 ba3ca71 cb7f516 ba3ca71 5aff5f0 cb7f516 d202154 ba3ca71 54cda27 82afe3d 54cda27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import streamlit as st
import numpy as np
import pandas as pd
import altair as alt
import chat as idf_chat
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain import OpenAI
from langchain import PromptTemplate, OpenAI, LLMChain
from langchain.chains import SimpleSequentialChain
from langchain_experimental.sql import SQLDatabaseChain
from langchain_experimental.agents.agent_toolkits import create_python_agent
from langchain_experimental.tools import PythonREPLTool
from langchain_experimental.utilities import PythonREPL
JSON_DATA_LABEL = 'json_data'
llm=OpenAI(temperature=0)
db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
if JSON_DATA_LABEL not in st.session_state:
st.session_state[JSON_DATA_LABEL] = {}
def get_db_chain():
template = """Your name is IDF. If you recieve a "Hey IDF" salute reply by saying, "Hey BBHer!". Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
If someone asks any question involving client name, you need to join with Client table
volume: you need to count records
Trades: you need to get volume of trades
Currency Bought: you need to use ccyBought
Currency Sold: you need to use ccySold
Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"],
template=template
)
return SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)
def get_json_chain():
prompt_template = "Reformat this {result} in JSON format"
return LLMChain(
llm=llm,
prompt=PromptTemplate.from_template(prompt_template)
)
def plot_chart():
json = st.session_state[JSON_DATA_LABEL]
# print(json)
# if not json:
# return "no data to plot"
agent_executor = create_python_agent(
llm=llm,
tool=PythonREPLTool(),
verbose=True
)
question = "Plot these results: " + json
# Add Open AI call to format outcome in a table
return agent_executor.run(question)
db_chain = get_db_chain()
json_chain = get_json_chain()
def from_gpt(query: str, plot: bool):
try:
chains = [db_chain, json_chain] if plot else [db_chain]
main_chain = SimpleSequentialChain(chains=chains, verbose=True)
ans = main_chain.run(query)
# Execute Python if plot
st.session_state[JSON_DATA_LABEL] = ans if plot else {}
return ans
except Exception as e:
return "Please ask a proper business question related to selected datasets"
def circle_chart():
df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
return alt.Chart(df).mark_circle().encode(
x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
)
# Parse the prompt to pick an example and a render function
def get_response(prompt: str, *kargs):
on_render = st.write
response = f"Here's what you asked: '{prompt}'"
prompt_lower = prompt.lower()
if prompt_lower == 'line chart':
on_render = st.line_chart
response = np.random.randn(30, 30)
elif prompt_lower == 'circle chart':
on_render = st.write
response = circle_chart()
elif prompt_lower.startswith('json'):
p = prompt_lower.split('json ')[1]
on_render = st.write
response = from_gpt(p, plot=True)
elif prompt_lower == 'plot':
on_render = st.write
response = plot_chart()
else:
on_render = st.write
response = from_gpt(prompt, plot=False)
return (response, on_render)
chat = idf_chat.Chat()
sidebar_text = """
# Ask IDF
## Example prompts
Replace `{client}` with a client name.
Hint: you can ask IDF to tell you what clients are available
```
What is the total USD Amount for client {client} for Jan 2022
```
```
What is the average USD Amount per month for client {client} for 2022
```
```
Trades volume for {client} in May 2022
```
```
Get the total USD amount where Currency Bought is {Currency (e.g. GBP)} for {client} in Jun 2022
```
```
Get the Currency Bought/Sold with least total USD amount for {client} in Mar 2022
```
"""
with st.sidebar:
st.markdown(sidebar_text)
# prompt = chat.get_promt("Ask IDF Anything")
chat.process(get_response, llm)
|