Jesus Sanchez commited on
Commit
7c3a8b1
·
1 Parent(s): 8e461e8
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
@@ -11,8 +12,13 @@ from langchain import OpenAI
11
  from langchain import PromptTemplate, OpenAI, LLMChain
12
  from langchain.chains import SimpleSequentialChain
13
 
 
14
  llm=OpenAI(temperature=0)
15
 
 
 
 
 
16
  def tables_from_db():
17
  db = sqlite3.connect('switrs.sqlite')
18
  cursor = db.cursor()
@@ -57,6 +63,13 @@ def get_json_chain():
57
  prompt=PromptTemplate.from_template(prompt_template)
58
  )
59
 
 
 
 
 
 
 
 
60
 
61
 
62
  def from_gpt(query: str, plot: bool):
@@ -69,7 +82,12 @@ def from_gpt(query: str, plot: bool):
69
  # Run the query using the agent executor
70
  main_chain = SimpleSequentialChain(chains=chains, verbose=True)
71
 
72
- return main_chain.run(query)
 
 
 
 
 
73
 
74
  def circle_chart():
75
  df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
@@ -96,10 +114,13 @@ def get_response(prompt: str, *kargs):
96
  table = prompt_lower.split(":")[1]
97
  on_render = st.write
98
  response = from_db(table)
99
- elif prompt_lower.startswith('plot'):
100
- p = prompt_lower.split('plot ')[1]
101
  on_render = st.write
102
  response = from_gpt(p, plot=True)
 
 
 
103
  else:
104
  on_render = st.write
105
  response = from_gpt(prompt, plot=False)
 
1
+ from os import write
2
  import streamlit as st
3
  import numpy as np
4
  import pandas as pd
 
12
  from langchain import PromptTemplate, OpenAI, LLMChain
13
  from langchain.chains import SimpleSequentialChain
14
 
15
+ JSON_DATA_LABEL = 'json_data'
16
  llm=OpenAI(temperature=0)
17
 
18
+
19
+ if JSON_DATA_LABEL not in st.session_state:
20
+ st.session_state[JSON_DATA_LABEL] = {}
21
+
22
  def tables_from_db():
23
  db = sqlite3.connect('switrs.sqlite')
24
  cursor = db.cursor()
 
63
  prompt=PromptTemplate.from_template(prompt_template)
64
  )
65
 
66
+ def plot_chart():
67
+ json = st.session_state[JSON_DATA_LABEL]
68
+ if not json:
69
+ return "no data to plot"
70
+
71
+ return alt.Chart(json).mark_line()
72
+
73
 
74
 
75
  def from_gpt(query: str, plot: bool):
 
82
  # Run the query using the agent executor
83
  main_chain = SimpleSequentialChain(chains=chains, verbose=True)
84
 
85
+ ans = main_chain.run(query)
86
+
87
+ # Save data as json if plot
88
+ st.session_state[JSON_DATA_LABEL] = ans if plot else {}
89
+
90
+ return ans
91
 
92
  def circle_chart():
93
  df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
 
114
  table = prompt_lower.split(":")[1]
115
  on_render = st.write
116
  response = from_db(table)
117
+ elif prompt_lower.startswith('json'):
118
+ p = prompt_lower.split('json ')[1]
119
  on_render = st.write
120
  response = from_gpt(p, plot=True)
121
+ elif prompt_lower == 'plot':
122
+ on_render = st.write
123
+ response = plot_chart()
124
  else:
125
  on_render = st.write
126
  response = from_gpt(prompt, plot=False)