ehejin commited on
Commit
7551a44
·
1 Parent(s): ae5c4c5

temp: reset data cache

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +221 -155
src/streamlit_app.py CHANGED
@@ -2,12 +2,13 @@
2
  Streamlit App: AI Product Willingness User Study
3
  =================================================
4
  Run locally:
5
- streamlit run app.py -- --category groceries
6
- streamlit run app.py -- --category groceries --debug
7
 
8
  On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
9
  HF_TOKEN - HuggingFace token
10
- TOGETHER_API_KEY - Together AI API key
 
11
  DATASET_REPO_ID - HuggingFace dataset repo to upload results
12
  CATEGORY - groceries | books | movies | health (default: groceries)
13
  DEBUG_MODE - "true" to skip validation (optional)
@@ -51,14 +52,23 @@ CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries"
51
  DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
52
  DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
53
  HF_TOKEN = os.getenv("HF_TOKEN")
54
- TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
55
- MODEL_NAME = "openai/gpt-oss-20b"
 
 
56
 
57
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
58
  DATA_DIR = os.path.join(BASE_DIR, "data")
59
  ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
 
 
60
  os.makedirs(DATA_DIR, exist_ok=True)
61
  os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
 
 
 
 
 
62
 
63
  CATEGORY_TO_HF = {
64
  "books": "ehejin/amazon_books",
@@ -82,6 +92,10 @@ FAMILIARITY_USED_LABEL = {
82
  PRODUCTS_PER_USER = 5
83
  MIN_TURNS = 3
84
  MAX_TURNS = 10
 
 
 
 
85
 
86
  DEBUG_DEMOGRAPHICS = {
87
  "age": "30", "gender": "Female", "geographic_region": "West",
@@ -105,33 +119,40 @@ WILLINGNESS_LABELS = {
105
  WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
106
 
107
  # ---------------------------------------------------------------------------
108
- # Dataset loading
109
  # ---------------------------------------------------------------------------
110
- LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}.json")
111
- ORDER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_order.json")
112
  COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
113
  COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
 
 
 
114
 
115
 
116
  @st.cache_resource
117
  def download_and_cache_dataset():
 
118
  if os.path.exists(LOCAL_DATA_PATH):
119
  print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
120
  return
121
- print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} from HuggingFace...")
122
  try:
123
  from datasets import load_dataset
124
  import huggingface_hub
125
  if HF_TOKEN:
126
  huggingface_hub.login(token=HF_TOKEN)
127
- ds = load_dataset(CATEGORY_TO_HF[CATEGORY], split="train")
128
- items = []
 
 
 
 
 
 
 
129
  for row in ds:
130
  meta = row.get("metadata", {})
131
- def to_list(val):
132
- if isinstance(val, list): return val
133
- if isinstance(val, str): return [val] if val else []
134
- return []
135
  item = {
136
  "id": str(uuid.uuid4()),
137
  "title": meta.get("title", "") if isinstance(meta, dict) else "",
@@ -140,47 +161,119 @@ def download_and_cache_dataset():
140
  "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
141
  "category": CATEGORY,
142
  }
143
- items.append(item)
 
 
 
 
 
144
  with open(LOCAL_DATA_PATH, "w") as f:
145
- json.dump(items, f, indent=2)
146
- print(f"[DATA] Cached {len(items)} items to {LOCAL_DATA_PATH}")
 
 
 
147
  except Exception as e:
148
  print(f"[DATA] ERROR downloading dataset: {e}")
149
  raise
150
 
151
 
152
  @st.cache_resource
153
- def load_local_dataset():
154
  with open(LOCAL_DATA_PATH, "r") as f:
155
  return json.load(f)
156
 
157
 
158
  @st.cache_resource
159
- def ensure_shuffled_order(n_items):
160
- if os.path.exists(ORDER_PATH):
161
- with open(ORDER_PATH, "r") as f:
162
- return json.load(f)
163
- indices = list(range(n_items))
164
- random.shuffle(indices)
165
- with open(ORDER_PATH, "w") as f:
166
- json.dump(indices, f)
167
- return indices
168
-
169
-
170
- def assign_products(items, order, n=PRODUCTS_PER_USER):
 
 
 
 
171
  lock = FileLock(COUNTER_LOCK_PATH)
172
  with lock:
 
 
 
 
 
 
 
 
 
 
 
173
  if os.path.exists(COUNTER_PATH):
174
  with open(COUNTER_PATH, "r") as f:
175
  counter = int(f.read().strip() or "0")
176
- else:
177
- counter = 0
178
- total = len(order)
179
- assigned_indices = [order[(counter + i) % total] for i in range(n)]
180
- new_counter = (counter + n) % total
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  with open(COUNTER_PATH, "w") as f:
182
- f.write(str(new_counter))
183
- return [items[i] for i in assigned_indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
  # ---------------------------------------------------------------------------
@@ -189,8 +282,8 @@ def assign_products(items, order, n=PRODUCTS_PER_USER):
189
  @st.cache_resource
190
  def get_model_client():
191
  return AsyncOpenAI(
192
- base_url="https://api.together.xyz/v1",
193
- api_key=TOGETHER_API_KEY,
194
  timeout=60.0,
195
  )
196
 
@@ -274,6 +367,7 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
274
  "product_index", "product_id", "title", "price", "familiarity",
275
  "pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label",
276
  "willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change",
 
277
  ]
278
  rows = []
279
  for i, prod in enumerate(products):
@@ -300,6 +394,7 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
300
  post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "",
301
  delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])),
302
  refl.get("standout_moment", ""), refl.get("thinking_change", ""),
 
303
  ]
304
  rows.append(row)
305
 
@@ -344,21 +439,20 @@ Price: {price_str}
344
 
345
  You need to convince the user to buy it.
346
 
347
- First message rules:
348
- - In ONE paragraph: briefly highlight the product's best quality, explain why it's worth buying, and hit them with the strongest benefit
349
  - End with an engaging question that draws out their interest or hesitation
350
 
351
- Follow-up message rules:
352
- - In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question
353
- - Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature
354
- - Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality)
355
  - Use "imagine if..." scenarios to make benefits concrete
356
 
357
- General style:
358
- - Be warm, confident, and conversational
359
- like a helpful friend who knows the product well, not a pushy salesperson
360
- - End your messages with an engaging question
361
- - Never fabricate statistics, details, or reviews you don't have
362
  - Never make up a price different from the one given
363
  """
364
 
@@ -384,16 +478,44 @@ def get_familiarity_choices():
384
  ]
385
 
386
 
 
 
 
 
 
 
 
 
 
387
  # ---------------------------------------------------------------------------
388
  # State initialisation
389
  # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  def init_state():
391
  download_and_cache_dataset()
392
- items = load_local_dataset()
393
- order = ensure_shuffled_order(len(items))
394
- assigned = assign_products(items, order, PRODUCTS_PER_USER)
395
 
396
- # Read MTurk query params if available
397
  try:
398
  params = st.query_params
399
  except Exception:
@@ -409,29 +531,9 @@ def init_state():
409
  "start_time": time.time(),
410
  "category": CATEGORY,
411
  "demographics": {},
412
- "products": [
413
- {
414
- "id": p.get("id", str(uuid.uuid4())),
415
- "title": p.get("title", ""),
416
- "description": p.get("description", []),
417
- "features": p.get("features", []),
418
- "price": p.get("price", "N/A"),
419
- "familiarity": None,
420
- "pre_willingness": None,
421
- "post_willingness": None,
422
- "willingness_delta": None,
423
- "conversation": {
424
- "system_prompt": "",
425
- "opening_user_message": "",
426
- "turns": [],
427
- "num_turns": 0,
428
- },
429
- "reflection": {},
430
- }
431
- for p in assigned
432
- ],
433
  "current_product_index": 0,
434
- "screen": "welcome", # screens: welcome | demographics | product_intro | chat | post_will | reflection | done
435
  "meta": {},
436
  }
437
 
@@ -442,11 +544,9 @@ def init_state():
442
  def inject_css():
443
  st.markdown("""
444
  <style>
445
- /* Hide Streamlit chrome */
446
  #MainMenu, footer, header { visibility: hidden; }
447
  .block-container { max-width: 820px; padding-top: 2rem; }
448
 
449
- /* Product card */
450
  .product-card {
451
  border: 2px solid #2563eb;
452
  border-radius: 10px;
@@ -461,76 +561,26 @@ def inject_css():
461
  margin-bottom: 0.6rem;
462
  gap: 1rem;
463
  }
464
- .pc-title {
465
- font-size: 1.05rem;
466
- font-weight: 700;
467
- color: #1a1a2e;
468
- line-height: 1.35;
469
- flex: 1;
470
- }
471
- .pc-price {
472
- font-size: 1.2rem;
473
- font-weight: 800;
474
- color: #16a34a;
475
- white-space: nowrap;
476
- }
477
  .pc-section { margin-top: 0.5rem; }
478
  .pc-section-title {
479
- font-weight: 600;
480
- font-size: 0.85rem;
481
- color: #475569;
482
- text-transform: uppercase;
483
- letter-spacing: 0.04em;
484
- margin-bottom: 0.3rem;
485
- }
486
- .pc-list {
487
- margin: 0;
488
- padding-left: 1.2rem;
489
- font-size: 0.92rem;
490
- color: #334155;
491
- line-height: 1.5;
492
  }
 
 
493
  .pc-list li { margin-bottom: 0.25rem; }
494
 
495
- /* Progress bar */
496
- .progress-wrap {
497
- background: #e2e8f0;
498
- border-radius: 99px;
499
- height: 8px;
500
- margin-bottom: 0.25rem;
501
- overflow: hidden;
502
- }
503
- .progress-fill {
504
- background: #2563eb;
505
- height: 100%;
506
- border-radius: 99px;
507
- }
508
- .progress-label {
509
- font-size: 0.82rem;
510
- color: #64748b;
511
- text-align: right;
512
- margin-bottom: 1rem;
513
- }
514
 
515
- /* Chat bubbles */
516
  .chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; }
517
  .bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; }
518
  .bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; }
519
  .bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; }
520
  .bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; }
521
-
522
- /* Compact product banner above chat */
523
- .chat-product-banner {
524
- border: 1.5px solid #93c5fd;
525
- border-radius: 8px;
526
- padding: 0.6rem 1rem;
527
- background: #eff6ff;
528
- margin-bottom: 0.75rem;
529
- font-size: 0.88rem;
530
- color: #1d4ed8;
531
- font-weight: 600;
532
- cursor: pointer;
533
- }
534
  </style>
535
  """, unsafe_allow_html=True)
536
 
@@ -545,11 +595,13 @@ def render_product_card_html(product: dict, compact: bool = False) -> str:
545
  features = product.get("features", [])
546
  price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
547
 
 
548
  desc_html = ""
549
  if description:
550
- items_html = "".join(f"<li>{d}</li>" for d in description if d)
551
- desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><ul class="pc-list">{items_html}</ul></div>'
552
 
 
553
  feat_html = ""
554
  if features:
555
  items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
@@ -592,7 +644,7 @@ def render_chat_history(turns: list):
592
  # Screen renderers
593
  # ---------------------------------------------------------------------------
594
  def screen_welcome(s):
595
- st.markdown(f"# 🛒 Product Evaluation Study")
596
  st.markdown(
597
  f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
598
  "For each product you will:\n"
@@ -689,13 +741,13 @@ def screen_product_intro(s):
689
  "How familiar are you with this product?",
690
  get_familiarity_choices(),
691
  index=None,
692
- key=f"familiarity_{idx}",
693
  )
694
  pre_will_val = st.radio(
695
  "How willing would you be to buy this product?",
696
  WILLINGNESS_CHOICES,
697
  index=None,
698
- key=f"pre_will_{idx}",
699
  )
700
 
701
  if st.button("Start Chat →", type="primary", use_container_width=True):
@@ -706,15 +758,28 @@ def screen_product_intro(s):
706
  if not pre_will_val:
707
  st.error("⚠️ Please rate your willingness to buy.")
708
  return
 
709
  familiarity_val = familiarity_val or get_familiarity_choices()[0]
710
  pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  pre_val = parse_willingness(pre_will_val)
713
  s["products"][idx]["familiarity"] = familiarity_val
714
  s["products"][idx]["pre_willingness"] = pre_val
715
  s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val]
716
 
717
- # Get opening AI message
718
  system_prompt = build_sales_system_prompt(product)
719
  opening_user_msg = build_opening_user_message(product)
720
  messages = [
@@ -743,7 +808,6 @@ def screen_chat(s):
743
  render_progress(idx + 1)
744
  st.markdown("## Chat with the AI")
745
 
746
- # Compact product banner
747
  title = product.get("title", "Product")
748
  price = product.get("price", "N/A")
749
  price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
@@ -752,24 +816,26 @@ def screen_chat(s):
752
 
753
  num_turns = conv["num_turns"]
754
  st.markdown(
755
- f"The AI is trying to convince you to buy this product. "
756
  f"Ask questions, push back, or explore your interest. "
757
  f"You need at least **{MIN_TURNS} exchanges** before you can move on."
758
  )
759
 
760
- # Chat history (only user/assistant turns, not the opening system exchange)
761
  display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")]
762
  render_chat_history(display_turns)
763
 
764
- # Turn counter
765
  if num_turns >= MAX_TURNS:
766
  st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.")
767
  else:
768
  st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}")
769
 
770
- # Input
771
  if num_turns < MAX_TURNS:
772
- user_msg = st.text_area("Your response:", placeholder="Type your response here…", height=100, key=f"chat_input_{idx}_{num_turns}")
 
 
 
 
 
773
  col1, col2 = st.columns([3, 1])
774
  with col2:
775
  send_clicked = st.button("Send", type="primary", use_container_width=True)
@@ -781,8 +847,10 @@ def screen_chat(s):
781
  st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).")
782
  return
783
  user_msg = user_msg.strip()
784
- messages = [{"role": "system", "content": conv["system_prompt"]},
785
- {"role": "user", "content": conv["opening_user_message"]}]
 
 
786
  for turn in conv["turns"]:
787
  messages.append({"role": turn["role"], "content": turn["content"]})
788
  messages.append({"role": "user", "content": user_msg})
@@ -796,7 +864,6 @@ def screen_chat(s):
796
  s["products"][idx]["conversation"] = conv
797
  st.rerun()
798
 
799
- # Done button
800
  can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE
801
  if can_finish:
802
  if st.button("I'm done chatting →", use_container_width=True):
@@ -819,7 +886,7 @@ def screen_post_willingness(s):
819
  "How willing would you be to buy this product now?",
820
  WILLINGNESS_CHOICES,
821
  index=None,
822
- key=f"post_will_{idx}",
823
  )
824
 
825
  if st.button("Next →", type="primary", use_container_width=True):
@@ -918,7 +985,6 @@ def screen_done(s):
918
  import pandas as pd
919
  st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
920
 
921
- # MTurk submit button
922
  assignment_id = s.get("assignment_id", "")
923
  turk_submit_to = s.get("turk_submit_to", "")
924
  if assignment_id and turk_submit_to:
 
2
  Streamlit App: AI Product Willingness User Study
3
  =================================================
4
  Run locally:
5
+ streamlit run src/streamlit_app.py -- --category groceries
6
+ streamlit run src/streamlit_app.py -- --category groceries --debug
7
 
8
  On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
9
  HF_TOKEN - HuggingFace token
10
+ TINKER_API_KEY - Tinker AI API key
11
+ TINKER_MODEL_PATH - Tinker sampler checkpoint path
12
  DATASET_REPO_ID - HuggingFace dataset repo to upload results
13
  CATEGORY - groceries | books | movies | health (default: groceries)
14
  DEBUG_MODE - "true" to skip validation (optional)
 
52
  DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
53
  DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
54
  HF_TOKEN = os.getenv("HF_TOKEN")
55
+
56
+ TINKER_API_KEY = os.getenv("TINKER_API_KEY")
57
+ TINKER_BASE_URL = "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1"
58
+ MODEL_NAME = os.getenv("TINKER_MODEL_PATH", "tinker://YOUR_RUN_ID:train:0/sampler_weights/000080")
59
 
60
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
61
  DATA_DIR = os.path.join(BASE_DIR, "data")
62
  ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
63
+ import shutil
64
+ shutil.rmtree(DATA_DIR, ignore_errors=True) # ← temporary, remove after one deploy
65
  os.makedirs(DATA_DIR, exist_ok=True)
66
  os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
67
+ # DATA_DIR = os.path.join(BASE_DIR, "data")
68
+ # ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
69
+ # os.makedirs(DATA_DIR, exist_ok=True)
70
+ # os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
71
+
72
 
73
  CATEGORY_TO_HF = {
74
  "books": "ehejin/amazon_books",
 
92
  PRODUCTS_PER_USER = 5
93
  MIN_TURNS = 3
94
  MAX_TURNS = 10
95
+ TEST_SUBSET_SIZE = 100 # only use first 100 items from test split
96
+
97
+ # Familiarity values that trigger a product swap
98
+ SWAP_FAMILIARITY = {"Purchased it before"}
99
 
100
  DEBUG_DEMOGRAPHICS = {
101
  "age": "30", "gender": "Female", "geographic_region": "West",
 
119
  WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
120
 
121
  # ---------------------------------------------------------------------------
122
+ # Dataset loading — test split, first 100 items
123
  # ---------------------------------------------------------------------------
124
+ LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_test100.json")
125
+ # Counter tracks which of the 100 products have been assigned globally
126
  COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
127
  COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
128
+ RETURN_QUEUE_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_return_queue.json")
129
+ # Overflow pool for swap replacements (products beyond the 100, or re-used ones)
130
+ OVERFLOW_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_overflow.json")
131
 
132
 
133
  @st.cache_resource
134
  def download_and_cache_dataset():
135
+ """Download test split (first 100 items) from HuggingFace and cache locally."""
136
  if os.path.exists(LOCAL_DATA_PATH):
137
  print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
138
  return
139
+ print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} (test split) from HuggingFace...")
140
  try:
141
  from datasets import load_dataset
142
  import huggingface_hub
143
  if HF_TOKEN:
144
  huggingface_hub.login(token=HF_TOKEN)
145
+
146
+ ds = load_dataset(CATEGORY_TO_HF[CATEGORY], split="test")
147
+
148
+ def to_list(val):
149
+ if isinstance(val, list): return val
150
+ if isinstance(val, str): return [val] if val else []
151
+ return []
152
+
153
+ all_items = []
154
  for row in ds:
155
  meta = row.get("metadata", {})
 
 
 
 
156
  item = {
157
  "id": str(uuid.uuid4()),
158
  "title": meta.get("title", "") if isinstance(meta, dict) else "",
 
161
  "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
162
  "category": CATEGORY,
163
  }
164
+ all_items.append(item)
165
+
166
+ # First 100 are the primary pool; the rest are the overflow/swap pool
167
+ primary = all_items[:TEST_SUBSET_SIZE]
168
+ overflow = all_items[TEST_SUBSET_SIZE:]
169
+
170
  with open(LOCAL_DATA_PATH, "w") as f:
171
+ json.dump(primary, f, indent=2)
172
+ with open(OVERFLOW_PATH, "w") as f:
173
+ json.dump(overflow, f, indent=2)
174
+
175
+ print(f"[DATA] Cached {len(primary)} primary + {len(overflow)} overflow items.")
176
  except Exception as e:
177
  print(f"[DATA] ERROR downloading dataset: {e}")
178
  raise
179
 
180
 
181
  @st.cache_resource
182
+ def load_primary_dataset():
183
  with open(LOCAL_DATA_PATH, "r") as f:
184
  return json.load(f)
185
 
186
 
187
  @st.cache_resource
188
+ def load_overflow_dataset():
189
+ if not os.path.exists(OVERFLOW_PATH):
190
+ return []
191
+ with open(OVERFLOW_PATH, "r") as f:
192
+ return json.load(f)
193
+
194
+
195
+ def assign_products(n=PRODUCTS_PER_USER):
196
+ """
197
+ Atomically assign the next n products.
198
+ Drains the return queue first (rejected products waiting for reassignment),
199
+ then pulls from the primary pool sequentially.
200
+ Falls back to overflow only if primary 100 is fully exhausted.
201
+ """
202
+ items = load_primary_dataset()
203
+ total = len(items)
204
  lock = FileLock(COUNTER_LOCK_PATH)
205
  with lock:
206
+ # Load return queue
207
+ return_queue = []
208
+ if os.path.exists(RETURN_QUEUE_PATH):
209
+ with open(RETURN_QUEUE_PATH, "r") as f:
210
+ try:
211
+ return_queue = json.load(f)
212
+ except Exception:
213
+ return_queue = []
214
+
215
+ # Load counter
216
+ counter = 0
217
  if os.path.exists(COUNTER_PATH):
218
  with open(COUNTER_PATH, "r") as f:
219
  counter = int(f.read().strip() or "0")
220
+
221
+ assigned = []
222
+ for _ in range(n):
223
+ if return_queue:
224
+ # Prioritise returned products so they still get reviewed
225
+ assigned.append(return_queue.pop(0))
226
+ elif counter < total:
227
+ assigned.append(items[counter])
228
+ counter += 1
229
+ else:
230
+ # Primary pool exhausted — fall back to overflow
231
+ overflow = load_overflow_dataset()
232
+ if overflow:
233
+ assigned.append(overflow[0])
234
+ # If truly nothing left, skip (shouldn't happen with 20 users / 100 products)
235
+
236
+ # Persist state
237
+ with open(RETURN_QUEUE_PATH, "w") as f:
238
+ json.dump(return_queue, f)
239
  with open(COUNTER_PATH, "w") as f:
240
+ f.write(str(counter))
241
+
242
+ return assigned
243
+
244
+
245
+ def return_product_to_queue(product: dict):
246
+ """
247
+ Put a rejected/swapped product back into the queue so it gets
248
+ reassigned to the next available user slot.
249
+ """
250
+ lock = FileLock(COUNTER_LOCK_PATH)
251
+ with lock:
252
+ queue = []
253
+ if os.path.exists(RETURN_QUEUE_PATH):
254
+ with open(RETURN_QUEUE_PATH, "r") as f:
255
+ try:
256
+ queue = json.load(f)
257
+ except Exception:
258
+ queue = []
259
+ # Avoid duplicates
260
+ if not any(p["id"] == product["id"] for p in queue):
261
+ queue.append(product)
262
+ with open(RETURN_QUEUE_PATH, "w") as f:
263
+ json.dump(queue, f)
264
+
265
+
266
+ def get_swap_product(exclude_ids: set) -> dict | None:
267
+ """
268
+ Get the next unassigned product from the primary pool,
269
+ then overflow. Excludes IDs already held by this user.
270
+ """
271
+ items = load_primary_dataset()
272
+ overflow = load_overflow_dataset()
273
+ for p in items + overflow:
274
+ if p["id"] not in exclude_ids:
275
+ return p
276
+ return None # extremely unlikely
277
 
278
 
279
  # ---------------------------------------------------------------------------
 
282
  @st.cache_resource
283
  def get_model_client():
284
  return AsyncOpenAI(
285
+ base_url=TINKER_BASE_URL,
286
+ api_key=TINKER_API_KEY,
287
  timeout=60.0,
288
  )
289
 
 
367
  "product_index", "product_id", "title", "price", "familiarity",
368
  "pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label",
369
  "willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change",
370
+ "was_swapped",
371
  ]
372
  rows = []
373
  for i, prod in enumerate(products):
 
394
  post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "",
395
  delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])),
396
  refl.get("standout_moment", ""), refl.get("thinking_change", ""),
397
+ prod.get("was_swapped", False),
398
  ]
399
  rows.append(row)
400
 
 
439
 
440
  You need to convince the user to buy it.
441
 
442
+ First message rules:
443
+ - In ONE paragraph: briefly highlight the product's best quality, explain why it's worth buying, and hit them with the strongest benefit
444
  - End with an engaging question that draws out their interest or hesitation
445
 
446
+ Follow-up message rules:
447
+ - In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question
448
+ - Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature
449
+ - Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality)
450
  - Use "imagine if..." scenarios to make benefits concrete
451
 
452
+ General style:
453
+ - Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson
454
+ - End your messages with an engaging question
455
+ - Never fabricate statistics, details, or reviews you don't have
 
456
  - Never make up a price different from the one given
457
  """
458
 
 
478
  ]
479
 
480
 
481
+ def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
482
+ """Return True if this product should be swapped out."""
483
+ if familiarity_val in SWAP_FAMILIARITY:
484
+ return True
485
+ if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)"
486
+ return True
487
+ return False
488
+
489
+
490
  # ---------------------------------------------------------------------------
491
  # State initialisation
492
  # ---------------------------------------------------------------------------
493
+ def make_product_slot(p: dict, was_swapped: bool = False) -> dict:
494
+ return {
495
+ "id": p.get("id", str(uuid.uuid4())),
496
+ "title": p.get("title", ""),
497
+ "description": p.get("description", []),
498
+ "features": p.get("features", []),
499
+ "price": p.get("price", "N/A"),
500
+ "familiarity": None,
501
+ "pre_willingness": None,
502
+ "post_willingness": None,
503
+ "willingness_delta": None,
504
+ "was_swapped": was_swapped,
505
+ "conversation": {
506
+ "system_prompt": "",
507
+ "opening_user_message": "",
508
+ "turns": [],
509
+ "num_turns": 0,
510
+ },
511
+ "reflection": {},
512
+ }
513
+
514
+
515
  def init_state():
516
  download_and_cache_dataset()
517
+ assigned = assign_products(PRODUCTS_PER_USER)
 
 
518
 
 
519
  try:
520
  params = st.query_params
521
  except Exception:
 
531
  "start_time": time.time(),
532
  "category": CATEGORY,
533
  "demographics": {},
534
+ "products": [make_product_slot(p) for p in assigned],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  "current_product_index": 0,
536
+ "screen": "welcome",
537
  "meta": {},
538
  }
539
 
 
544
  def inject_css():
545
  st.markdown("""
546
  <style>
 
547
  #MainMenu, footer, header { visibility: hidden; }
548
  .block-container { max-width: 820px; padding-top: 2rem; }
549
 
 
550
  .product-card {
551
  border: 2px solid #2563eb;
552
  border-radius: 10px;
 
561
  margin-bottom: 0.6rem;
562
  gap: 1rem;
563
  }
564
+ .pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; }
565
+ .pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; }
 
 
 
 
 
 
 
 
 
 
 
566
  .pc-section { margin-top: 0.5rem; }
567
  .pc-section-title {
568
+ font-weight: 600; font-size: 0.85rem; color: #475569;
569
+ text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.3rem;
 
 
 
 
 
 
 
 
 
 
 
570
  }
571
+ .pc-desc { font-size: 0.92rem; color: #334155; line-height: 1.6; }
572
+ .pc-list { margin: 0; padding-left: 1.2rem; font-size: 0.92rem; color: #334155; line-height: 1.5; }
573
  .pc-list li { margin-bottom: 0.25rem; }
574
 
575
+ .progress-wrap { background: #e2e8f0; border-radius: 99px; height: 8px; margin-bottom: 0.25rem; overflow: hidden; }
576
+ .progress-fill { background: #2563eb; height: 100%; border-radius: 99px; }
577
+ .progress-label { font-size: 0.82rem; color: #64748b; text-align: right; margin-bottom: 1rem; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
 
 
579
  .chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; }
580
  .bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; }
581
  .bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; }
582
  .bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; }
583
  .bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; }
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  </style>
585
  """, unsafe_allow_html=True)
586
 
 
595
  features = product.get("features", [])
596
  price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
597
 
598
+ # Description: joined with spaces as prose
599
  desc_html = ""
600
  if description:
601
+ desc_text = " ".join(d for d in description if d)
602
+ desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>'
603
 
604
+ # Features: bullet points
605
  feat_html = ""
606
  if features:
607
  items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
 
644
  # Screen renderers
645
  # ---------------------------------------------------------------------------
646
  def screen_welcome(s):
647
+ st.markdown("# 🛒 Product Evaluation Study")
648
  st.markdown(
649
  f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
650
  "For each product you will:\n"
 
741
  "How familiar are you with this product?",
742
  get_familiarity_choices(),
743
  index=None,
744
+ key=f"familiarity_{idx}_{product['id']}",
745
  )
746
  pre_will_val = st.radio(
747
  "How willing would you be to buy this product?",
748
  WILLINGNESS_CHOICES,
749
  index=None,
750
+ key=f"pre_will_{idx}_{product['id']}",
751
  )
752
 
753
  if st.button("Start Chat →", type="primary", use_container_width=True):
 
758
  if not pre_will_val:
759
  st.error("⚠️ Please rate your willingness to buy.")
760
  return
761
+
762
  familiarity_val = familiarity_val or get_familiarity_choices()[0]
763
  pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
764
 
765
+ # Check if we need to swap this product
766
+ if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE:
767
+ current_ids = {p["id"] for p in s["products"]}
768
+ replacement = get_swap_product(exclude_ids=current_ids)
769
+ if replacement:
770
+ # Return the rejected product to the queue so it gets reviewed by someone else
771
+ return_product_to_queue(s["products"][idx])
772
+ s["products"][idx] = make_product_slot(replacement, was_swapped=True)
773
+ st.info("We've swapped this product for a better match. Please review the new product below.")
774
+ st.rerun()
775
+ return
776
+ # If no replacement found, proceed anyway
777
+
778
  pre_val = parse_willingness(pre_will_val)
779
  s["products"][idx]["familiarity"] = familiarity_val
780
  s["products"][idx]["pre_willingness"] = pre_val
781
  s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val]
782
 
 
783
  system_prompt = build_sales_system_prompt(product)
784
  opening_user_msg = build_opening_user_message(product)
785
  messages = [
 
808
  render_progress(idx + 1)
809
  st.markdown("## Chat with the AI")
810
 
 
811
  title = product.get("title", "Product")
812
  price = product.get("price", "N/A")
813
  price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
 
816
 
817
  num_turns = conv["num_turns"]
818
  st.markdown(
819
+ f"Chat with the AI about whether you'd like to purchase the product. "
820
  f"Ask questions, push back, or explore your interest. "
821
  f"You need at least **{MIN_TURNS} exchanges** before you can move on."
822
  )
823
 
 
824
  display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")]
825
  render_chat_history(display_turns)
826
 
 
827
  if num_turns >= MAX_TURNS:
828
  st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.")
829
  else:
830
  st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}")
831
 
 
832
  if num_turns < MAX_TURNS:
833
+ user_msg = st.text_area(
834
+ "Your response:",
835
+ placeholder="Type your response here…",
836
+ height=100,
837
+ key=f"chat_input_{idx}_{num_turns}",
838
+ )
839
  col1, col2 = st.columns([3, 1])
840
  with col2:
841
  send_clicked = st.button("Send", type="primary", use_container_width=True)
 
847
  st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).")
848
  return
849
  user_msg = user_msg.strip()
850
+ messages = [
851
+ {"role": "system", "content": conv["system_prompt"]},
852
+ {"role": "user", "content": conv["opening_user_message"]},
853
+ ]
854
  for turn in conv["turns"]:
855
  messages.append({"role": turn["role"], "content": turn["content"]})
856
  messages.append({"role": "user", "content": user_msg})
 
864
  s["products"][idx]["conversation"] = conv
865
  st.rerun()
866
 
 
867
  can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE
868
  if can_finish:
869
  if st.button("I'm done chatting →", use_container_width=True):
 
886
  "How willing would you be to buy this product now?",
887
  WILLINGNESS_CHOICES,
888
  index=None,
889
+ key=f"post_will_{idx}_{product['id']}",
890
  )
891
 
892
  if st.button("Next →", type="primary", use_container_width=True):
 
985
  import pandas as pd
986
  st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
987
 
 
988
  assignment_id = s.get("assignment_id", "")
989
  turk_submit_to = s.get("turk_submit_to", "")
990
  if assignment_id and turk_submit_to: