Spaces:
Sleeping
Sleeping
Jesus Sanchez commited on
Commit ·
7c3a8b1
1
Parent(s): 8e461e8
plot
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
|
@@ -11,8 +12,13 @@ from langchain import OpenAI
|
|
| 11 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 12 |
from langchain.chains import SimpleSequentialChain
|
| 13 |
|
|
|
|
| 14 |
llm=OpenAI(temperature=0)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def tables_from_db():
|
| 17 |
db = sqlite3.connect('switrs.sqlite')
|
| 18 |
cursor = db.cursor()
|
|
@@ -57,6 +63,13 @@ def get_json_chain():
|
|
| 57 |
prompt=PromptTemplate.from_template(prompt_template)
|
| 58 |
)
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def from_gpt(query: str, plot: bool):
|
|
@@ -69,7 +82,12 @@ def from_gpt(query: str, plot: bool):
|
|
| 69 |
# Run the query using the agent executor
|
| 70 |
main_chain = SimpleSequentialChain(chains=chains, verbose=True)
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def circle_chart():
|
| 75 |
df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
|
|
@@ -96,10 +114,13 @@ def get_response(prompt: str, *kargs):
|
|
| 96 |
table = prompt_lower.split(":")[1]
|
| 97 |
on_render = st.write
|
| 98 |
response = from_db(table)
|
| 99 |
-
elif prompt_lower.startswith('
|
| 100 |
-
p = prompt_lower.split('
|
| 101 |
on_render = st.write
|
| 102 |
response = from_gpt(p, plot=True)
|
|
|
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
on_render = st.write
|
| 105 |
response = from_gpt(prompt, plot=False)
|
|
|
|
| 1 |
+
from os import write
|
| 2 |
import streamlit as st
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
|
|
|
| 12 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 13 |
from langchain.chains import SimpleSequentialChain
|
| 14 |
|
| 15 |
+
JSON_DATA_LABEL = 'json_data'
|
| 16 |
llm=OpenAI(temperature=0)
|
| 17 |
|
| 18 |
+
|
| 19 |
+
if JSON_DATA_LABEL not in st.session_state:
|
| 20 |
+
st.session_state[JSON_DATA_LABEL] = {}
|
| 21 |
+
|
| 22 |
def tables_from_db():
|
| 23 |
db = sqlite3.connect('switrs.sqlite')
|
| 24 |
cursor = db.cursor()
|
|
|
|
| 63 |
prompt=PromptTemplate.from_template(prompt_template)
|
| 64 |
)
|
| 65 |
|
| 66 |
+
def plot_chart():
|
| 67 |
+
json = st.session_state[JSON_DATA_LABEL]
|
| 68 |
+
if not json:
|
| 69 |
+
return "no data to plot"
|
| 70 |
+
|
| 71 |
+
return alt.Chart(json).mark_line()
|
| 72 |
+
|
| 73 |
|
| 74 |
|
| 75 |
def from_gpt(query: str, plot: bool):
|
|
|
|
| 82 |
# Run the query using the agent executor
|
| 83 |
main_chain = SimpleSequentialChain(chains=chains, verbose=True)
|
| 84 |
|
| 85 |
+
ans = main_chain.run(query)
|
| 86 |
+
|
| 87 |
+
# Save data as json if plot
|
| 88 |
+
st.session_state[JSON_DATA_LABEL] = ans if plot else {}
|
| 89 |
+
|
| 90 |
+
return ans
|
| 91 |
|
| 92 |
def circle_chart():
|
| 93 |
df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
|
|
|
|
| 114 |
table = prompt_lower.split(":")[1]
|
| 115 |
on_render = st.write
|
| 116 |
response = from_db(table)
|
| 117 |
+
elif prompt_lower.startswith('json'):
|
| 118 |
+
p = prompt_lower.split('json ')[1]
|
| 119 |
on_render = st.write
|
| 120 |
response = from_gpt(p, plot=True)
|
| 121 |
+
elif prompt_lower == 'plot':
|
| 122 |
+
on_render = st.write
|
| 123 |
+
response = plot_chart()
|
| 124 |
else:
|
| 125 |
on_render = st.write
|
| 126 |
response = from_gpt(prompt, plot=False)
|