Spaces:
Sleeping
Sleeping
Jesus Sanchez commited on
Commit ·
d72f607
1
Parent(s): 485e943
merge conlict
Browse files
app.py
CHANGED
|
@@ -10,6 +10,11 @@ from langchain.agents import create_sql_agent
|
|
| 10 |
from langchain import OpenAI
|
| 11 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def tables_from_db():
|
| 15 |
db = sqlite3.connect('switrs.sqlite')
|
|
@@ -32,6 +37,10 @@ def from_db(table: str):
|
|
| 32 |
print(columns)
|
| 33 |
return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def from_gpt(query: str):
|
| 36 |
# Create the agent executor
|
| 37 |
db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
|
|
@@ -49,7 +58,7 @@ def from_gpt(query: str):
|
|
| 49 |
)
|
| 50 |
|
| 51 |
# Run the query using the agent executor
|
| 52 |
-
|
| 53 |
|
| 54 |
def parse_into_json(result: str):
|
| 55 |
prompt_template = "Reformat this {result} in JSON format"
|
|
@@ -62,13 +71,19 @@ def parse_into_json(result: str):
|
|
| 62 |
|
| 63 |
Final_Results['text']
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
return response
|
| 68 |
|
| 69 |
-
|
| 70 |
-
return alt.Chart(data).mark_circle()
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def circle_chart():
|
|
@@ -96,10 +111,12 @@ def get_response(prompt: str, *kargs):
|
|
| 96 |
table = prompt_lower.split(":")[1]
|
| 97 |
on_render = st.write
|
| 98 |
response = from_db(table)
|
|
|
|
|
|
|
|
|
|
| 99 |
else:
|
| 100 |
on_render = st.write
|
| 101 |
response = from_gpt(prompt)
|
| 102 |
-
# response = parse_into_chart(response)
|
| 103 |
|
| 104 |
return (response, on_render)
|
| 105 |
|
|
|
|
| 10 |
from langchain import OpenAI
|
| 11 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
| 12 |
|
| 13 |
+
JSON_DATA_LABEL = 'json_data'
|
| 14 |
+
|
| 15 |
+
if JSON_DATA_LABEL not in st.session_state:
|
| 16 |
+
st.session_state[JSON_DATA_LABEL] = {}
|
| 17 |
+
|
| 18 |
|
| 19 |
def tables_from_db():
|
| 20 |
db = sqlite3.connect('switrs.sqlite')
|
|
|
|
| 37 |
print(columns)
|
| 38 |
return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
|
| 39 |
|
| 40 |
+
def parse_into_json(ans: str) -> (str, str):
|
| 41 |
+
# json parsing goes here
|
| 42 |
+
return ans
|
| 43 |
+
|
| 44 |
def from_gpt(query: str):
|
| 45 |
# Create the agent executor
|
| 46 |
db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
# Run the query using the agent executor
|
| 61 |
+
ans = agent_executor.run(query)
|
| 62 |
|
| 63 |
def parse_into_json(result: str):
|
| 64 |
prompt_template = "Reformat this {result} in JSON format"
|
|
|
|
| 71 |
|
| 72 |
Final_Results['text']
|
| 73 |
|
| 74 |
+
# Extract json data from response
|
| 75 |
+
(response, json_data) = parse_into_json(ans)
|
|
|
|
| 76 |
|
| 77 |
+
st.session_state[JSON_DATA_LABEL] = json_data
|
|
|
|
| 78 |
|
| 79 |
+
return response
|
| 80 |
+
|
| 81 |
+
def plot():
|
| 82 |
+
"""
|
| 83 |
+
Tries to plot the last generated answer
|
| 84 |
+
"""
|
| 85 |
+
json = st.session_state[JSON_DATA_LABEL]
|
| 86 |
+
return {} if not json else alt.Chart(json).mark_circle()
|
| 87 |
|
| 88 |
|
| 89 |
def circle_chart():
|
|
|
|
| 111 |
table = prompt_lower.split(":")[1]
|
| 112 |
on_render = st.write
|
| 113 |
response = from_db(table)
|
| 114 |
+
elif prompt_lower == 'plot':
|
| 115 |
+
on_render = st.write
|
| 116 |
+
response = plot()
|
| 117 |
else:
|
| 118 |
on_render = st.write
|
| 119 |
response = from_gpt(prompt)
|
|
|
|
| 120 |
|
| 121 |
return (response, on_render)
|
| 122 |
|