github-actions
commited on
Commit
·
7873f3c
1
Parent(s):
d4d8ea9
Sync updates from source repository
Browse files- app.py +84 -5
- query.py +8 -7
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1,13 +1,59 @@
|
|
| 1 |
from omegaconf import OmegaConf
|
| 2 |
from query import VectaraQuery
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
from streamlit_pills import pills
|
|
|
|
| 7 |
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
max_examples = 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def isTrue(x) -> bool:
|
| 13 |
if isinstance(x, bool):
|
|
@@ -16,11 +62,11 @@ def isTrue(x) -> bool:
|
|
| 16 |
|
| 17 |
def launch_bot():
|
| 18 |
def generate_response(question):
|
| 19 |
-
response = vq.submit_query(question)
|
| 20 |
return response
|
| 21 |
|
| 22 |
def generate_streaming_response(question):
|
| 23 |
-
response = vq.submit_query_streaming(question)
|
| 24 |
return response
|
| 25 |
|
| 26 |
def show_example_questions():
|
|
@@ -41,11 +87,13 @@ def launch_bot():
|
|
| 41 |
'source_data_desc': os.environ['source_data_desc'],
|
| 42 |
'streaming': isTrue(os.environ.get('streaming', False)),
|
| 43 |
'prompt_name': os.environ.get('prompt_name', None),
|
| 44 |
-
'examples': os.environ.get('examples', None)
|
|
|
|
| 45 |
})
|
| 46 |
st.session_state.cfg = cfg
|
| 47 |
st.session_state.ex_prompt = None
|
| 48 |
-
st.session_state.first_turn = True
|
|
|
|
| 49 |
example_messages = [example.strip() for example in cfg.examples.split(",")]
|
| 50 |
st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
|
| 51 |
|
|
@@ -60,7 +108,13 @@ def launch_bot():
|
|
| 60 |
image = Image.open('Vectara-logo.png')
|
| 61 |
st.image(image, width=175)
|
| 62 |
st.markdown(f"## About\n\n"
|
| 63 |
-
f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
st.markdown("---")
|
| 66 |
st.markdown(
|
|
@@ -111,7 +165,32 @@ def launch_bot():
|
|
| 111 |
st.write(response)
|
| 112 |
message = {"role": "assistant", "content": response}
|
| 113 |
st.session_state.messages.append(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
launch_bot()
|
|
|
|
| 1 |
from omegaconf import OmegaConf
|
| 2 |
from query import VectaraQuery
|
| 3 |
import os
|
| 4 |
+
import requests
|
| 5 |
+
import json
|
| 6 |
+
import uuid
|
| 7 |
|
| 8 |
import streamlit as st
|
| 9 |
from streamlit_pills import pills
|
| 10 |
+
from streamlit_feedback import streamlit_feedback
|
| 11 |
|
| 12 |
from PIL import Image
|
| 13 |
|
| 14 |
max_examples = 6
|
| 15 |
+
languages = {'English': 'eng', 'Spanish': 'spa', 'French': 'frs', 'Chinese': 'zho', 'German': 'deu', 'Hindi': 'hin', 'Arabic': 'ara',
|
| 16 |
+
'Portuguese': 'por', 'Italian': 'ita', 'Japanese': 'jpn', 'Korean': 'kor', 'Russian': 'rus', 'Turkish': 'tur', 'Persian (Farsi)': 'fas',
|
| 17 |
+
'Vietnamese': 'vie', 'Thai': 'tha', 'Hebrew': 'heb', 'Dutch': 'nld', 'Indonesian': 'ind', 'Polish': 'pol', 'Ukrainian': 'ukr',
|
| 18 |
+
'Romanian': 'ron', 'Swedish': 'swe', 'Czech': 'ces', 'Greek': 'ell', 'Bengali': 'ben', 'Malay (or Malaysian)': 'msa', 'Urdu': 'urd'}
|
| 19 |
+
|
| 20 |
+
# Setup for HTTP API Calls to Amplitude Analytics
|
| 21 |
+
if 'device_id' not in st.session_state:
|
| 22 |
+
st.session_state.device_id = str(uuid.uuid4())
|
| 23 |
+
|
| 24 |
+
headers = {
|
| 25 |
+
'Content-Type': 'application/json',
|
| 26 |
+
'Accept': '*/*'
|
| 27 |
+
}
|
| 28 |
+
amp_api_key = os.getenv('AMPLITUDE_TOKEN')
|
| 29 |
+
|
| 30 |
+
def thumbs_feedback(feedback, **kwargs):
|
| 31 |
+
"""
|
| 32 |
+
Sends feedback to Amplitude Analytics
|
| 33 |
+
"""
|
| 34 |
+
data = {
|
| 35 |
+
"api_key": amp_api_key,
|
| 36 |
+
"events": [{
|
| 37 |
+
"device_id": st.session_state.device_id,
|
| 38 |
+
"event_type": "provided_feedback",
|
| 39 |
+
"event_properties": {
|
| 40 |
+
"Space Name": kwargs.get("title", "Unknown Space Name"),
|
| 41 |
+
"Demo Type": "chatbot",
|
| 42 |
+
"query": kwargs.get("prompt", "No user input"),
|
| 43 |
+
"response": kwargs.get("response", "No chat response"),
|
| 44 |
+
"feedback": feedback["score"],
|
| 45 |
+
"Response Language": st.session_state.language
|
| 46 |
+
}
|
| 47 |
+
}]
|
| 48 |
+
}
|
| 49 |
+
response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
|
| 50 |
+
if response.status_code != 200:
|
| 51 |
+
print(f"Request failed with status code {response.status_code}. Response Text: {response.text}")
|
| 52 |
+
|
| 53 |
+
st.session_state.feedback_key += 1
|
| 54 |
+
|
| 55 |
+
if "feedback_key" not in st.session_state:
|
| 56 |
+
st.session_state.feedback_key = 0
|
| 57 |
|
| 58 |
def isTrue(x) -> bool:
|
| 59 |
if isinstance(x, bool):
|
|
|
|
| 62 |
|
| 63 |
def launch_bot():
|
| 64 |
def generate_response(question):
|
| 65 |
+
response = vq.submit_query(question, languages[st.session_state.language])
|
| 66 |
return response
|
| 67 |
|
| 68 |
def generate_streaming_response(question):
|
| 69 |
+
response = vq.submit_query_streaming(question, languages[st.session_state.language])
|
| 70 |
return response
|
| 71 |
|
| 72 |
def show_example_questions():
|
|
|
|
| 87 |
'source_data_desc': os.environ['source_data_desc'],
|
| 88 |
'streaming': isTrue(os.environ.get('streaming', False)),
|
| 89 |
'prompt_name': os.environ.get('prompt_name', None),
|
| 90 |
+
'examples': os.environ.get('examples', None),
|
| 91 |
+
'language': 'English'
|
| 92 |
})
|
| 93 |
st.session_state.cfg = cfg
|
| 94 |
st.session_state.ex_prompt = None
|
| 95 |
+
st.session_state.first_turn = True
|
| 96 |
+
st.session_state.language = cfg.language
|
| 97 |
example_messages = [example.strip() for example in cfg.examples.split(",")]
|
| 98 |
st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
|
| 99 |
|
|
|
|
| 108 |
image = Image.open('Vectara-logo.png')
|
| 109 |
st.image(image, width=175)
|
| 110 |
st.markdown(f"## About\n\n"
|
| 111 |
+
f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n")
|
| 112 |
+
|
| 113 |
+
cfg.language = st.selectbox('Language:', languages.keys())
|
| 114 |
+
if st.session_state.language != cfg.language:
|
| 115 |
+
st.session_state.language = cfg.language
|
| 116 |
+
print(f"DEBUG: Language changed to {st.session_state.language}")
|
| 117 |
+
st.rerun()
|
| 118 |
|
| 119 |
st.markdown("---")
|
| 120 |
st.markdown(
|
|
|
|
| 165 |
st.write(response)
|
| 166 |
message = {"role": "assistant", "content": response}
|
| 167 |
st.session_state.messages.append(message)
|
| 168 |
+
|
| 169 |
+
# Send query and response to Amplitude Analytics
|
| 170 |
+
data = {
|
| 171 |
+
"api_key": amp_api_key,
|
| 172 |
+
"events": [{
|
| 173 |
+
"device_id": st.session_state.device_id,
|
| 174 |
+
"event_type": "submitted_query",
|
| 175 |
+
"event_properties": {
|
| 176 |
+
"Space Name": cfg["title"],
|
| 177 |
+
"Demo Type": "chatbot",
|
| 178 |
+
"query": st.session_state.messages[-2]["content"],
|
| 179 |
+
"response": st.session_state.messages[-1]["content"],
|
| 180 |
+
"Response Language": st.session_state.language
|
| 181 |
+
}
|
| 182 |
+
}]
|
| 183 |
+
}
|
| 184 |
+
response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
|
| 185 |
+
if response.status_code != 200:
|
| 186 |
+
print(f"Amplitude request failed with status code {response.status_code}. Response Text: {response.text}")
|
| 187 |
st.rerun()
|
| 188 |
+
|
| 189 |
+
if (st.session_state.messages[-1]["role"] == "assistant") & (st.session_state.messages[-1]["content"] != "How may I help you?"):
|
| 190 |
+
streamlit_feedback(feedback_type="thumbs", on_submit = thumbs_feedback, key = st.session_state.feedback_key,
|
| 191 |
+
kwargs = {"prompt": st.session_state.messages[-2]["content"],
|
| 192 |
+
"response": st.session_state.messages[-1]["content"],
|
| 193 |
+
"title": cfg["title"]})
|
| 194 |
|
| 195 |
if __name__ == "__main__":
|
| 196 |
launch_bot()
|
query.py
CHANGED
|
@@ -10,7 +10,7 @@ class VectaraQuery():
|
|
| 10 |
self.conv_id = None
|
| 11 |
|
| 12 |
|
| 13 |
-
def get_body(self, query_str: str, stream: False):
|
| 14 |
corpora_list = [{
|
| 15 |
'corpus_key': corpus_key, 'lexical_interpolation': 0.005
|
| 16 |
} for corpus_key in self.corpus_keys
|
|
@@ -40,11 +40,12 @@ class VectaraQuery():
|
|
| 40 |
{
|
| 41 |
'prompt_name': self.prompt_name,
|
| 42 |
'max_used_search_results': 10,
|
| 43 |
-
'response_language':
|
| 44 |
'citations':
|
| 45 |
{
|
| 46 |
'style': 'none'
|
| 47 |
-
}
|
|
|
|
| 48 |
},
|
| 49 |
'chat':
|
| 50 |
{
|
|
@@ -70,14 +71,14 @@ class VectaraQuery():
|
|
| 70 |
"grpc-timeout": "60S"
|
| 71 |
}
|
| 72 |
|
| 73 |
-
def submit_query(self, query_str: str):
|
| 74 |
|
| 75 |
if self.conv_id:
|
| 76 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
| 77 |
else:
|
| 78 |
endpoint = "https://api.vectara.io/v2/chats"
|
| 79 |
|
| 80 |
-
body = self.get_body(query_str, stream=False)
|
| 81 |
|
| 82 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
|
| 83 |
|
|
@@ -96,14 +97,14 @@ class VectaraQuery():
|
|
| 96 |
|
| 97 |
return summary
|
| 98 |
|
| 99 |
-
def submit_query_streaming(self, query_str: str):
|
| 100 |
|
| 101 |
if self.conv_id:
|
| 102 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
| 103 |
else:
|
| 104 |
endpoint = "https://api.vectara.io/v2/chats"
|
| 105 |
|
| 106 |
-
body = self.get_body(query_str, stream=True)
|
| 107 |
|
| 108 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True)
|
| 109 |
|
|
|
|
| 10 |
self.conv_id = None
|
| 11 |
|
| 12 |
|
| 13 |
+
def get_body(self, query_str: str, response_lang: str, stream: False):
|
| 14 |
corpora_list = [{
|
| 15 |
'corpus_key': corpus_key, 'lexical_interpolation': 0.005
|
| 16 |
} for corpus_key in self.corpus_keys
|
|
|
|
| 40 |
{
|
| 41 |
'prompt_name': self.prompt_name,
|
| 42 |
'max_used_search_results': 10,
|
| 43 |
+
'response_language': response_lang,
|
| 44 |
'citations':
|
| 45 |
{
|
| 46 |
'style': 'none'
|
| 47 |
+
},
|
| 48 |
+
'enable_factual_consistency_score': False
|
| 49 |
},
|
| 50 |
'chat':
|
| 51 |
{
|
|
|
|
| 71 |
"grpc-timeout": "60S"
|
| 72 |
}
|
| 73 |
|
| 74 |
+
def submit_query(self, query_str: str, language: str):
|
| 75 |
|
| 76 |
if self.conv_id:
|
| 77 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
| 78 |
else:
|
| 79 |
endpoint = "https://api.vectara.io/v2/chats"
|
| 80 |
|
| 81 |
+
body = self.get_body(query_str, language, stream=False)
|
| 82 |
|
| 83 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
|
| 84 |
|
|
|
|
| 97 |
|
| 98 |
return summary
|
| 99 |
|
| 100 |
+
def submit_query_streaming(self, query_str: str, language: str):
|
| 101 |
|
| 102 |
if self.conv_id:
|
| 103 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
| 104 |
else:
|
| 105 |
endpoint = "https://api.vectara.io/v2/chats"
|
| 106 |
|
| 107 |
+
body = self.get_body(query_str, language, stream=True)
|
| 108 |
|
| 109 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True)
|
| 110 |
|
requirements.txt
CHANGED
|
@@ -3,3 +3,5 @@ toml==0.10.2
|
|
| 3 |
omegaconf==2.3.0
|
| 4 |
syrupy==4.0.8
|
| 5 |
streamlit_pills==0.3.0
|
|
|
|
|
|
|
|
|
| 3 |
omegaconf==2.3.0
|
| 4 |
syrupy==4.0.8
|
| 5 |
streamlit_pills==0.3.0
|
| 6 |
+
streamlit-feedback==0.1.3
|
| 7 |
+
uuid==1.30
|