Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,13 +10,9 @@ from langchain import OpenAI
|
|
| 10 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 11 |
from langchain.chains import SimpleSequentialChain
|
| 12 |
from langchain import SQLDatabaseChain
|
| 13 |
-
from langchain.agents.agent_toolkits import create_python_agent
|
| 14 |
-
from langchain.tools.python.tool import PythonREPLTool
|
| 15 |
-
from langchain.python import PythonREPL
|
| 16 |
|
| 17 |
|
| 18 |
JSON_DATA_LABEL = 'json_data'
|
| 19 |
-
PY_CHAIN = 'py_chain'
|
| 20 |
llm=OpenAI(temperature=0)
|
| 21 |
db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
|
| 22 |
|
|
@@ -24,9 +20,6 @@ db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
|
|
| 24 |
if JSON_DATA_LABEL not in st.session_state:
|
| 25 |
st.session_state[JSON_DATA_LABEL] = {}
|
| 26 |
|
| 27 |
-
if PY_CHAIN not in st.session_state:
|
| 28 |
-
st.session_state[PY_CHAIN] = {}
|
| 29 |
-
|
| 30 |
def get_sql_agent():
|
| 31 |
toolkit = SQLDatabaseToolkit(llm=llm,db=db)
|
| 32 |
|
|
@@ -75,50 +68,31 @@ def get_json_chain():
|
|
| 75 |
prompt=PromptTemplate.from_template(prompt_template)
|
| 76 |
)
|
| 77 |
|
| 78 |
-
def get_python_chain():
|
| 79 |
-
return create_python_agent(
|
| 80 |
-
llm=llm,
|
| 81 |
-
tool=PythonREPLTool(),
|
| 82 |
-
verbose=True
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
def plot_chart():
|
| 86 |
json = st.session_state[JSON_DATA_LABEL]
|
| 87 |
# print(json)
|
| 88 |
if not json:
|
| 89 |
return "no data to plot"
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
dataframe = alt.Data(values=json)
|
| 93 |
-
|
| 94 |
-
return alt.Chart(dataframe).mark_line()
|
| 95 |
-
|
| 96 |
-
def python_exec():
|
| 97 |
-
ans = st.session_state[PY_CHAIN]
|
| 98 |
-
|
| 99 |
-
if not ans:
|
| 100 |
-
return "no data to plot"
|
| 101 |
-
return python_chain.run("Plot as a bar chart this result:{ans}",ans)
|
| 102 |
-
|
| 103 |
# sql_agent = get_sql_agent()
|
| 104 |
db_chain = get_db_chain()
|
| 105 |
json_chain = get_json_chain()
|
| 106 |
python_chain = get_python_chain()
|
| 107 |
|
| 108 |
def from_gpt(query: str, plot: bool):
|
| 109 |
-
ans=""
|
| 110 |
try:
|
| 111 |
chains = [db_chain, json_chain] if plot else [db_chain]
|
| 112 |
main_chain = SimpleSequentialChain(chains=chains, verbose=True)
|
| 113 |
ans = main_chain.run(query)
|
|
|
|
|
|
|
|
|
|
| 114 |
except Exception as e:
|
| 115 |
return "Please ask a proper business question related to selected datasets"
|
| 116 |
|
| 117 |
-
# Execute Python if plot
|
| 118 |
-
st.session_state[PY_CHAIN] = ans if plot else {}
|
| 119 |
-
|
| 120 |
-
return ans
|
| 121 |
-
|
| 122 |
def circle_chart():
|
| 123 |
df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
|
| 124 |
return alt.Chart(df).mark_circle().encode(
|
|
@@ -143,7 +117,7 @@ def get_response(prompt: str, *kargs):
|
|
| 143 |
response = from_gpt(p, plot=True)
|
| 144 |
elif prompt_lower == 'plot':
|
| 145 |
on_render = st.write
|
| 146 |
-
response =
|
| 147 |
else:
|
| 148 |
on_render = st.write
|
| 149 |
response = from_gpt(prompt, plot=False)
|
|
|
|
| 10 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 11 |
from langchain.chains import SimpleSequentialChain
|
| 12 |
from langchain import SQLDatabaseChain
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
JSON_DATA_LABEL = 'json_data'
|
|
|
|
| 16 |
llm=OpenAI(temperature=0)
|
| 17 |
db = SQLDatabase.from_uri("sqlite:///FXTrades.db")
|
| 18 |
|
|
|
|
| 20 |
if JSON_DATA_LABEL not in st.session_state:
|
| 21 |
st.session_state[JSON_DATA_LABEL] = {}
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
def get_sql_agent():
|
| 24 |
toolkit = SQLDatabaseToolkit(llm=llm,db=db)
|
| 25 |
|
|
|
|
| 68 |
prompt=PromptTemplate.from_template(prompt_template)
|
| 69 |
)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def plot_chart():
|
| 72 |
json = st.session_state[JSON_DATA_LABEL]
|
| 73 |
# print(json)
|
| 74 |
if not json:
|
| 75 |
return "no data to plot"
|
| 76 |
+
# Add Open AI call to format outcome in a table
|
| 77 |
+
return ""
|
| 78 |
|
| 79 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# sql_agent = get_sql_agent()
|
| 81 |
db_chain = get_db_chain()
|
| 82 |
json_chain = get_json_chain()
|
| 83 |
python_chain = get_python_chain()
|
| 84 |
|
| 85 |
def from_gpt(query: str, plot: bool):
|
|
|
|
| 86 |
try:
|
| 87 |
chains = [db_chain, json_chain] if plot else [db_chain]
|
| 88 |
main_chain = SimpleSequentialChain(chains=chains, verbose=True)
|
| 89 |
ans = main_chain.run(query)
|
| 90 |
+
# Execute Python if plot
|
| 91 |
+
st.session_state[JSON_DATA_LABEL] = ans if plot else {}
|
| 92 |
+
return ans
|
| 93 |
except Exception as e:
|
| 94 |
return "Please ask a proper business question related to selected datasets"
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def circle_chart():
|
| 97 |
df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c'])
|
| 98 |
return alt.Chart(df).mark_circle().encode(
|
|
|
|
| 117 |
response = from_gpt(p, plot=True)
|
| 118 |
elif prompt_lower == 'plot':
|
| 119 |
on_render = st.write
|
| 120 |
+
response = plot_chart()
|
| 121 |
else:
|
| 122 |
on_render = st.write
|
| 123 |
response = from_gpt(prompt, plot=False)
|