Jesus Sanchez commited on
Commit
54cda27
·
1 Parent(s): 40eb877

fixed sql chain

Browse files
Files changed (2) hide show
  1. app.py +18 -58
  2. chat.py +2 -14
app.py CHANGED
@@ -1,13 +1,14 @@
1
  from os import write
 
2
  import streamlit as st
3
  import numpy as np
4
  import pandas as pd
5
  import altair as alt
6
  import chat as idf_chat
7
- import sqlite3
8
  from langchain.sql_database import SQLDatabase
9
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
10
- from langchain.agents import create_sql_agent
11
  from langchain import OpenAI
12
  from langchain import PromptTemplate, OpenAI, LLMChain
13
  from langchain.chains import SimpleSequentialChain
@@ -16,6 +17,7 @@ from langchain import SQLDatabaseChain
16
  JSON_DATA_LABEL = 'json_data'
17
  DB_CHAIN = 'db_chain'
18
  llm=OpenAI(temperature=0)
 
19
 
20
 
21
  if JSON_DATA_LABEL not in st.session_state:
@@ -24,38 +26,12 @@ if JSON_DATA_LABEL not in st.session_state:
24
  if DB_CHAIN not in st.session_state:
25
  st.session_state[DB_CHAIN] = {}
26
 
27
- def tables_from_db():
28
- # db = sqlite3.connect('switrs.sqlite')
29
- db = sqlite3.connect('FXTrades.db')
30
- cursor = db.cursor()
31
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
32
- tables = cursor.fetchall()
33
- cursor.close()
34
- db.close()
35
- return tables
36
-
37
- def from_db(table: str):
38
- # db = sqlite3.connect('switrs.sqlite')
39
- db = sqlite3.connect('FXTrades.db')
40
- # Read into a panda DataFrame
41
- df = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 50", db)
42
- db.close()
43
- columns = df.columns
44
- # Pick a random column for y axis
45
- column_index = np.random.randint(0, columns.size, 1)
46
- y_column_name = columns[column_index][0]
47
- print(columns)
48
- return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
49
-
50
-
51
-
52
  def get_sql_agent():
53
  # db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
54
  db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
55
 
56
  toolkit = SQLDatabaseToolkit(llm=llm,db=db)
57
 
58
-
59
  return create_sql_agent(
60
  llm=llm,
61
  toolkit=toolkit,
@@ -65,32 +41,31 @@ def get_sql_agent():
65
  )
66
 
67
  def get_db_chain():
68
-
69
- db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
70
 
71
- _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
72
  Use the following format:
73
  ccyBoughtccyBought
74
  Question: "Question here"
75
  SQLQuery: "SQL Query to run"
76
  SQLResult: "Result of the SQLQuery"
77
  Answer: "Final answer here"
78
-
79
  Only use the following tables:
80
-
81
  {table_info}
82
-
83
  If someone asks any question involving client name, you need to join with Client table
84
  volume: you need to count records
85
  Amounts: you need to use USD amount
86
  Trades: you need to get volume of trades
87
  Currency Bought: you need to use ccyBought
88
  Currency Sold: you need to use ccySold
89
-
90
-
91
  Question: {input}"""
92
  PROMPT = PromptTemplate(
93
- input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
 
94
  )
95
 
96
  return SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)
@@ -114,23 +89,15 @@ def plot_chart():
114
 
115
  return alt.Chart(dataframe).mark_line()
116
 
 
 
 
117
 
118
 
119
  def from_gpt(query: str, plot: bool):
120
 
121
- # sql_agent = get_sql_agent()
122
- sql_agent = st.session_state[DB_CHAIN]
123
- if not sql_agent:
124
- sql_agent = get_db_chain()
125
- st.session_state[DB_CHAIN] = sql_agent
126
-
127
- json_chain = get_json_chain()
128
-
129
- chains = [sql_agent, json_chain] if plot else [sql_agent]
130
-
131
- # Run the query using the agent executor
132
  main_chain = SimpleSequentialChain(chains=chains, verbose=True)
133
-
134
  ans = main_chain.run(query)
135
 
136
  # Save data as json if plot
@@ -156,13 +123,6 @@ def get_response(prompt: str, *kargs):
156
  elif prompt_lower == 'circle chart':
157
  on_render = st.write
158
  response = circle_chart()
159
- elif prompt_lower == 'db':
160
- on_render = st.write
161
- response = tables_from_db()
162
- elif prompt_lower.startswith('db:'):
163
- table = prompt_lower.split(":")[1]
164
- on_render = st.write
165
- response = from_db(table)
166
  elif prompt_lower.startswith('json'):
167
  p = prompt_lower.split('json ')[1]
168
  on_render = st.write
@@ -206,6 +166,6 @@ What trades did client {client} do in May 2022
206
  with st.sidebar:
207
  st.markdown(sidebar_text)
208
 
209
- prompt = chat.get_promt("Ask IDF Anything")
210
 
211
- chat.process(prompt, get_response, llm)
 
1
  from os import write
2
+ from langchain.llms.base import LLM
3
  import streamlit as st
4
  import numpy as np
5
  import pandas as pd
6
  import altair as alt
7
  import chat as idf_chat
8
+ # import sqlite3
9
  from langchain.sql_database import SQLDatabase
10
  from langchain.agents.agent_toolkits import SQLDatabaseToolkit
11
+ from langchain.agents import LLMSingleActionAgent, agent_types, create_sql_agent, initialize_agent, AgentType,load_tools
12
  from langchain import OpenAI
13
  from langchain import PromptTemplate, OpenAI, LLMChain
14
  from langchain.chains import SimpleSequentialChain
 
17
  JSON_DATA_LABEL = 'json_data'
18
  DB_CHAIN = 'db_chain'
19
  llm=OpenAI(temperature=0)
20
+ db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
21
 
22
 
23
  if JSON_DATA_LABEL not in st.session_state:
 
26
  if DB_CHAIN not in st.session_state:
27
  st.session_state[DB_CHAIN] = {}
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_sql_agent():
30
  # db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
31
  db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
32
 
33
  toolkit = SQLDatabaseToolkit(llm=llm,db=db)
34
 
 
35
  return create_sql_agent(
36
  llm=llm,
37
  toolkit=toolkit,
 
41
  )
42
 
43
  def get_db_chain():
 
 
44
 
45
+ template = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
46
  Use the following format:
47
  ccyBoughtccyBought
48
  Question: "Question here"
49
  SQLQuery: "SQL Query to run"
50
  SQLResult: "Result of the SQLQuery"
51
  Answer: "Final answer here"
52
+
53
  Only use the following tables:
54
+
55
  {table_info}
56
+
57
  If someone asks any question involving client name, you need to join with Client table
58
  volume: you need to count records
59
  Amounts: you need to use USD amount
60
  Trades: you need to get volume of trades
61
  Currency Bought: you need to use ccyBought
62
  Currency Sold: you need to use ccySold
63
+
64
+
65
  Question: {input}"""
66
  PROMPT = PromptTemplate(
67
+ input_variables=["input", "table_info", "dialect"],
68
+ template=template
69
  )
70
 
71
  return SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)
 
89
 
90
  return alt.Chart(dataframe).mark_line()
91
 
92
+ # sql_agent = get_sql_agent()
93
+ db_chain = get_db_chain()
94
+ json_chain = get_json_chain()
95
 
96
 
97
  def from_gpt(query: str, plot: bool):
98
 
99
+ chains = [db_chain, json_chain] if plot else [db_chain]
 
 
 
 
 
 
 
 
 
 
100
  main_chain = SimpleSequentialChain(chains=chains, verbose=True)
 
101
  ans = main_chain.run(query)
102
 
103
  # Save data as json if plot
 
123
  elif prompt_lower == 'circle chart':
124
  on_render = st.write
125
  response = circle_chart()
 
 
 
 
 
 
 
126
  elif prompt_lower.startswith('json'):
127
  p = prompt_lower.split('json ')[1]
128
  on_render = st.write
 
166
  with st.sidebar:
167
  st.markdown(sidebar_text)
168
 
169
+ # prompt = chat.get_promt("Ask IDF Anything")
170
 
171
+ chat.process(get_response, llm)
chat.py CHANGED
@@ -7,7 +7,6 @@ from typing import Callable
7
 
8
  RESPONSE_LABEL = 'chat_response'
9
  PROMPT_LABEL = 'chat_prompt'
10
- TEXT_INPUT_LABEL = "chat_input"
11
 
12
  class Chat:
13
 
@@ -18,18 +17,7 @@ class Chat:
18
  if PROMPT_LABEL not in st.session_state:
19
  st.session_state[PROMPT_LABEL] = []
20
 
21
- if TEXT_INPUT_LABEL not in st.session_state:
22
- st.session_state[TEXT_INPUT_LABEL] = ''
23
-
24
- def get_promt(self, placeholder: str):
25
- # return st.text_input(
26
- # label="ChatIDF",
27
- # placeholder=placeholder,
28
- # key="chat_widget",
29
- # )
30
- return st.chat_input(placeholder=placeholder)
31
-
32
- def process(self, prompt: str, process_prompt: Callable, *args):
33
  """
34
  process_prompt(promt: str, *args) -> tuple(Any, Callable)
35
  callback to process the chat promt, it takes the promt for input
@@ -45,7 +33,7 @@ class Chat:
45
  on_render(response)
46
 
47
  # Compute prompt
48
- if prompt:
49
  st.session_state[PROMPT_LABEL].append(prompt)
50
  (response, on_render) = process_prompt(prompt, *args)
51
  st.session_state[RESPONSE_LABEL].append((response, on_render))
 
7
 
8
  RESPONSE_LABEL = 'chat_response'
9
  PROMPT_LABEL = 'chat_prompt'
 
10
 
11
  class Chat:
12
 
 
17
  if PROMPT_LABEL not in st.session_state:
18
  st.session_state[PROMPT_LABEL] = []
19
 
20
+ def process(self, process_prompt: Callable, *args):
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
  process_prompt(promt: str, *args) -> tuple(Any, Callable)
23
  callback to process the chat promt, it takes the promt for input
 
33
  on_render(response)
34
 
35
  # Compute prompt
36
+ if prompt:= st.chat_input("Ask IDF Anything"):
37
  st.session_state[PROMPT_LABEL].append(prompt)
38
  (response, on_render) = process_prompt(prompt, *args)
39
  st.session_state[RESPONSE_LABEL].append((response, on_render))