Jesus Sanchez commited on
Commit
d72f607
·
1 Parent(s): 485e943

merge conlict

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -10,6 +10,11 @@ from langchain.agents import create_sql_agent
10
  from langchain import OpenAI
11
  from langchain import PromptTemplate, OpenAI, LLMChain
12
 
 
 
 
 
 
13
 
14
  def tables_from_db():
15
  db = sqlite3.connect('switrs.sqlite')
@@ -32,6 +37,10 @@ def from_db(table: str):
32
  print(columns)
33
  return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
34
 
 
 
 
 
35
  def from_gpt(query: str):
36
  # Create the agent executor
37
  db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
@@ -49,7 +58,7 @@ def from_gpt(query: str):
49
  )
50
 
51
  # Run the query using the agent executor
52
- return agent_executor.run(query)
53
 
54
  def parse_into_json(result: str):
55
  prompt_template = "Reformat this {result} in JSON format"
@@ -62,13 +71,19 @@ def parse_into_json(result: str):
62
 
63
  Final_Results['text']
64
 
65
- def parse_into_chart(response: str):
66
- if ":" not in response:
67
- return response
68
 
69
- data = response.split(":")[1]
70
- return alt.Chart(data).mark_circle()
71
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  def circle_chart():
@@ -96,10 +111,12 @@ def get_response(prompt: str, *kargs):
96
  table = prompt_lower.split(":")[1]
97
  on_render = st.write
98
  response = from_db(table)
 
 
 
99
  else:
100
  on_render = st.write
101
  response = from_gpt(prompt)
102
- # response = parse_into_chart(response)
103
 
104
  return (response, on_render)
105
 
 
10
  from langchain import OpenAI
11
  from langchain import PromptTemplate, OpenAI, LLMChain
12
 
13
+ JSON_DATA_LABEL = 'json_data'
14
+
15
+ if JSON_DATA_LABEL not in st.session_state:
16
+ st.session_state[JSON_DATA_LABEL] = {}
17
+
18
 
19
  def tables_from_db():
20
  db = sqlite3.connect('switrs.sqlite')
 
37
  print(columns)
38
  return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
39
 
40
+ def parse_into_json(ans: str) -> (str, str):
41
+ # json parsing goes here
42
+ return ans
43
+
44
  def from_gpt(query: str):
45
  # Create the agent executor
46
  db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
 
58
  )
59
 
60
  # Run the query using the agent executor
61
+ ans = agent_executor.run(query)
62
 
63
  def parse_into_json(result: str):
64
  prompt_template = "Reformat this {result} in JSON format"
 
71
 
72
  Final_Results['text']
73
 
74
+ # Extract json data from response
75
+ (response, json_data) = parse_into_json(ans)
 
76
 
77
+ st.session_state[JSON_DATA_LABEL] = json_data
 
78
 
79
+ return response
80
+
81
+ def plot():
82
+ """
83
+ Tries to plot the last generated answer
84
+ """
85
+ json = st.session_state[JSON_DATA_LABEL]
86
+ return {} if not json else alt.Chart(json).mark_circle()
87
 
88
 
89
  def circle_chart():
 
111
  table = prompt_lower.split(":")[1]
112
  on_render = st.write
113
  response = from_db(table)
114
+ elif prompt_lower == 'plot':
115
+ on_render = st.write
116
+ response = plot()
117
  else:
118
  on_render = st.write
119
  response = from_gpt(prompt)
 
120
 
121
  return (response, on_render)
122