File size: 4,616 Bytes
02ec225
920ff5a
 
 
e470cbb
d3c12b0
 
787b5cd
d3c12b0
485e943
0a473b6
465bcde
0354aa8
 
 
6e5c75d
02ec225
7c3a8b1
0cbd39d
54cda27
bb18880
7c3a8b1
 
 
 
3ee3873
83678c1
 
fe7c766
83678c1
 
 
 
 
54cda27
83678c1
54cda27
83678c1
54cda27
83678c1
 
 
 
 
54cda27
 
83678c1
 
54cda27
 
83678c1
 
 
 
 
0cbd39d
485e943
0a473b6
485e943
 
 
 
7c3a8b1
 
03d7502
b2feba5
 
 
 
 
 
 
 
09a1315
ffd6b0c
09a1315
7c3a8b1
ffd6b0c
54cda27
 
485e943
0cbd39d
4be3544
 
 
 
ffd6b0c
 
 
0a87fb1
408e4a2
7c3a8b1
920ff5a
 
 
 
 
 
e470cbb
b7580ac
920ff5a
5c3a841
920ff5a
5c3a841
 
920ff5a
 
21dfa43
920ff5a
 
7c3a8b1
 
d72f607
9952459
7c3a8b1
 
ffd6b0c
8e461e8
3ee3873
9952459
920ff5a
 
 
 
0a473b6
920ff5a
 
ba3ca71
 
cb7f516
ba3ca71
 
 
 
5aff5f0
 
ba3ca71
 
cb7f516
ba3ca71
 
cb7f516
ba3ca71
 
5aff5f0
cb7f516
d202154
 
 
 
 
 
ba3ca71
 
 
 
 
54cda27
82afe3d
54cda27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
import numpy as np
import pandas as pd
import altair as alt
import chat as idf_chat
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain import OpenAI
from langchain import PromptTemplate, OpenAI, LLMChain
from langchain.chains import SimpleSequentialChain
from langchain_experimental.sql import SQLDatabaseChain
from langchain_experimental.agents.agent_toolkits import create_python_agent
from langchain_experimental.tools import PythonREPLTool
from langchain_experimental.utilities import PythonREPL


JSON_DATA_LABEL = 'json_data'
llm=OpenAI(temperature=0)
db = SQLDatabase.from_uri("sqlite:///FXTrades.db")


if JSON_DATA_LABEL not in st.session_state:
    st.session_state[JSON_DATA_LABEL] = {}


def get_db_chain():

    template = """Your name is IDF. If you recieve a "Hey IDF" salute reply by saying, "Hey BBHer!". Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
    Use the following format:
    Question: "Question here"
    SQLQuery: "SQL Query to run"
    SQLResult: "Result of the SQLQuery"
    Answer: "Final answer here"

    Only use the following tables:

    {table_info}

    If someone asks any question involving client name, you need to join with Client table
    volume: you need to count records
    Trades: you need to get volume of trades
    Currency Bought: you need to use ccyBought
    Currency Sold: you need to use ccySold


    Question: {input}"""
    PROMPT = PromptTemplate(
        input_variables=["input", "table_info", "dialect"],
        template=template
        )

    return SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)


def get_json_chain():
    prompt_template = "Reformat this {result} in JSON format"
    return LLMChain(
        llm=llm,
        prompt=PromptTemplate.from_template(prompt_template)
    )

def plot_chart():
    json = st.session_state[JSON_DATA_LABEL]
    # print(json)
#    if not json:
#        return "no data to plot"

    agent_executor = create_python_agent(
    llm=llm,
    tool=PythonREPLTool(),
    verbose=True
    )
    question = "Plot these results: " + json
#   Add Open AI call to format outcome in a table
    return agent_executor.run(question)

   
db_chain = get_db_chain()
json_chain = get_json_chain()

def from_gpt(query: str, plot: bool):
    try:
        chains = [db_chain, json_chain] if plot else [db_chain]
        main_chain = SimpleSequentialChain(chains=chains, verbose=True)
        ans = main_chain.run(query)
       # Execute Python if plot
        st.session_state[JSON_DATA_LABEL] = ans if plot else {}
        return ans
    except Exception as e:
        return "Please ask a proper business question related to selected datasets" 

def circle_chart():
    df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
    return alt.Chart(df).mark_circle().encode(
        x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
    )

# Parse the prompt to pick an example and a render function
def get_response(prompt: str, *kargs):
    on_render = st.write
    response = f"Here's what you asked: '{prompt}'"

    prompt_lower = prompt.lower()
    if prompt_lower == 'line chart':
        on_render = st.line_chart
        response = np.random.randn(30, 30)
    elif prompt_lower == 'circle chart':
        on_render = st.write
        response = circle_chart()
    elif prompt_lower.startswith('json'):
        p = prompt_lower.split('json ')[1]
        on_render = st.write
        response = from_gpt(p, plot=True)
    elif prompt_lower == 'plot':
        on_render = st.write
        response = plot_chart()
    else:
        on_render = st.write
        response = from_gpt(prompt, plot=False)

    return (response, on_render)




chat = idf_chat.Chat()

sidebar_text = """
# Ask IDF



## Example prompts
Replace `{client}` with a client name. 
Hint: you can ask IDF to tell you what clients are available

```
What is the total USD Amount for client {client} for Jan 2022 
```
```
What is the average USD Amount per month for client {client} for 2022
```
```
Trades volume for {client} in May 2022
```
```
Get the total USD amount where Currency Bought is {Currency (e.g. GBP)} for {client} in Jun 2022
```
```
Get the Currency Bought/Sold with least total USD amount for {client} in Mar 2022
```
"""

with st.sidebar:
    st.markdown(sidebar_text)

# prompt = chat.get_promt("Ask IDF Anything")

chat.process(get_response, llm)