Spaces:
Sleeping
Sleeping
Jesus Sanchez commited on
Commit ·
e470cbb
1
Parent(s): ae74824
cleanup
Browse files- README.md +8 -0
- app.py +14 -7
- chat/__init__.py +0 -0
- chat/chat.py +0 -39
README.md
CHANGED
|
@@ -21,3 +21,11 @@ $ python -m venv .venv
|
|
| 21 |
$ pip install streamlit
|
| 22 |
$ python3 app.py
|
| 23 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
$ pip install streamlit
|
| 22 |
$ python3 app.py
|
| 23 |
```
|
| 24 |
+
|
| 25 |
+
There are 4 special promts:
|
| 26 |
+
- `db`: returns a list of the table names
|
| 27 |
+
- `db:<table>`: loads the table `<table>` and returns a chart where:
|
| 28 |
+
- x: column `case_id`
|
| 29 |
+
- y: random column from table
|
| 30 |
+
- `line chart`: a randomly generated line chart
|
| 31 |
+
- `circle chart`: a randomly generated circle chart
|
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import streamlit as st
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import altair as alt
|
| 5 |
-
|
| 6 |
import sqlite3
|
| 7 |
|
| 8 |
|
|
@@ -17,9 +17,15 @@ def tables_from_db():
|
|
| 17 |
|
| 18 |
def from_db(table: str):
|
| 19 |
db = sqlite3.connect('switrs.sqlite')
|
| 20 |
-
|
|
|
|
| 21 |
db.close()
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
|
|
@@ -29,6 +35,7 @@ def circle_chart():
|
|
| 29 |
x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
|
| 30 |
)
|
| 31 |
|
|
|
|
| 32 |
def get_response(prompt: str, *kargs):
|
| 33 |
on_render = st.write
|
| 34 |
response = f"Here's what you asked: '{prompt}'"
|
|
@@ -40,12 +47,13 @@ def get_response(prompt: str, *kargs):
|
|
| 40 |
elif prompt_lower == 'circle chart':
|
| 41 |
on_render = st.write
|
| 42 |
response = circle_chart()
|
| 43 |
-
elif prompt_lower == '
|
| 44 |
on_render = st.write
|
| 45 |
response = tables_from_db()
|
| 46 |
-
elif prompt_lower
|
|
|
|
| 47 |
on_render = st.write
|
| 48 |
-
response = from_db(
|
| 49 |
|
| 50 |
return (response, on_render)
|
| 51 |
|
|
@@ -53,4 +61,3 @@ def get_response(prompt: str, *kargs):
|
|
| 53 |
|
| 54 |
chat = idf_chat.Chat()
|
| 55 |
chat.get_input(get_response)
|
| 56 |
-
chat.display()
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import altair as alt
|
| 5 |
+
import chat as idf_chat
|
| 6 |
import sqlite3
|
| 7 |
|
| 8 |
|
|
|
|
| 17 |
|
| 18 |
def from_db(table: str):
|
| 19 |
db = sqlite3.connect('switrs.sqlite')
|
| 20 |
+
# Read into a panda DataFrame
|
| 21 |
+
df = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 50", db)
|
| 22 |
db.close()
|
| 23 |
+
columns = df.columns
|
| 24 |
+
# Pick a random column for y axis
|
| 25 |
+
column_index = np.random.randint(0, columns.size, 1)
|
| 26 |
+
y_column_name = columns[column_index][0]
|
| 27 |
+
print(columns)
|
| 28 |
+
return alt.Chart(df).mark_circle().encode(x='case_id', y=y_column_name)
|
| 29 |
|
| 30 |
|
| 31 |
|
|
|
|
| 35 |
x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c']
|
| 36 |
)
|
| 37 |
|
| 38 |
+
# Parse the prompt to pick an example and a render function
|
| 39 |
def get_response(prompt: str, *kargs):
|
| 40 |
on_render = st.write
|
| 41 |
response = f"Here's what you asked: '{prompt}'"
|
|
|
|
| 47 |
elif prompt_lower == 'circle chart':
|
| 48 |
on_render = st.write
|
| 49 |
response = circle_chart()
|
| 50 |
+
elif prompt_lower == 'db':
|
| 51 |
on_render = st.write
|
| 52 |
response = tables_from_db()
|
| 53 |
+
elif prompt_lower.startswith('db:'):
|
| 54 |
+
table = prompt_lower.split(":")[1]
|
| 55 |
on_render = st.write
|
| 56 |
+
response = from_db(table)
|
| 57 |
|
| 58 |
return (response, on_render)
|
| 59 |
|
|
|
|
| 61 |
|
| 62 |
chat = idf_chat.Chat()
|
| 63 |
chat.get_input(get_response)
|
|
|
chat/__init__.py
DELETED
|
File without changes
|
chat/chat.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
from typing import Callable
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
RESPONSE_LABEL = 'chat_response'
|
| 7 |
-
PROMPT_LABEL = 'chat_user_input'
|
| 8 |
-
|
| 9 |
-
class Chat:
|
| 10 |
-
def __init__(self):
|
| 11 |
-
if RESPONSE_LABEL not in st.session_state:
|
| 12 |
-
st.session_state[RESPONSE_LABEL] = []
|
| 13 |
-
|
| 14 |
-
if PROMPT_LABEL not in st.session_state:
|
| 15 |
-
st.session_state[PROMPT_LABEL] = []
|
| 16 |
-
|
| 17 |
-
def get_input(self, process_prompt: Callable, *args):
|
| 18 |
-
"""
|
| 19 |
-
process_prompt(promt: str, *args) -> tuple(Any, Callable)
|
| 20 |
-
callback to process the chat promt, it takes the promt for input
|
| 21 |
-
and returns a tuple with the response and a render callback
|
| 22 |
-
"""
|
| 23 |
-
promt = st.chat_input(placeholder="Ask Me Anything", key='chat_text_input')
|
| 24 |
-
# promt = st.text_input(label="Ask Me Anything", key='chat_text_input')
|
| 25 |
-
|
| 26 |
-
if promt:
|
| 27 |
-
st.session_state[PROMPT_LABEL].append(promt)
|
| 28 |
-
(response, on_render) = process_prompt(promt, *args)
|
| 29 |
-
st.session_state[RESPONSE_LABEL].append((response, on_render))
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def display(self):
|
| 33 |
-
if st.session_state[RESPONSE_LABEL]:
|
| 34 |
-
messages = zip(st.session_state[PROMPT_LABEL], st.session_state[RESPONSE_LABEL])
|
| 35 |
-
for prompt, (response, on_render) in messages:
|
| 36 |
-
st.chat_message("user").write(prompt)
|
| 37 |
-
with st.chat_message("assistant"):
|
| 38 |
-
on_render(response)
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|