ehejin commited on
Commit
cc5590d
·
1 Parent(s): 6d6e203

added user study

Browse files
Files changed (2) hide show
  1. requirements.txt +7 -3
  2. src/streamlit_app.py +965 -34
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit>=1.32.0
2
+ openai>=1.0.0
3
+ huggingface_hub>=0.20.0
4
+ datasets>=2.18.0
5
+ filelock>=3.13.0
6
+ python-dotenv>=1.0.0
7
+ pandas>=2.0.0
src/streamlit_app.py CHANGED
@@ -1,40 +1,971 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
 
 
5
 
 
 
 
 
 
 
6
  """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit App: AI Product Willingness User Study
3
+ =================================================
4
+ Run locally:
5
+ streamlit run app.py -- --category groceries
6
+ streamlit run app.py -- --category groceries --debug
7
 
8
+ On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
9
+ HF_TOKEN - HuggingFace token
10
+ TOGETHER_API_KEY - Together AI API key
11
+ DATASET_REPO_ID - HuggingFace dataset repo to upload results
12
+ CATEGORY - groceries | books | movies | health (default: groceries)
13
+ DEBUG_MODE - "true" to skip validation (optional)
14
  """
 
15
 
16
+ import asyncio
17
+ import concurrent.futures
18
+ import csv
19
+ import json
20
+ import os
21
+ import random
22
+ import re
23
+ import sys
24
+ import tempfile
25
+ import time
26
+ import uuid
27
+ from datetime import datetime
28
+ from pathlib import Path
29
+
30
+ import streamlit as st
31
+ from dotenv import load_dotenv
32
+ from filelock import FileLock
33
+ from huggingface_hub import HfApi
34
+ from openai import AsyncOpenAI
35
+
36
+ load_dotenv()
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # CLI args (supported locally; ignored on HF Spaces — use env vars instead)
40
+ # ---------------------------------------------------------------------------
41
+ import argparse
42
+ parser = argparse.ArgumentParser(add_help=False)
43
+ parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None)
44
+ parser.add_argument("--debug", action="store_true", default=False)
45
+ cli_args, _ = parser.parse_known_args()
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Config (env vars take precedence, then CLI args, then defaults)
49
+ # ---------------------------------------------------------------------------
50
+ CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries"
51
+ DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
52
+ DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
53
+ HF_TOKEN = os.getenv("HF_TOKEN")
54
+ TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
55
+ MODEL_NAME = "openai/gpt-oss-20b"
56
+
57
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
58
+ DATA_DIR = os.path.join(BASE_DIR, "data")
59
+ ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
60
+ os.makedirs(DATA_DIR, exist_ok=True)
61
+ os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
62
+
63
+ CATEGORY_TO_HF = {
64
+ "books": "ehejin/amazon_books",
65
+ "groceries": "ehejin/amazon_Grocery_and_Gourmet_Food",
66
+ "movies": "ehejin/amazon_Movies_and_TV",
67
+ "health": "ehejin/amazon_Health_and_Household",
68
+ }
69
+ CATEGORY_DISPLAY = {
70
+ "books": "Books",
71
+ "groceries": "Grocery Products",
72
+ "movies": "Movies & TV",
73
+ "health": "Health & Household Products",
74
+ }
75
+ FAMILIARITY_USED_LABEL = {
76
+ "books": "Read it before",
77
+ "movies": "Watched it before",
78
+ "groceries": "Used it before",
79
+ "health": "Used it before",
80
+ }
81
+
82
+ PRODUCTS_PER_USER = 5
83
+ MIN_TURNS = 3
84
+ MAX_TURNS = 10
85
+
86
+ DEBUG_DEMOGRAPHICS = {
87
+ "age": "30", "gender": "Female", "geographic_region": "West",
88
+ "education_level": "College graduate/some postgrad", "race": "White",
89
+ "us_citizen": "Yes", "marital_status": "Single",
90
+ "religion": "Agnostic", "religious_attendance": "Never",
91
+ "political_affiliation": "Independent", "income": "$50,000-$75,000",
92
+ "political_views": "Moderate", "household_size": "2",
93
+ "employment_status": "Full-time employment",
94
+ }
95
+
96
+ WILLINGNESS_LABELS = {
97
+ 1: "Definitely would not buy",
98
+ 2: "Probably would not buy",
99
+ 3: "Slightly unlikely to buy",
100
+ 4: "Neutral",
101
+ 5: "Slightly likely to buy",
102
+ 6: "Probably would buy",
103
+ 7: "Definitely would buy",
104
+ }
105
+ WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Dataset loading
109
+ # ---------------------------------------------------------------------------
110
+ LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}.json")
111
+ ORDER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_order.json")
112
+ COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
113
+ COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
114
+
115
+
116
+ @st.cache_resource
117
+ def download_and_cache_dataset():
118
+ if os.path.exists(LOCAL_DATA_PATH):
119
+ print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
120
+ return
121
+ print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} from HuggingFace...")
122
+ try:
123
+ from datasets import load_dataset
124
+ import huggingface_hub
125
+ if HF_TOKEN:
126
+ huggingface_hub.login(token=HF_TOKEN)
127
+ ds = load_dataset(CATEGORY_TO_HF[CATEGORY], split="train")
128
+ items = []
129
+ for row in ds:
130
+ meta = row.get("metadata", {})
131
+ def to_list(val):
132
+ if isinstance(val, list): return val
133
+ if isinstance(val, str): return [val] if val else []
134
+ return []
135
+ item = {
136
+ "id": str(uuid.uuid4()),
137
+ "title": meta.get("title", "") if isinstance(meta, dict) else "",
138
+ "description": to_list(meta.get("description", []) if isinstance(meta, dict) else []),
139
+ "features": to_list(meta.get("features", []) if isinstance(meta, dict) else []),
140
+ "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
141
+ "category": CATEGORY,
142
+ }
143
+ items.append(item)
144
+ with open(LOCAL_DATA_PATH, "w") as f:
145
+ json.dump(items, f, indent=2)
146
+ print(f"[DATA] Cached {len(items)} items to {LOCAL_DATA_PATH}")
147
+ except Exception as e:
148
+ print(f"[DATA] ERROR downloading dataset: {e}")
149
+ raise
150
+
151
+
152
+ @st.cache_resource
153
+ def load_local_dataset():
154
+ with open(LOCAL_DATA_PATH, "r") as f:
155
+ return json.load(f)
156
+
157
+
158
+ @st.cache_resource
159
+ def ensure_shuffled_order(n_items):
160
+ if os.path.exists(ORDER_PATH):
161
+ with open(ORDER_PATH, "r") as f:
162
+ return json.load(f)
163
+ indices = list(range(n_items))
164
+ random.shuffle(indices)
165
+ with open(ORDER_PATH, "w") as f:
166
+ json.dump(indices, f)
167
+ return indices
168
+
169
+
170
+ def assign_products(items, order, n=PRODUCTS_PER_USER):
171
+ lock = FileLock(COUNTER_LOCK_PATH)
172
+ with lock:
173
+ if os.path.exists(COUNTER_PATH):
174
+ with open(COUNTER_PATH, "r") as f:
175
+ counter = int(f.read().strip() or "0")
176
+ else:
177
+ counter = 0
178
+ total = len(order)
179
+ assigned_indices = [order[(counter + i) % total] for i in range(n)]
180
+ new_counter = (counter + n) % total
181
+ with open(COUNTER_PATH, "w") as f:
182
+ f.write(str(new_counter))
183
+ return [items[i] for i in assigned_indices]
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # AI client
188
+ # ---------------------------------------------------------------------------
189
+ @st.cache_resource
190
+ def get_model_client():
191
+ return AsyncOpenAI(
192
+ base_url="https://api.together.xyz/v1",
193
+ api_key=TOGETHER_API_KEY,
194
+ timeout=60.0,
195
+ )
196
+
197
+
198
+ def call_model(messages: list) -> str:
199
+ async def _call():
200
+ try:
201
+ client = get_model_client()
202
+ response = await client.chat.completions.create(
203
+ model=MODEL_NAME,
204
+ messages=messages,
205
+ max_tokens=1000,
206
+ temperature=0.7,
207
+ top_p=0.9,
208
+ )
209
+ content = response.choices[0].message.content.strip()
210
+ content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
211
+ return content
212
+ except Exception as e:
213
+ print(f"[MODEL] Error: {e}")
214
+ return f"[Model error: {e}]"
215
+
216
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
217
+ future = pool.submit(asyncio.run, _call())
218
+ return future.result()
219
+
220
 
221
+ # ---------------------------------------------------------------------------
222
+ # HuggingFace upload
223
+ # ---------------------------------------------------------------------------
224
+ @st.cache_resource
225
+ def get_hf_api():
226
+ api = HfApi(token=HF_TOKEN) if HF_TOKEN else HfApi()
227
+ if HF_TOKEN:
228
+ try:
229
+ api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset")
230
+ print(f"[HF] Repo {DATASET_REPO_ID} exists.")
231
+ except Exception as e:
232
+ if "404" in str(e) or "not found" in str(e).lower():
233
+ api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True)
234
+ print(f"[HF] Created repo {DATASET_REPO_ID}.")
235
+ else:
236
+ print(f"[HF] WARNING: {e}")
237
+ return api
238
+
239
+
240
+ def save_and_upload(state: dict):
241
+ hf_api = get_hf_api()
242
+ worker_id = state.get("worker_id") or state.get("user_id", "anonymous")
243
+ submission_id = state.get("submission_id", str(uuid.uuid4()))
244
+ safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
245
+ filename = f"{submission_id}_{CATEGORY}.json"
246
+ folder = os.path.join(ANNOTATIONS_DIR, safe_worker)
247
+ os.makedirs(folder, exist_ok=True)
248
+ file_path = os.path.join(folder, filename)
249
+ with open(file_path, "w") as f:
250
+ json.dump(state, f, indent=2)
251
+ print(f"[SAVE] Wrote {file_path}")
252
+ if HF_TOKEN:
253
+ try:
254
+ hf_api.upload_file(
255
+ path_or_fileobj=file_path,
256
+ path_in_repo=f"{safe_worker}/{filename}",
257
+ repo_id=DATASET_REPO_ID,
258
+ repo_type="dataset",
259
+ )
260
+ print("[HF] Uploaded JSON.")
261
+ except Exception as e:
262
+ print(f"[HF] JSON upload error: {e}")
263
+ upload_csv_rows(state, hf_api, safe_worker, submission_id)
264
+
265
+
266
+ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
267
+ demographics = state.get("demographics", {})
268
+ products = state.get("products", [])
269
+ header = [
270
+ "submission_id", "worker_id", "submission_time", "duration_seconds", "category",
271
+ "age", "gender", "geographic_region", "education_level", "race",
272
+ "us_citizen", "marital_status", "religion", "religious_attendance",
273
+ "political_affiliation", "income", "political_views", "household_size", "employment_status",
274
+ "product_index", "product_id", "title", "price", "familiarity",
275
+ "pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label",
276
+ "willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change",
277
+ ]
278
+ rows = []
279
+ for i, prod in enumerate(products):
280
+ conv = prod.get("conversation", {})
281
+ refl = prod.get("reflection", {})
282
+ pre = prod.get("pre_willingness", "")
283
+ post = prod.get("post_willingness", "")
284
+ delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
285
+ row = [
286
+ submission_id, state.get("worker_id", ""),
287
+ state.get("meta", {}).get("submission_time", ""),
288
+ state.get("meta", {}).get("duration_seconds", ""),
289
+ CATEGORY,
290
+ demographics.get("age", ""), demographics.get("gender", ""),
291
+ demographics.get("geographic_region", ""), demographics.get("education_level", ""),
292
+ demographics.get("race", ""), demographics.get("us_citizen", ""),
293
+ demographics.get("marital_status", ""), demographics.get("religion", ""),
294
+ demographics.get("religious_attendance", ""), demographics.get("political_affiliation", ""),
295
+ demographics.get("income", ""), demographics.get("political_views", ""),
296
+ demographics.get("household_size", ""), demographics.get("employment_status", ""),
297
+ i + 1, prod.get("id", ""), prod.get("title", ""), prod.get("price", ""),
298
+ prod.get("familiarity", ""),
299
+ pre, WILLINGNESS_LABELS.get(pre, "") if isinstance(pre, int) else "",
300
+ post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "",
301
+ delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])),
302
+ refl.get("standout_moment", ""), refl.get("thinking_change", ""),
303
+ ]
304
+ rows.append(row)
305
+
306
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
307
+ unique_id = uuid.uuid4().hex[:8]
308
+ csv_filename = f"csv_submissions/{timestamp_str}_{safe_worker}_{unique_id}.csv"
309
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8") as tmp:
310
+ tmp_path = tmp.name
311
+ writer = csv.writer(tmp)
312
+ writer.writerow(header)
313
+ writer.writerows(rows)
314
+ if HF_TOKEN:
315
+ try:
316
+ hf_api.upload_file(
317
+ path_or_fileobj=tmp_path,
318
+ path_in_repo=csv_filename,
319
+ repo_id=DATASET_REPO_ID,
320
+ repo_type="dataset",
321
+ )
322
+ print("[HF] Uploaded CSV rows.")
323
+ except Exception as e:
324
+ print(f"[HF] CSV upload error: {e}")
325
+ os.unlink(tmp_path)
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # Prompt builders
330
+ # ---------------------------------------------------------------------------
331
+ def build_sales_system_prompt(product: dict) -> str:
332
+ title = product.get("title", "this product")
333
+ price = product.get("price", "N/A")
334
+ description = product.get("description", [])
335
+ features = product.get("features", [])
336
+ desc_text = " ".join(description) if description else "No description available."
337
+ feat_text = " ".join(features) if features else ""
338
+ price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
339
+ return f"""You are a warm, knowledgeable sales agent for an Amazon product. Your goal is to convince the user to buy this product.
340
+
341
+ Product name: {title}
342
+ Product description and features: {desc_text} {feat_text}
343
+ Price: {price_str}
344
+
345
+ Rules:
346
+ - Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson
347
+ - Be persuasive and proactive: do not just answer questions; actively make the case for buying.
348
+ - Tailor your argument to what the user cares about or seems hesitant about.
349
+ - Use only the information provided in the title, description, features, and price.
350
+ - Do not fabricate reviews, ratings, statistics, comparisons, discounts, or product details.
351
+ - If the user raises a concern, respond directly and turn the conversation back to why the product is worth buying.
352
+ - Keep the reply concise and natural.
353
+
354
+ First message:
355
+ - Open with the product's strongest selling point.
356
+ - Explain why it is worth buying.
357
+ - End with a question that helps uncover the user's preferences, objections, or hesitation.
358
+
359
+ Follow-up rules:
360
+ - Acknowledge what the user said.
361
+ - Address their main objection directly with a concrete benefit grounded in the product info.
362
+ - Continue steering the conversation toward purchase.
363
+ - Usually end with a natural question that keeps the user engaged.
364
  """
365
 
366
+
367
+ def build_opening_user_message(product: dict) -> str:
368
+ return f'Tell me about this product and why I should buy it: "{product.get("title", "this product")}"'
369
+
370
+
371
+ def parse_willingness(choice_str: str) -> int:
372
+ try:
373
+ return int(choice_str.split("(")[1].rstrip(")"))
374
+ except Exception:
375
+ return 4
376
+
377
+
378
+ def get_familiarity_choices():
379
+ used_label = FAMILIARITY_USED_LABEL.get(CATEGORY, "Used it before")
380
+ return [
381
+ "Never heard of it",
382
+ "Heard of it, but not used/purchased",
383
+ used_label,
384
+ "Purchased it before",
385
+ ]
386
+
387
+
388
+ # ---------------------------------------------------------------------------
389
+ # State initialisation
390
+ # ---------------------------------------------------------------------------
391
+ def init_state():
392
+ download_and_cache_dataset()
393
+ items = load_local_dataset()
394
+ order = ensure_shuffled_order(len(items))
395
+ assigned = assign_products(items, order, PRODUCTS_PER_USER)
396
+
397
+ # Read MTurk query params if available
398
+ try:
399
+ params = st.query_params
400
+ except Exception:
401
+ params = {}
402
+
403
+ return {
404
+ "submission_id": str(uuid.uuid4()),
405
+ "user_id": str(uuid.uuid4()),
406
+ "worker_id": params.get("workerId", ""),
407
+ "assignment_id": params.get("assignmentId", ""),
408
+ "hit_id": params.get("hitId", ""),
409
+ "turk_submit_to": params.get("turkSubmitTo", ""),
410
+ "start_time": time.time(),
411
+ "category": CATEGORY,
412
+ "demographics": {},
413
+ "products": [
414
+ {
415
+ "id": p.get("id", str(uuid.uuid4())),
416
+ "title": p.get("title", ""),
417
+ "description": p.get("description", []),
418
+ "features": p.get("features", []),
419
+ "price": p.get("price", "N/A"),
420
+ "familiarity": None,
421
+ "pre_willingness": None,
422
+ "post_willingness": None,
423
+ "willingness_delta": None,
424
+ "conversation": {
425
+ "system_prompt": "",
426
+ "opening_user_message": "",
427
+ "turns": [],
428
+ "num_turns": 0,
429
+ },
430
+ "reflection": {},
431
+ }
432
+ for p in assigned
433
+ ],
434
+ "current_product_index": 0,
435
+ "screen": "welcome", # screens: welcome | demographics | product_intro | chat | post_will | reflection | done
436
+ "meta": {},
437
+ }
438
+
439
+
440
+ # ---------------------------------------------------------------------------
441
+ # CSS
442
+ # ---------------------------------------------------------------------------
443
+ def inject_css():
444
+ st.markdown("""
445
+ <style>
446
+ /* Hide Streamlit chrome */
447
+ #MainMenu, footer, header { visibility: hidden; }
448
+ .block-container { max-width: 820px; padding-top: 2rem; }
449
+
450
+ /* Product card */
451
+ .product-card {
452
+ border: 2px solid #2563eb;
453
+ border-radius: 10px;
454
+ padding: 1rem 1.25rem;
455
+ background: #f0f6ff;
456
+ margin-bottom: 0.75rem;
457
+ }
458
+ .pc-header {
459
+ display: flex;
460
+ justify-content: space-between;
461
+ align-items: flex-start;
462
+ margin-bottom: 0.6rem;
463
+ gap: 1rem;
464
+ }
465
+ .pc-title {
466
+ font-size: 1.05rem;
467
+ font-weight: 700;
468
+ color: #1a1a2e;
469
+ line-height: 1.35;
470
+ flex: 1;
471
+ }
472
+ .pc-price {
473
+ font-size: 1.2rem;
474
+ font-weight: 800;
475
+ color: #16a34a;
476
+ white-space: nowrap;
477
+ }
478
+ .pc-section { margin-top: 0.5rem; }
479
+ .pc-section-title {
480
+ font-weight: 600;
481
+ font-size: 0.85rem;
482
+ color: #475569;
483
+ text-transform: uppercase;
484
+ letter-spacing: 0.04em;
485
+ margin-bottom: 0.3rem;
486
+ }
487
+ .pc-list {
488
+ margin: 0;
489
+ padding-left: 1.2rem;
490
+ font-size: 0.92rem;
491
+ color: #334155;
492
+ line-height: 1.5;
493
+ }
494
+ .pc-list li { margin-bottom: 0.25rem; }
495
+
496
+ /* Progress bar */
497
+ .progress-wrap {
498
+ background: #e2e8f0;
499
+ border-radius: 99px;
500
+ height: 8px;
501
+ margin-bottom: 0.25rem;
502
+ overflow: hidden;
503
+ }
504
+ .progress-fill {
505
+ background: #2563eb;
506
+ height: 100%;
507
+ border-radius: 99px;
508
+ }
509
+ .progress-label {
510
+ font-size: 0.82rem;
511
+ color: #64748b;
512
+ text-align: right;
513
+ margin-bottom: 1rem;
514
+ }
515
+
516
+ /* Chat bubbles */
517
+ .chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; }
518
+ .bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; }
519
+ .bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; }
520
+ .bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; }
521
+ .bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; }
522
+
523
+ /* Compact product banner above chat */
524
+ .chat-product-banner {
525
+ border: 1.5px solid #93c5fd;
526
+ border-radius: 8px;
527
+ padding: 0.6rem 1rem;
528
+ background: #eff6ff;
529
+ margin-bottom: 0.75rem;
530
+ font-size: 0.88rem;
531
+ color: #1d4ed8;
532
+ font-weight: 600;
533
+ cursor: pointer;
534
+ }
535
+ </style>
536
+ """, unsafe_allow_html=True)
537
+
538
+
539
+ # ---------------------------------------------------------------------------
540
+ # UI helpers
541
+ # ---------------------------------------------------------------------------
542
+ def render_product_card_html(product: dict, compact: bool = False) -> str:
543
+ title = product.get("title", "Unknown Product")
544
+ price = product.get("price", "N/A")
545
+ description = product.get("description", [])
546
+ features = product.get("features", [])
547
+ price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
548
+
549
+ desc_html = ""
550
+ if description:
551
+ items_html = "".join(f"<li>{d}</li>" for d in description if d)
552
+ desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><ul class="pc-list">{items_html}</ul></div>'
553
+
554
+ feat_html = ""
555
+ if features:
556
+ items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
557
+ feat_html = f'<div class="pc-section"><div class="pc-section-title">✨ Features</div><ul class="pc-list">{items_html}</ul></div>'
558
+
559
+ max_h = "max-height:240px;overflow-y:auto;" if compact else ""
560
+ return f"""
561
+ <div class="product-card" style="{max_h}">
562
+ <div class="pc-header">
563
+ <div class="pc-title">{title}</div>
564
+ <div class="pc-price">{price_str}</div>
565
+ </div>
566
+ {desc_html}
567
+ {feat_html}
568
+ </div>"""
569
+
570
+
571
+ def render_progress(current: int, total: int = PRODUCTS_PER_USER):
572
+ pct = int((current / total) * 100)
573
+ st.markdown(f"""
574
+ <div class="progress-wrap"><div class="progress-fill" style="width:{pct}%"></div></div>
575
+ <div class="progress-label">Product {current} of {total}</div>
576
+ """, unsafe_allow_html=True)
577
+
578
+
579
+ def render_chat_history(turns: list):
580
+ html = '<div class="chat-wrap">'
581
+ for turn in turns:
582
+ role = turn.get("role", "")
583
+ content = turn.get("content", "")
584
+ if role == "assistant":
585
+ html += f'<div class="bubble-label">🤖 AI Sales Agent</div><div class="bubble bubble-ai">{content}</div>'
586
+ elif role == "user":
587
+ html += f'<div class="bubble-label" style="text-align:right">You</div><div class="bubble bubble-user">{content}</div>'
588
+ html += "</div>"
589
+ st.markdown(html, unsafe_allow_html=True)
590
+
591
+
592
+ # ---------------------------------------------------------------------------
593
+ # Screen renderers
594
+ # ---------------------------------------------------------------------------
595
+ def screen_welcome(s):
596
+ st.markdown(f"# 🛒 Product Evaluation Study")
597
+ st.markdown(
598
+ f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
599
+ "For each product you will:\n"
600
+ "1. Rate how familiar you are with the product\n"
601
+ "2. Rate how willing you are to buy it\n"
602
+ "3. Chat with an AI about the product (**at least 3 exchanges**)\n"
603
+ "4. Rate your willingness to buy it again\n"
604
+ "5. Answer two brief reflection questions\n\n"
605
+ "After all 5 products, you're done! The study takes about **20–30 minutes**. "
606
+ "Thank you for participating!"
607
+ )
608
+ if st.button("Begin →", type="primary", use_container_width=True):
609
+ if DEBUG_MODE:
610
+ s["demographics"] = DEBUG_DEMOGRAPHICS.copy()
611
+ s["screen"] = "product_intro"
612
+ else:
613
+ s["screen"] = "demographics"
614
+ st.rerun()
615
+
616
+
617
+ def screen_demographics(s):
618
+ st.markdown("## Demographics — About You")
619
+ st.markdown("All fields are required before you can proceed.")
620
+
621
+ age = st.text_input("Age (years)", placeholder="e.g. 34")
622
+ gender = st.selectbox("Gender", ["", "Female", "Male"])
623
+ geographic_region = st.selectbox("Geographic region", ["", "West", "South", "Midwest", "Northeast", "Pacific"])
624
+ education_level = st.selectbox("Highest education level", [
625
+ "", "Less than high school", "High school graduate",
626
+ "Some college, no degree", "Associate's degree",
627
+ "College graduate/some postgrad", "Postgraduate",
628
+ ])
629
+ race = st.selectbox("Race / ethnicity", ["", "Asian", "Hispanic", "White", "Black", "Other"])
630
+ us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"])
631
+ marital_status = st.selectbox("Marital status", [
632
+ "", "Never been married", "Married", "Living with a partner",
633
+ "Divorced", "Separated", "Widowed",
634
+ ])
635
+ religion = st.selectbox("Religion", [
636
+ "", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish",
637
+ "Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other",
638
+ ])
639
+ religious_attendance = st.selectbox("How often do you attend religious services?", [
640
+ "", "Never", "Seldom", "A few times a year", "Once or twice a month",
641
+ "Once a week", "More than once a week",
642
+ ])
643
+ political_affiliation = st.selectbox("Political affiliation", [
644
+ "", "Democrat", "Republican", "Independent", "Something else",
645
+ ])
646
+ income = st.selectbox("Household income", [
647
+ "", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000",
648
+ "$75,000-$100,000", "$100,000 or more",
649
+ ])
650
+ political_views = st.selectbox("Political views", [
651
+ "", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative",
652
+ ])
653
+ household_size = st.selectbox("Household size", ["", "1", "2", "3", "4", "More than 4"])
654
+ employment_status = st.selectbox("Employment status", [
655
+ "", "Full-time employment", "Part-time employment", "Self-employed",
656
+ "Unemployed", "Retired", "Home-maker", "Student",
657
+ ])
658
+
659
+ if st.button("Next →", type="primary", use_container_width=True):
660
+ fields = [age, gender, geographic_region, education_level, race, us_citizen,
661
+ marital_status, religion, religious_attendance, political_affiliation,
662
+ income, political_views, household_size, employment_status]
663
+ if not all([f and (f.strip() if isinstance(f, str) else f) for f in fields]):
664
+ st.error("⚠️ Please complete all fields.")
665
+ return
666
+ if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120):
667
+ st.error("⚠️ Please enter a valid age.")
668
+ return
669
+ s["demographics"] = {
670
+ "age": age.strip(), "gender": gender, "geographic_region": geographic_region,
671
+ "education_level": education_level, "race": race, "us_citizen": us_citizen,
672
+ "marital_status": marital_status, "religion": religion,
673
+ "religious_attendance": religious_attendance, "political_affiliation": political_affiliation,
674
+ "income": income, "political_views": political_views,
675
+ "household_size": household_size, "employment_status": employment_status,
676
+ }
677
+ s["screen"] = "product_intro"
678
+ st.rerun()
679
+
680
+
681
+ def screen_product_intro(s):
682
+ idx = s["current_product_index"]
683
+ product = s["products"][idx]
684
+ render_progress(idx + 1)
685
+ st.markdown("## Product Evaluation")
686
+ st.markdown("Please read the product information carefully, then answer the two questions below.")
687
+ st.markdown(render_product_card_html(product), unsafe_allow_html=True)
688
+
689
+ familiarity_val = st.radio(
690
+ "How familiar are you with this product?",
691
+ get_familiarity_choices(),
692
+ index=None,
693
+ key=f"familiarity_{idx}",
694
+ )
695
+ pre_will_val = st.radio(
696
+ "How willing would you be to buy this product?",
697
+ WILLINGNESS_CHOICES,
698
+ index=None,
699
+ key=f"pre_will_{idx}",
700
+ )
701
+
702
+ if st.button("Start Chat →", type="primary", use_container_width=True):
703
+ if not DEBUG_MODE:
704
+ if not familiarity_val:
705
+ st.error("⚠️ Please rate your familiarity.")
706
+ return
707
+ if not pre_will_val:
708
+ st.error("⚠️ Please rate your willingness to buy.")
709
+ return
710
+ familiarity_val = familiarity_val or get_familiarity_choices()[0]
711
+ pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
712
+
713
+ pre_val = parse_willingness(pre_will_val)
714
+ s["products"][idx]["familiarity"] = familiarity_val
715
+ s["products"][idx]["pre_willingness"] = pre_val
716
+ s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val]
717
+
718
+ # Get opening AI message
719
+ system_prompt = build_sales_system_prompt(product)
720
+ opening_user_msg = build_opening_user_message(product)
721
+ messages = [
722
+ {"role": "system", "content": system_prompt},
723
+ {"role": "user", "content": opening_user_msg},
724
+ ]
725
+ with st.spinner("Starting conversation…"):
726
+ ai_reply = call_model(messages)
727
+
728
+ s["products"][idx]["conversation"]["system_prompt"] = system_prompt
729
+ s["products"][idx]["conversation"]["opening_user_message"] = opening_user_msg
730
+ s["products"][idx]["conversation"]["turns"] = [
731
+ {"turn_index": 0, "role": "assistant", "content": ai_reply,
732
+ "timestamp": time.time(), "model": MODEL_NAME}
733
+ ]
734
+ s["products"][idx]["conversation"]["num_turns"] = 0
735
+ s["screen"] = "chat"
736
+ st.rerun()
737
+
738
+
739
+ def screen_chat(s):
740
+ idx = s["current_product_index"]
741
+ product = s["products"][idx]
742
+ conv = s["products"][idx]["conversation"]
743
+
744
+ render_progress(idx + 1)
745
+ st.markdown("## Chat with the AI")
746
+
747
+ # Compact product banner
748
+ title = product.get("title", "Product")
749
+ price = product.get("price", "N/A")
750
+ price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
751
+ with st.expander(f"📦 {title} — {price_str} (click to expand product details)"):
752
+ st.markdown(render_product_card_html(product, compact=True), unsafe_allow_html=True)
753
+
754
+ num_turns = conv["num_turns"]
755
+ st.markdown(
756
+ f"The AI is trying to convince you to buy this product. "
757
+ f"Ask questions, push back, or explore your interest. "
758
+ f"You need at least **{MIN_TURNS} exchanges** before you can move on."
759
+ )
760
+
761
+ # Chat history (only user/assistant turns, not the opening system exchange)
762
+ display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")]
763
+ render_chat_history(display_turns)
764
+
765
+ # Turn counter
766
+ if num_turns >= MAX_TURNS:
767
+ st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.")
768
+ else:
769
+ st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}")
770
+
771
+ # Input
772
+ if num_turns < MAX_TURNS:
773
+ user_msg = st.text_area("Your response:", placeholder="Type your response here…", height=100, key=f"chat_input_{idx}_{num_turns}")
774
+ col1, col2 = st.columns([3, 1])
775
+ with col2:
776
+ send_clicked = st.button("Send", type="primary", use_container_width=True)
777
+ if send_clicked:
778
+ if not user_msg or not user_msg.strip():
779
+ st.error("⚠️ Please type a message.")
780
+ return
781
+ if len(user_msg.strip().split()) < 5 and not DEBUG_MODE:
782
+ st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).")
783
+ return
784
+ user_msg = user_msg.strip()
785
+ messages = [{"role": "system", "content": conv["system_prompt"]},
786
+ {"role": "user", "content": conv["opening_user_message"]}]
787
+ for turn in conv["turns"]:
788
+ messages.append({"role": turn["role"], "content": turn["content"]})
789
+ messages.append({"role": "user", "content": user_msg})
790
+ with st.spinner("AI is responding…"):
791
+ ai_reply = call_model(messages)
792
+ conv["turns"].append({"turn_index": len(conv["turns"]), "role": "user",
793
+ "content": user_msg, "timestamp": time.time()})
794
+ conv["turns"].append({"turn_index": len(conv["turns"]), "role": "assistant",
795
+ "content": ai_reply, "timestamp": time.time(), "model": MODEL_NAME})
796
+ conv["num_turns"] = num_turns + 1
797
+ s["products"][idx]["conversation"] = conv
798
+ st.rerun()
799
+
800
+ # Done button
801
+ can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE
802
+ if can_finish:
803
+ if st.button("I'm done chatting →", use_container_width=True):
804
+ s["screen"] = "post_will"
805
+ st.rerun()
806
+ else:
807
+ st.button("I'm done chatting →", disabled=True, use_container_width=True,
808
+ help=f"Complete at least {MIN_TURNS} exchanges first.")
809
+
810
+
811
+ def screen_post_willingness(s):
812
+ idx = s["current_product_index"]
813
+ product = s["products"][idx]
814
+ render_progress(idx + 1)
815
+ st.markdown("## Your View Now")
816
+ st.markdown("Now that you've chatted with the AI, rate your willingness to buy again.")
817
+ st.markdown(render_product_card_html(product), unsafe_allow_html=True)
818
+
819
+ post_will_val = st.radio(
820
+ "How willing would you be to buy this product now?",
821
+ WILLINGNESS_CHOICES,
822
+ index=None,
823
+ key=f"post_will_{idx}",
824
+ )
825
+
826
+ if st.button("Next →", type="primary", use_container_width=True):
827
+ if not post_will_val and not DEBUG_MODE:
828
+ st.error("⚠️ Please rate your willingness to buy.")
829
+ return
830
+ post_will_val = post_will_val or WILLINGNESS_CHOICES[3]
831
+ post_val = parse_willingness(post_will_val)
832
+ pre_val = s["products"][idx].get("pre_willingness", 4)
833
+ delta = post_val - pre_val
834
+ s["products"][idx]["post_willingness"] = post_val
835
+ s["products"][idx]["post_willingness_label"] = WILLINGNESS_LABELS[post_val]
836
+ s["products"][idx]["willingness_delta"] = delta
837
+ s["screen"] = "reflection"
838
+ st.rerun()
839
+
840
+
841
+ def screen_reflection(s):
842
+ idx = s["current_product_index"]
843
+ render_progress(idx + 1)
844
+ st.markdown("## Reflection")
845
+
846
+ standout = st.text_area(
847
+ "What did the AI say that stood out to you most?",
848
+ placeholder="Describe a specific argument, question, or moment from the conversation…",
849
+ height=120,
850
+ key=f"standout_{idx}",
851
+ )
852
+ thinking_change = st.text_area(
853
+ "How did your thinking about this product change (or not change) during the chat? Why?",
854
+ placeholder="Be as specific as you can…",
855
+ height=120,
856
+ key=f"thinking_{idx}",
857
+ )
858
+
859
+ next_label = "Next Product →" if idx + 1 < PRODUCTS_PER_USER else "Submit Study →"
860
+ if st.button(next_label, type="primary", use_container_width=True):
861
+ if not DEBUG_MODE:
862
+ if not standout or not standout.strip():
863
+ st.error("⚠️ Please answer the first reflection question.")
864
+ return
865
+ if len(standout.strip().split()) < 10:
866
+ st.error(f"⚠️ Please write at least 10 words for the first question ({len(standout.strip().split())} so far).")
867
+ return
868
+ if not thinking_change or not thinking_change.strip():
869
+ st.error("⚠️ Please answer the second reflection question.")
870
+ return
871
+ if len(thinking_change.strip().split()) < 10:
872
+ st.error(f"⚠️ Please write at least 10 words for the second question ({len(thinking_change.strip().split())} so far).")
873
+ return
874
+
875
+ standout = (standout or "").strip() or "[debug placeholder]"
876
+ thinking_change = (thinking_change or "").strip() or "[debug placeholder]"
877
+ s["products"][idx]["reflection"] = {
878
+ "standout_moment": standout,
879
+ "thinking_change": thinking_change,
880
+ }
881
+
882
+ next_idx = idx + 1
883
+ s["current_product_index"] = next_idx
884
+
885
+ if next_idx >= PRODUCTS_PER_USER:
886
+ end_time = time.time()
887
+ s["meta"] = {
888
+ "submission_time": end_time,
889
+ "duration_seconds": round(end_time - s.get("start_time", end_time), 1),
890
+ "model": MODEL_NAME,
891
+ "category": CATEGORY,
892
+ }
893
+ with st.spinner("Saving your responses…"):
894
+ save_and_upload(s)
895
+ s["screen"] = "done"
896
+ else:
897
+ s["screen"] = "product_intro"
898
+ st.rerun()
899
+
900
+
901
+ def screen_done(s):
902
+ st.markdown("## ✅ Study Complete!")
903
+ st.markdown("**Thank you for completing the study!**")
904
+ st.markdown(f"Here's a summary of how your willingness changed across the {PRODUCTS_PER_USER} products:")
905
+
906
+ rows = []
907
+ for i, p in enumerate(s["products"]):
908
+ pre = p.get("pre_willingness", "?")
909
+ post = p.get("post_willingness", "?")
910
+ delta = p.get("willingness_delta", 0)
911
+ arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️")
912
+ rows.append({
913
+ "#": i + 1,
914
+ "Product": p.get("title", "")[:60] + ("…" if len(p.get("title", "")) > 60 else ""),
915
+ "Before": WILLINGNESS_LABELS.get(pre, str(pre)),
916
+ "After": WILLINGNESS_LABELS.get(post, str(post)),
917
+ "Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–",
918
+ })
919
+ import pandas as pd
920
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
921
+
922
+ # MTurk submit button
923
+ assignment_id = s.get("assignment_id", "")
924
+ turk_submit_to = s.get("turk_submit_to", "")
925
+ if assignment_id and turk_submit_to:
926
+ submit_url = f"{turk_submit_to}/mturk/externalSubmit"
927
+ submission_id = s.get("submission_id", "")
928
+ st.markdown(f"""
929
+ <form id="mturk-submit-form" method="POST" action="{submit_url}">
930
+ <input type="hidden" name="assignmentId" value="{assignment_id}" />
931
+ <input type="hidden" name="submission_id" value="{submission_id}" />
932
+ <button type="submit" style="
933
+ background:#2563eb; color:white; border:none; padding:12px 28px;
934
+ font-size:1rem; border-radius:6px; cursor:pointer; margin-top:12px;">
935
+ ✅ Submit to MTurk
936
+ </button>
937
+ </form>
938
+ """, unsafe_allow_html=True)
939
+
940
+
941
+ # ---------------------------------------------------------------------------
942
+ # Main
943
+ # ---------------------------------------------------------------------------
944
+ def main():
945
+ st.set_page_config(page_title="Product Study", page_icon="🛒", layout="centered")
946
+ inject_css()
947
+
948
+ if "study_state" not in st.session_state:
949
+ st.session_state.study_state = init_state()
950
+
951
+ s = st.session_state.study_state
952
+ screen = s.get("screen", "welcome")
953
+
954
+ if screen == "welcome":
955
+ screen_welcome(s)
956
+ elif screen == "demographics":
957
+ screen_demographics(s)
958
+ elif screen == "product_intro":
959
+ screen_product_intro(s)
960
+ elif screen == "chat":
961
+ screen_chat(s)
962
+ elif screen == "post_will":
963
+ screen_post_willingness(s)
964
+ elif screen == "reflection":
965
+ screen_reflection(s)
966
+ elif screen == "done":
967
+ screen_done(s)
968
+
969
+
970
+ if __name__ == "__main__":
971
+ main()