Jesus Sanchez commited on
Commit
e470cbb
·
1 Parent(s): ae74824
Files changed (4) hide show
  1. README.md +8 -0
  2. app.py +14 -7
  3. chat/__init__.py +0 -0
  4. chat/chat.py +0 -39
README.md CHANGED
@@ -21,3 +21,11 @@ $ python -m venv .venv
21
  $ pip install streamlit
22
  $ python3 app.py
23
  ```
 
 
 
 
 
 
 
 
 
21
  $ pip install streamlit
22
  $ python3 app.py
23
  ```
24
+
25
+ There are 4 special promts:
26
+ - `db`: returns a list of the table names
27
+ - `db:<table>`: loads the table `<table>` and returns a chart where:
28
+ - x: column `case_id`
29
+ - y: random column from table
30
+ - `line chart`: a randomly generated line chart
31
+ - `circle chart`: a randomly generated circle chart
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
  import altair as alt
5
- from chat import chat as idf_chat
6
  import sqlite3
7
 
8
 
@@ -17,9 +17,15 @@ def tables_from_db():
17
 
18
  def from_db(table: str):
19
  db = sqlite3.connect('switrs.sqlite')
20
- df = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 100", db)
 
21
  db.close()
22
- return alt.Chart(df).mark_circle()
 
 
 
 
 
23
 
24
 
25
 
@@ -29,6 +35,7 @@ def circle_chart():
29
  x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
30
  )
31
 
 
32
  def get_response(prompt: str, *kargs):
33
  on_render = st.write
34
  response = f"Here's what you asked: '{prompt}'"
@@ -40,12 +47,13 @@ def get_response(prompt: str, *kargs):
40
  elif prompt_lower == 'circle chart':
41
  on_render = st.write
42
  response = circle_chart()
43
- elif prompt_lower == 'sql':
44
  on_render = st.write
45
  response = tables_from_db()
46
- elif prompt_lower == 'parties':
 
47
  on_render = st.write
48
- response = from_db(prompt_lower)
49
 
50
  return (response, on_render)
51
 
@@ -53,4 +61,3 @@ def get_response(prompt: str, *kargs):
53
 
54
  chat = idf_chat.Chat()
55
  chat.get_input(get_response)
56
- chat.display()
 
2
  import numpy as np
3
  import pandas as pd
4
  import altair as alt
5
+ import chat as idf_chat
6
  import sqlite3
7
 
8
 
 
17
 
18
  def from_db(table: str):
19
  db = sqlite3.connect('switrs.sqlite')
20
+ # Read into a panda DataFrame
21
+ df = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 50", db)
22
  db.close()
23
+ columns = df.columns
24
+ # Pick a random column for y axis
25
+ column_index = np.random.randint(0, columns.size, 1)
26
+ y_column_name = columns[column_index][0]
27
+ print(columns)
28
+ return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
29
 
30
 
31
 
 
35
  x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
36
  )
37
 
38
+ # Parse the prompt to pick an example and a render function
39
  def get_response(prompt: str, *kargs):
40
  on_render = st.write
41
  response = f"Here's what you asked: '{prompt}'"
 
47
  elif prompt_lower == 'circle chart':
48
  on_render = st.write
49
  response = circle_chart()
50
+ elif prompt_lower == 'db':
51
  on_render = st.write
52
  response = tables_from_db()
53
+ elif prompt_lower.startswith('db:'):
54
+ table = prompt_lower.split(":")[1]
55
  on_render = st.write
56
+ response = from_db(table)
57
 
58
  return (response, on_render)
59
 
 
61
 
62
  chat = idf_chat.Chat()
63
  chat.get_input(get_response)
 
chat/__init__.py DELETED
File without changes
chat/chat.py DELETED
@@ -1,39 +0,0 @@
1
- import streamlit as st
2
- from typing import Callable
3
-
4
-
5
-
6
- RESPONSE_LABEL = 'chat_response'
7
- PROMPT_LABEL = 'chat_user_input'
8
-
9
- class Chat:
10
- def __init__(self):
11
- if RESPONSE_LABEL not in st.session_state:
12
- st.session_state[RESPONSE_LABEL] = []
13
-
14
- if PROMPT_LABEL not in st.session_state:
15
- st.session_state[PROMPT_LABEL] = []
16
-
17
- def get_input(self, process_prompt: Callable, *args):
18
- """
19
- process_prompt(promt: str, *args) -> tuple(Any, Callable)
20
- callback to process the chat promt, it takes the promt for input
21
- and returns a tuple with the response and a render callback
22
- """
23
- promt = st.chat_input(placeholder="Ask Me Anything", key='chat_text_input')
24
- # promt = st.text_input(label="Ask Me Anything", key='chat_text_input')
25
-
26
- if promt:
27
- st.session_state[PROMPT_LABEL].append(promt)
28
- (response, on_render) = process_prompt(promt, *args)
29
- st.session_state[RESPONSE_LABEL].append((response, on_render))
30
-
31
-
32
- def display(self):
33
- if st.session_state[RESPONSE_LABEL]:
34
- messages = zip(st.session_state[PROMPT_LABEL], st.session_state[RESPONSE_LABEL])
35
- for prompt, (response, on_render) in messages:
36
- st.chat_message("user").write(prompt)
37
- with st.chat_message("assistant"):
38
- on_render(response)
39
-