Update app.py
Browse files
app.py
CHANGED
|
@@ -1,203 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from openai import OpenAI
|
| 3 |
from pymongo.mongo_client import MongoClient
|
| 4 |
from pymongo.server_api import ServerApi
|
| 5 |
-
from datetime import datetime
|
| 6 |
-
import random
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
st.set_page_config(
|
| 9 |
page_title="Bot",
|
| 10 |
page_icon="🤖",
|
| 11 |
initial_sidebar_state="collapsed",
|
| 12 |
layout="wide",
|
| 13 |
menu_items={
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
}
|
| 17 |
)
|
|
|
|
| 18 |
st.markdown(
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
-
### Setting up the session state
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
st.session_state.messages = []
|
| 34 |
-
st.session_state.
|
| 35 |
-
st.session_state.client = ""
|
| 36 |
-
|
| 37 |
-
st.session_state.user_data = {}
|
| 38 |
-
st.session_state.user_data["BASE_URL"] = ""
|
| 39 |
-
st.session_state.user_data["MODEL_PATH"] = ""
|
| 40 |
-
st.session_state.user_data["url_id"] = True
|
| 41 |
-
st.session_state.user_data["user_id"] = str(random.randint(100000, 999999))
|
| 42 |
-
st.session_state.user_data["start_time"] = datetime.now()
|
| 43 |
-
st.session_state.inserted = 1
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
st.session_state.convo_start_time = ''
|
| 47 |
st.session_state.setup = False
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def setup_messages():
|
| 59 |
-
# 1 = true control,
|
| 60 |
-
# 2 = base,
|
| 61 |
-
# 3 = bridging,
|
| 62 |
-
# 4 = GPT
|
| 63 |
-
|
| 64 |
-
if st.query_params["p"] == "1":
|
| 65 |
-
pass
|
| 66 |
-
|
| 67 |
-
elif st.query_params["p"] == "2":
|
| 68 |
-
st.session_state.user_data["BASE_URL"] = "https://openrouter.ai/api/v1"
|
| 69 |
-
st.session_state.user_data["MODEL_PATH"] = "meta-llama/llama-3.3-70b-instruct"
|
| 70 |
-
st.session_state.API = "OPENROUTER_API_KEY"
|
| 71 |
-
|
| 72 |
-
elif st.query_params["p"] == "3":
|
| 73 |
-
st.session_state.user_data["BASE_URL"] = "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1"
|
| 74 |
-
st.session_state.user_data["MODEL_PATH"] = "tinker://808e4f02-e847-54ae-bc75-f14ee885ce5a:train:0/sampler_weights/final_sampler"
|
| 75 |
-
st.session_state.API = "TINKER_API_KEY"
|
| 76 |
-
|
| 77 |
-
elif st.query_params["p"] == "4":
|
| 78 |
-
st.session_state.user_data["BASE_URL"] = "https://openrouter.ai/api/v1"
|
| 79 |
-
st.session_state.user_data["MODEL_PATH"] = "openai/gpt-5.4"
|
| 80 |
-
st.session_state.API = "OPENROUTER_API_KEY"
|
| 81 |
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
st.session_state.convo_start_time = datetime.now()
|
| 84 |
-
st.session_state.user_data["random_pid"] = st.query_params['id']
|
| 85 |
-
st.session_state.user_data["condition"] = st.query_params['p']
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
st.session_state.setup = True
|
| 89 |
|
| 90 |
-
if st.session_state.setup == False:
|
| 91 |
-
setup_messages()
|
| 92 |
-
|
| 93 |
-
if len(st.session_state.messages) == 1 and st.session_state.inserted < 2:
|
| 94 |
-
st.success("Ask, request, or talk to the chatbot about something you consider **politically polarizing** or something that people from different US political parties might disagree about.", icon='🎯')
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
st.success("Ask, request, or talk to the chatbot about something you consider **politically polarizing** or something that people from different US political parties might disagree about.", icon='🎯')
|
| 102 |
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
st.markdown(f"""# **Step 2. Use the *Submit Interaction* button to get your chatbot word**
|
| 106 |
-
|
| 107 |
-
⚠️ You must respond **at least 5 times** before you will see a *Submit Interaction* button. You can continue before submitting, but **you must Submit Interaction and enter your chatbot word to proceed with the survey**.
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
with st.chat_message(message["role"]):
|
| 115 |
-
st.markdown(message["content"])
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
st.markdown("## Copy your WORD!")
|
| 124 |
-
st.markdown(
|
| 125 |
-
st.markdown(
|
| 126 |
-
st.markdown(
|
| 127 |
-
|
| 128 |
-
elif prompt := st.chat_input("Type to ask a question or respond..."):
|
| 129 |
-
if len(st.session_state.messages) == 1:
|
| 130 |
-
clean_prompt = prompt.strip().lower()
|
| 131 |
-
clean_prompt = clean_prompt.strip("!,.?")
|
| 132 |
-
greetings = {
|
| 133 |
-
"hi", "hello", "hey", "heya", "hiya", "yo", "howdy", "sup",
|
| 134 |
-
"good morning", "good afternoon", "good evening", "good day",
|
| 135 |
-
"what's up", "whats up", "how do you do", "greetings",
|
| 136 |
-
"salutations", "hi there", "hello there", "hey there",
|
| 137 |
-
"how's it going", "hows it going", "how are you", "how are you doing"
|
| 138 |
-
}
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
| 148 |
with st.chat_message("user"):
|
| 149 |
st.markdown(prompt)
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
st.session_state.max_messages = len(st.session_state.messages)
|
| 155 |
st.rerun()
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
completion = st.session_state.client.chat.completions.create(
|
| 162 |
-
model=st.session_state.user_data["MODEL_PATH"],
|
| 163 |
-
messages=[
|
| 164 |
-
{"role": m["role"], "content": m["content"]}
|
| 165 |
-
for m in st.session_state.messages
|
| 166 |
-
],
|
| 167 |
-
stream=False,
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
response = completion.choices[0].message.content or ""
|
| 171 |
-
st.markdown(response)
|
| 172 |
-
|
| 173 |
-
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 174 |
-
|
| 175 |
-
except:
|
| 176 |
-
rate_limit_message = "An error has occurred or you've reached the maximum conversation length. Please submit the conversation."
|
| 177 |
-
st.session_state.messages.append(
|
| 178 |
-
{"role": "assistant", "content": rate_limit_message}
|
| 179 |
-
)
|
| 180 |
-
st.session_state.max_messages = len(st.session_state.messages)
|
| 181 |
-
st.rerun()
|
| 182 |
-
|
| 183 |
-
if len(st.session_state.messages) > 10 or st.session_state.max_messages == len(st.session_state.messages):
|
| 184 |
-
columns = st.columns((1,1,1))
|
| 185 |
-
with columns[2]:
|
| 186 |
-
if st.button("Submit Interaction",use_container_width=True):
|
| 187 |
-
keys = ["inserted", "messages", "convo_start_time"]
|
| 188 |
|
| 189 |
-
st.
|
| 190 |
-
|
| 191 |
-
st.session_state.user_data["convo_end_time"] = datetime.now()
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
-
collection.insert_one(user_data)
|
| 199 |
-
st.session_state.inserted += 1
|
| 200 |
-
done = True
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
import streamlit as st
|
| 5 |
from openai import OpenAI
|
| 6 |
from pymongo.mongo_client import MongoClient
|
| 7 |
from pymongo.server_api import ServerApi
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
|
| 10 |
+
# -----------------------------
|
| 11 |
+
# Page config
|
| 12 |
+
# -----------------------------
|
| 13 |
st.set_page_config(
|
| 14 |
page_title="Bot",
|
| 15 |
page_icon="🤖",
|
| 16 |
initial_sidebar_state="collapsed",
|
| 17 |
layout="wide",
|
| 18 |
menu_items={
|
| 19 |
+
"Report a bug": "mailto:yk408@cam.ac.uk",
|
| 20 |
+
"About": "Bot",
|
| 21 |
+
},
|
| 22 |
)
|
| 23 |
+
|
| 24 |
st.markdown(
|
| 25 |
+
"""
|
| 26 |
+
<style>
|
| 27 |
+
div[role="radiogroup"] > :first-child {
|
| 28 |
+
display: none !important;
|
| 29 |
+
}
|
| 30 |
+
</style>
|
| 31 |
+
""",
|
| 32 |
+
unsafe_allow_html=True,
|
| 33 |
)
|
| 34 |
|
|
|
|
| 35 |
|
| 36 |
+
# -----------------------------
|
| 37 |
+
# Constants
|
| 38 |
+
# -----------------------------
|
| 39 |
+
MAX_MESSAGES_DEFAULT = 50
|
| 40 |
+
SUBMIT_AFTER_USER_TURNS = 5
|
| 41 |
+
|
| 42 |
+
GREETINGS = {
|
| 43 |
+
"hi", "hello", "hey", "heya", "hiya", "yo", "howdy", "sup",
|
| 44 |
+
"good morning", "good afternoon", "good evening", "good day",
|
| 45 |
+
"what's up", "whats up", "how do you do", "greetings",
|
| 46 |
+
"salutations", "hi there", "hello there", "hey there",
|
| 47 |
+
"how's it going", "hows it going", "how are you", "how are you doing"
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
CONDITIONS = {
|
| 51 |
+
"1": {
|
| 52 |
+
"label": "true control",
|
| 53 |
+
"base_url": None,
|
| 54 |
+
"model": None,
|
| 55 |
+
"api_secret": None,
|
| 56 |
+
},
|
| 57 |
+
"2": {
|
| 58 |
+
"label": "base",
|
| 59 |
+
"base_url": "https://openrouter.ai/api/v1",
|
| 60 |
+
"model": "meta-llama/llama-3.3-70b-instruct",
|
| 61 |
+
"api_secret": "OPENROUTER_API_KEY",
|
| 62 |
+
},
|
| 63 |
+
"3": {
|
| 64 |
+
"label": "bridging",
|
| 65 |
+
"base_url": "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1",
|
| 66 |
+
"model": "tinker://808e4f02-e847-54ae-bc75-f14ee885ce5a:train:0/sampler_weights/final_sampler",
|
| 67 |
+
"api_secret": "TINKER_API_KEY",
|
| 68 |
+
},
|
| 69 |
+
"4": {
|
| 70 |
+
"label": "gpt",
|
| 71 |
+
"base_url": "https://openrouter.ai/api/v1",
|
| 72 |
+
"model": "openai/gpt-5.4",
|
| 73 |
+
"api_secret": "OPENROUTER_API_KEY",
|
| 74 |
+
},
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# -----------------------------
|
| 79 |
+
# Session state helpers
|
| 80 |
+
# -----------------------------
|
| 81 |
+
def init_session_state() -> None:
|
| 82 |
+
if "initialized" in st.session_state:
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
user_id = str(random.randint(100000, 999999))
|
| 86 |
+
|
| 87 |
+
st.session_state.initialized = True
|
| 88 |
+
st.session_state.inserted = 0
|
| 89 |
+
st.session_state.max_messages = MAX_MESSAGES_DEFAULT
|
| 90 |
st.session_state.messages = []
|
| 91 |
+
st.session_state.client = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
st.session_state.setup = False
|
| 93 |
+
st.session_state.convo_start_time = None
|
| 94 |
|
| 95 |
+
st.session_state.user_data = {
|
| 96 |
+
"BASE_URL": "",
|
| 97 |
+
"MODEL_PATH": "",
|
| 98 |
+
"url_id": True,
|
| 99 |
+
"user_id": user_id,
|
| 100 |
+
"start_time": datetime.now(),
|
| 101 |
+
"random_pid": None,
|
| 102 |
+
"condition": None,
|
| 103 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
+
def reset_conversation_state() -> None:
|
| 107 |
+
st.session_state.messages = []
|
| 108 |
+
st.session_state.max_messages = MAX_MESSAGES_DEFAULT
|
| 109 |
+
st.session_state.convo_start_time = None
|
| 110 |
+
st.session_state.setup = False
|
| 111 |
+
st.session_state.client = None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# -----------------------------
|
| 115 |
+
# Query param / condition helpers
|
| 116 |
+
# -----------------------------
|
| 117 |
+
def ensure_query_params() -> None:
|
| 118 |
+
if "p" not in st.query_params or st.query_params["p"] not in CONDITIONS:
|
| 119 |
+
st.query_params["p"] = st.radio(
|
| 120 |
+
"Select a condition for the conversation",
|
| 121 |
+
["", "1", "2", "3", "4"],
|
| 122 |
+
help="1 = true control, 2 = base, 3 = bridging, 4 = gpt",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if "id" not in st.query_params or not st.query_params["id"]:
|
| 126 |
+
st.session_state.user_data["url_id"] = False
|
| 127 |
+
st.query_params["id"] = st.session_state.user_data["user_id"]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def setup_conversation() -> None:
|
| 131 |
+
condition = st.query_params.get("p", "")
|
| 132 |
+
if condition not in CONDITIONS:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
config = CONDITIONS[condition]
|
| 136 |
+
|
| 137 |
+
st.session_state.user_data["random_pid"] = st.query_params["id"]
|
| 138 |
+
st.session_state.user_data["condition"] = condition
|
| 139 |
+
st.session_state.user_data["BASE_URL"] = config["base_url"] or ""
|
| 140 |
+
st.session_state.user_data["MODEL_PATH"] = config["model"] or ""
|
| 141 |
st.session_state.convo_start_time = datetime.now()
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
if condition != "1":
|
| 144 |
+
st.session_state.client = OpenAI(
|
| 145 |
+
base_url=config["base_url"],
|
| 146 |
+
api_key=st.secrets[config["api_secret"]],
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
st.session_state.client = None
|
| 150 |
+
|
| 151 |
st.session_state.setup = True
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
# -----------------------------
|
| 155 |
+
# Utility helpers
|
| 156 |
+
# -----------------------------
|
| 157 |
+
def is_greeting_only(text: str) -> bool:
|
| 158 |
+
clean = text.strip().lower().strip("!,.?")
|
| 159 |
+
return clean in GREETINGS
|
| 160 |
|
|
|
|
| 161 |
|
| 162 |
+
def user_turn_count() -> int:
|
| 163 |
+
return sum(1 for m in st.session_state.messages if m["role"] == "user")
|
| 164 |
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
def can_submit() -> bool:
|
| 167 |
+
return (
|
| 168 |
+
user_turn_count() >= SUBMIT_AFTER_USER_TURNS
|
| 169 |
+
or len(st.session_state.messages) >= st.session_state.max_messages
|
| 170 |
+
)
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
|
| 173 |
+
def render_sidebar() -> None:
|
| 174 |
+
with st.sidebar:
|
| 175 |
+
st.markdown("# Let's talk!")
|
| 176 |
+
st.markdown("# **Step 1. Type in the chat box to start a conversation**")
|
| 177 |
+
|
| 178 |
+
st.success(
|
| 179 |
+
"Ask, request, or talk to the chatbot about something you consider "
|
| 180 |
+
"**politically polarizing** or something that people from different "
|
| 181 |
+
"US political parties might disagree about.",
|
| 182 |
+
icon="🎯",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
st.markdown(
|
| 186 |
+
"🚫 Please avoid greetings and start the conversation with a question "
|
| 187 |
+
"or a statement about a politically polarizing topic. "
|
| 188 |
+
"**Note: the chatbot's knowledge only goes up to late August 2025.**"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
st.markdown(
|
| 192 |
+
"# **Step 2. Use the *Submit Interaction* button to get your chatbot word**\n\n"
|
| 193 |
+
"⚠️ You must respond **at least 5 times** before you will see a *Submit Interaction* button. "
|
| 194 |
+
"You can continue before submitting, but **you must Submit Interaction and enter your chatbot word "
|
| 195 |
+
"to proceed with the survey**.\n"
|
| 196 |
+
"❗ Do not share any personal information (e.g., name or address). "
|
| 197 |
+
"Do not use AI tools to write your responses. "
|
| 198 |
+
"If you encounter any technical issues, please let us know. "
|
| 199 |
+
"It might sometimes take 30 seconds or more to generate a response, so please be patient."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def render_messages() -> None:
|
| 204 |
+
for message in st.session_state.messages:
|
| 205 |
+
if message["role"] != "system":
|
| 206 |
+
with st.chat_message(message["role"]):
|
| 207 |
+
st.markdown(message["content"])
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def save_conversation() -> None:
|
| 211 |
+
payload = dict(st.session_state.user_data)
|
| 212 |
+
payload["messages"] = st.session_state.messages
|
| 213 |
+
payload["convo_start_time"] = st.session_state.convo_start_time
|
| 214 |
+
payload["convo_end_time"] = datetime.now()
|
| 215 |
+
payload["inserted"] = st.session_state.inserted + 1
|
| 216 |
+
|
| 217 |
+
with MongoClient(st.secrets["mongo"], server_api=ServerApi("1")) as mongo_client:
|
| 218 |
+
db = mongo_client.bridge
|
| 219 |
+
collection = db.app2
|
| 220 |
+
collection.insert_one(payload)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_assistant_response() -> str:
|
| 224 |
+
completion = st.session_state.client.chat.completions.create(
|
| 225 |
+
model=st.session_state.user_data["MODEL_PATH"],
|
| 226 |
+
messages=st.session_state.messages,
|
| 227 |
+
stream=False,
|
| 228 |
+
)
|
| 229 |
+
return completion.choices[0].message.content or ""
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# -----------------------------
|
| 233 |
+
# App bootstrap
|
| 234 |
+
# -----------------------------
|
| 235 |
+
init_session_state()
|
| 236 |
+
ensure_query_params()
|
| 237 |
+
|
| 238 |
+
if not st.session_state.setup and st.query_params.get("p") in CONDITIONS:
|
| 239 |
+
setup_conversation()
|
| 240 |
+
|
| 241 |
+
render_sidebar()
|
| 242 |
+
|
| 243 |
+
if st.session_state.setup and not st.session_state.messages:
|
| 244 |
+
st.success(
|
| 245 |
+
"Ask, request, or talk to the chatbot about something you consider "
|
| 246 |
+
"**politically polarizing** or something that people from different US political parties might disagree about.",
|
| 247 |
+
icon="🎯",
|
| 248 |
)
|
| 249 |
|
| 250 |
+
render_messages()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# -----------------------------
|
| 254 |
+
# Main app flow
|
| 255 |
+
# -----------------------------
|
| 256 |
+
if len(st.session_state.messages) >= st.session_state.max_messages:
|
| 257 |
+
st.info("You have reached the limit of messages for this conversation. Please end and submit the conversation.")
|
| 258 |
+
|
| 259 |
+
elif st.session_state.inserted > 0:
|
| 260 |
st.markdown("## Copy your WORD!")
|
| 261 |
+
st.markdown("**Your chatbot WORD is:**")
|
| 262 |
+
st.markdown("## TOMATOES")
|
| 263 |
+
st.markdown("**Please copy the WORD and enter it into the survey field below.**")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
elif prompt := st.chat_input("Type to ask a question or respond..."):
|
| 266 |
+
if not st.session_state.setup:
|
| 267 |
+
st.error("Please select a condition first.")
|
| 268 |
+
st.stop()
|
| 269 |
+
|
| 270 |
+
if not st.session_state.messages and is_greeting_only(prompt):
|
| 271 |
+
st.error(
|
| 272 |
+
"Please avoid greetings and start the conversation with a question or a statement about a politically polarizing topic.",
|
| 273 |
+
icon="🚫",
|
| 274 |
+
)
|
| 275 |
+
st.stop()
|
| 276 |
|
| 277 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 278 |
+
|
| 279 |
with st.chat_message("user"):
|
| 280 |
st.markdown(prompt)
|
| 281 |
|
| 282 |
+
condition = st.session_state.user_data["condition"]
|
| 283 |
+
|
| 284 |
+
if condition == "1":
|
| 285 |
+
response = (
|
| 286 |
+
"Thank you for your question. You have been randomly assigned to a condition "
|
| 287 |
+
"without a chatbot. **Please submit your interaction anyway** to get your chatbot word "
|
| 288 |
+
"and proceed with the survey. Do not worry, this will not influence your compensation."
|
| 289 |
+
)
|
| 290 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 291 |
st.session_state.max_messages = len(st.session_state.messages)
|
| 292 |
st.rerun()
|
| 293 |
|
| 294 |
+
with st.chat_message("assistant"):
|
| 295 |
+
try:
|
| 296 |
+
with st.spinner("Typing..."):
|
| 297 |
+
response = get_assistant_response()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
st.markdown(response)
|
| 300 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
| 301 |
|
| 302 |
+
except Exception as e:
|
| 303 |
+
error_message = (
|
| 304 |
+
"An error has occurred or you've reached the maximum conversation length. "
|
| 305 |
+
"Please submit the conversation."
|
| 306 |
+
)
|
| 307 |
+
st.session_state.messages.append({"role": "assistant", "content": error_message})
|
| 308 |
+
st.session_state.max_messages = len(st.session_state.messages)
|
| 309 |
+
st.error(f"Request failed: {e}")
|
| 310 |
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
+
# -----------------------------
|
| 313 |
+
# Submit button
|
| 314 |
+
# -----------------------------
|
| 315 |
+
if can_submit() and st.session_state.inserted == 0:
|
| 316 |
+
cols = st.columns((1, 1, 1))
|
| 317 |
+
with cols[2]:
|
| 318 |
+
if st.button("Submit Interaction", use_container_width=True):
|
| 319 |
+
try:
|
| 320 |
+
save_conversation()
|
| 321 |
+
st.session_state.inserted += 1
|
| 322 |
+
reset_conversation_state()
|
| 323 |
+
st.rerun()
|
| 324 |
+
except Exception as e:
|
| 325 |
+
st.error(f"Failed to save conversation: {e}")
|