Spaces:
Sleeping
Sleeping
temp: reset data cache
Browse files- src/streamlit_app.py +221 -155
src/streamlit_app.py
CHANGED
|
@@ -2,12 +2,13 @@
|
|
| 2 |
Streamlit App: AI Product Willingness User Study
|
| 3 |
=================================================
|
| 4 |
Run locally:
|
| 5 |
-
streamlit run
|
| 6 |
-
streamlit run
|
| 7 |
|
| 8 |
On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
|
| 9 |
HF_TOKEN - HuggingFace token
|
| 10 |
-
|
|
|
|
| 11 |
DATASET_REPO_ID - HuggingFace dataset repo to upload results
|
| 12 |
CATEGORY - groceries | books | movies | health (default: groceries)
|
| 13 |
DEBUG_MODE - "true" to skip validation (optional)
|
|
@@ -51,14 +52,23 @@ CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries"
|
|
| 51 |
DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
|
| 52 |
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
|
| 53 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 58 |
DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 59 |
ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
|
|
|
|
|
|
|
| 60 |
os.makedirs(DATA_DIR, exist_ok=True)
|
| 61 |
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
CATEGORY_TO_HF = {
|
| 64 |
"books": "ehejin/amazon_books",
|
|
@@ -82,6 +92,10 @@ FAMILIARITY_USED_LABEL = {
|
|
| 82 |
PRODUCTS_PER_USER = 5
|
| 83 |
MIN_TURNS = 3
|
| 84 |
MAX_TURNS = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
DEBUG_DEMOGRAPHICS = {
|
| 87 |
"age": "30", "gender": "Female", "geographic_region": "West",
|
|
@@ -105,33 +119,40 @@ WILLINGNESS_LABELS = {
|
|
| 105 |
WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
|
| 106 |
|
| 107 |
# ---------------------------------------------------------------------------
|
| 108 |
-
# Dataset loading
|
| 109 |
# ---------------------------------------------------------------------------
|
| 110 |
-
LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}.json")
|
| 111 |
-
|
| 112 |
COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
|
| 113 |
COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
@st.cache_resource
|
| 117 |
def download_and_cache_dataset():
|
|
|
|
| 118 |
if os.path.exists(LOCAL_DATA_PATH):
|
| 119 |
print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
|
| 120 |
return
|
| 121 |
-
print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} from HuggingFace...")
|
| 122 |
try:
|
| 123 |
from datasets import load_dataset
|
| 124 |
import huggingface_hub
|
| 125 |
if HF_TOKEN:
|
| 126 |
huggingface_hub.login(token=HF_TOKEN)
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
for row in ds:
|
| 130 |
meta = row.get("metadata", {})
|
| 131 |
-
def to_list(val):
|
| 132 |
-
if isinstance(val, list): return val
|
| 133 |
-
if isinstance(val, str): return [val] if val else []
|
| 134 |
-
return []
|
| 135 |
item = {
|
| 136 |
"id": str(uuid.uuid4()),
|
| 137 |
"title": meta.get("title", "") if isinstance(meta, dict) else "",
|
|
@@ -140,47 +161,119 @@ def download_and_cache_dataset():
|
|
| 140 |
"price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
|
| 141 |
"category": CATEGORY,
|
| 142 |
}
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
with open(LOCAL_DATA_PATH, "w") as f:
|
| 145 |
-
json.dump(
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
except Exception as e:
|
| 148 |
print(f"[DATA] ERROR downloading dataset: {e}")
|
| 149 |
raise
|
| 150 |
|
| 151 |
|
| 152 |
@st.cache_resource
|
| 153 |
-
def
|
| 154 |
with open(LOCAL_DATA_PATH, "r") as f:
|
| 155 |
return json.load(f)
|
| 156 |
|
| 157 |
|
| 158 |
@st.cache_resource
|
| 159 |
-
def
|
| 160 |
-
if os.path.exists(
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
lock = FileLock(COUNTER_LOCK_PATH)
|
| 172 |
with lock:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if os.path.exists(COUNTER_PATH):
|
| 174 |
with open(COUNTER_PATH, "r") as f:
|
| 175 |
counter = int(f.read().strip() or "0")
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
with open(COUNTER_PATH, "w") as f:
|
| 182 |
-
f.write(str(
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
# ---------------------------------------------------------------------------
|
|
@@ -189,8 +282,8 @@ def assign_products(items, order, n=PRODUCTS_PER_USER):
|
|
| 189 |
@st.cache_resource
|
| 190 |
def get_model_client():
|
| 191 |
return AsyncOpenAI(
|
| 192 |
-
base_url=
|
| 193 |
-
api_key=
|
| 194 |
timeout=60.0,
|
| 195 |
)
|
| 196 |
|
|
@@ -274,6 +367,7 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
|
|
| 274 |
"product_index", "product_id", "title", "price", "familiarity",
|
| 275 |
"pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label",
|
| 276 |
"willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change",
|
|
|
|
| 277 |
]
|
| 278 |
rows = []
|
| 279 |
for i, prod in enumerate(products):
|
|
@@ -300,6 +394,7 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
|
|
| 300 |
post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "",
|
| 301 |
delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])),
|
| 302 |
refl.get("standout_moment", ""), refl.get("thinking_change", ""),
|
|
|
|
| 303 |
]
|
| 304 |
rows.append(row)
|
| 305 |
|
|
@@ -344,21 +439,20 @@ Price: {price_str}
|
|
| 344 |
|
| 345 |
You need to convince the user to buy it.
|
| 346 |
|
| 347 |
-
First message rules:
|
| 348 |
-
- In ONE paragraph: briefly highlight the product's best quality, explain why it's worth buying, and hit them with the strongest benefit
|
| 349 |
- End with an engaging question that draws out their interest or hesitation
|
| 350 |
|
| 351 |
-
Follow-up message rules:
|
| 352 |
-
- In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question
|
| 353 |
-
- Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature
|
| 354 |
-
- Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality)
|
| 355 |
- Use "imagine if..." scenarios to make benefits concrete
|
| 356 |
|
| 357 |
-
General style:
|
| 358 |
-
- Be warm, confident, and conversational
|
| 359 |
-
|
| 360 |
-
-
|
| 361 |
-
- Never fabricate statistics, details, or reviews you don't have
|
| 362 |
- Never make up a price different from the one given
|
| 363 |
"""
|
| 364 |
|
|
@@ -384,16 +478,44 @@ def get_familiarity_choices():
|
|
| 384 |
]
|
| 385 |
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
# ---------------------------------------------------------------------------
|
| 388 |
# State initialisation
|
| 389 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
def init_state():
|
| 391 |
download_and_cache_dataset()
|
| 392 |
-
|
| 393 |
-
order = ensure_shuffled_order(len(items))
|
| 394 |
-
assigned = assign_products(items, order, PRODUCTS_PER_USER)
|
| 395 |
|
| 396 |
-
# Read MTurk query params if available
|
| 397 |
try:
|
| 398 |
params = st.query_params
|
| 399 |
except Exception:
|
|
@@ -409,29 +531,9 @@ def init_state():
|
|
| 409 |
"start_time": time.time(),
|
| 410 |
"category": CATEGORY,
|
| 411 |
"demographics": {},
|
| 412 |
-
"products": [
|
| 413 |
-
{
|
| 414 |
-
"id": p.get("id", str(uuid.uuid4())),
|
| 415 |
-
"title": p.get("title", ""),
|
| 416 |
-
"description": p.get("description", []),
|
| 417 |
-
"features": p.get("features", []),
|
| 418 |
-
"price": p.get("price", "N/A"),
|
| 419 |
-
"familiarity": None,
|
| 420 |
-
"pre_willingness": None,
|
| 421 |
-
"post_willingness": None,
|
| 422 |
-
"willingness_delta": None,
|
| 423 |
-
"conversation": {
|
| 424 |
-
"system_prompt": "",
|
| 425 |
-
"opening_user_message": "",
|
| 426 |
-
"turns": [],
|
| 427 |
-
"num_turns": 0,
|
| 428 |
-
},
|
| 429 |
-
"reflection": {},
|
| 430 |
-
}
|
| 431 |
-
for p in assigned
|
| 432 |
-
],
|
| 433 |
"current_product_index": 0,
|
| 434 |
-
"screen": "welcome",
|
| 435 |
"meta": {},
|
| 436 |
}
|
| 437 |
|
|
@@ -442,11 +544,9 @@ def init_state():
|
|
| 442 |
def inject_css():
|
| 443 |
st.markdown("""
|
| 444 |
<style>
|
| 445 |
-
/* Hide Streamlit chrome */
|
| 446 |
#MainMenu, footer, header { visibility: hidden; }
|
| 447 |
.block-container { max-width: 820px; padding-top: 2rem; }
|
| 448 |
|
| 449 |
-
/* Product card */
|
| 450 |
.product-card {
|
| 451 |
border: 2px solid #2563eb;
|
| 452 |
border-radius: 10px;
|
|
@@ -461,76 +561,26 @@ def inject_css():
|
|
| 461 |
margin-bottom: 0.6rem;
|
| 462 |
gap: 1rem;
|
| 463 |
}
|
| 464 |
-
.pc-title {
|
| 465 |
-
|
| 466 |
-
font-weight: 700;
|
| 467 |
-
color: #1a1a2e;
|
| 468 |
-
line-height: 1.35;
|
| 469 |
-
flex: 1;
|
| 470 |
-
}
|
| 471 |
-
.pc-price {
|
| 472 |
-
font-size: 1.2rem;
|
| 473 |
-
font-weight: 800;
|
| 474 |
-
color: #16a34a;
|
| 475 |
-
white-space: nowrap;
|
| 476 |
-
}
|
| 477 |
.pc-section { margin-top: 0.5rem; }
|
| 478 |
.pc-section-title {
|
| 479 |
-
font-weight: 600;
|
| 480 |
-
|
| 481 |
-
color: #475569;
|
| 482 |
-
text-transform: uppercase;
|
| 483 |
-
letter-spacing: 0.04em;
|
| 484 |
-
margin-bottom: 0.3rem;
|
| 485 |
-
}
|
| 486 |
-
.pc-list {
|
| 487 |
-
margin: 0;
|
| 488 |
-
padding-left: 1.2rem;
|
| 489 |
-
font-size: 0.92rem;
|
| 490 |
-
color: #334155;
|
| 491 |
-
line-height: 1.5;
|
| 492 |
}
|
|
|
|
|
|
|
| 493 |
.pc-list li { margin-bottom: 0.25rem; }
|
| 494 |
|
| 495 |
-
|
| 496 |
-
.progress-
|
| 497 |
-
|
| 498 |
-
border-radius: 99px;
|
| 499 |
-
height: 8px;
|
| 500 |
-
margin-bottom: 0.25rem;
|
| 501 |
-
overflow: hidden;
|
| 502 |
-
}
|
| 503 |
-
.progress-fill {
|
| 504 |
-
background: #2563eb;
|
| 505 |
-
height: 100%;
|
| 506 |
-
border-radius: 99px;
|
| 507 |
-
}
|
| 508 |
-
.progress-label {
|
| 509 |
-
font-size: 0.82rem;
|
| 510 |
-
color: #64748b;
|
| 511 |
-
text-align: right;
|
| 512 |
-
margin-bottom: 1rem;
|
| 513 |
-
}
|
| 514 |
|
| 515 |
-
/* Chat bubbles */
|
| 516 |
.chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; }
|
| 517 |
.bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; }
|
| 518 |
.bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; }
|
| 519 |
.bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; }
|
| 520 |
.bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; }
|
| 521 |
-
|
| 522 |
-
/* Compact product banner above chat */
|
| 523 |
-
.chat-product-banner {
|
| 524 |
-
border: 1.5px solid #93c5fd;
|
| 525 |
-
border-radius: 8px;
|
| 526 |
-
padding: 0.6rem 1rem;
|
| 527 |
-
background: #eff6ff;
|
| 528 |
-
margin-bottom: 0.75rem;
|
| 529 |
-
font-size: 0.88rem;
|
| 530 |
-
color: #1d4ed8;
|
| 531 |
-
font-weight: 600;
|
| 532 |
-
cursor: pointer;
|
| 533 |
-
}
|
| 534 |
</style>
|
| 535 |
""", unsafe_allow_html=True)
|
| 536 |
|
|
@@ -545,11 +595,13 @@ def render_product_card_html(product: dict, compact: bool = False) -> str:
|
|
| 545 |
features = product.get("features", [])
|
| 546 |
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 547 |
|
|
|
|
| 548 |
desc_html = ""
|
| 549 |
if description:
|
| 550 |
-
|
| 551 |
-
desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><
|
| 552 |
|
|
|
|
| 553 |
feat_html = ""
|
| 554 |
if features:
|
| 555 |
items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
|
|
@@ -592,7 +644,7 @@ def render_chat_history(turns: list):
|
|
| 592 |
# Screen renderers
|
| 593 |
# ---------------------------------------------------------------------------
|
| 594 |
def screen_welcome(s):
|
| 595 |
-
st.markdown(
|
| 596 |
st.markdown(
|
| 597 |
f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
|
| 598 |
"For each product you will:\n"
|
|
@@ -689,13 +741,13 @@ def screen_product_intro(s):
|
|
| 689 |
"How familiar are you with this product?",
|
| 690 |
get_familiarity_choices(),
|
| 691 |
index=None,
|
| 692 |
-
key=f"familiarity_{idx}",
|
| 693 |
)
|
| 694 |
pre_will_val = st.radio(
|
| 695 |
"How willing would you be to buy this product?",
|
| 696 |
WILLINGNESS_CHOICES,
|
| 697 |
index=None,
|
| 698 |
-
key=f"pre_will_{idx}",
|
| 699 |
)
|
| 700 |
|
| 701 |
if st.button("Start Chat →", type="primary", use_container_width=True):
|
|
@@ -706,15 +758,28 @@ def screen_product_intro(s):
|
|
| 706 |
if not pre_will_val:
|
| 707 |
st.error("⚠️ Please rate your willingness to buy.")
|
| 708 |
return
|
|
|
|
| 709 |
familiarity_val = familiarity_val or get_familiarity_choices()[0]
|
| 710 |
pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
|
| 711 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
pre_val = parse_willingness(pre_will_val)
|
| 713 |
s["products"][idx]["familiarity"] = familiarity_val
|
| 714 |
s["products"][idx]["pre_willingness"] = pre_val
|
| 715 |
s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val]
|
| 716 |
|
| 717 |
-
# Get opening AI message
|
| 718 |
system_prompt = build_sales_system_prompt(product)
|
| 719 |
opening_user_msg = build_opening_user_message(product)
|
| 720 |
messages = [
|
|
@@ -743,7 +808,6 @@ def screen_chat(s):
|
|
| 743 |
render_progress(idx + 1)
|
| 744 |
st.markdown("## Chat with the AI")
|
| 745 |
|
| 746 |
-
# Compact product banner
|
| 747 |
title = product.get("title", "Product")
|
| 748 |
price = product.get("price", "N/A")
|
| 749 |
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
|
@@ -752,24 +816,26 @@ def screen_chat(s):
|
|
| 752 |
|
| 753 |
num_turns = conv["num_turns"]
|
| 754 |
st.markdown(
|
| 755 |
-
f"
|
| 756 |
f"Ask questions, push back, or explore your interest. "
|
| 757 |
f"You need at least **{MIN_TURNS} exchanges** before you can move on."
|
| 758 |
)
|
| 759 |
|
| 760 |
-
# Chat history (only user/assistant turns, not the opening system exchange)
|
| 761 |
display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")]
|
| 762 |
render_chat_history(display_turns)
|
| 763 |
|
| 764 |
-
# Turn counter
|
| 765 |
if num_turns >= MAX_TURNS:
|
| 766 |
st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.")
|
| 767 |
else:
|
| 768 |
st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}")
|
| 769 |
|
| 770 |
-
# Input
|
| 771 |
if num_turns < MAX_TURNS:
|
| 772 |
-
user_msg = st.text_area(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
col1, col2 = st.columns([3, 1])
|
| 774 |
with col2:
|
| 775 |
send_clicked = st.button("Send", type="primary", use_container_width=True)
|
|
@@ -781,8 +847,10 @@ def screen_chat(s):
|
|
| 781 |
st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).")
|
| 782 |
return
|
| 783 |
user_msg = user_msg.strip()
|
| 784 |
-
messages = [
|
| 785 |
-
|
|
|
|
|
|
|
| 786 |
for turn in conv["turns"]:
|
| 787 |
messages.append({"role": turn["role"], "content": turn["content"]})
|
| 788 |
messages.append({"role": "user", "content": user_msg})
|
|
@@ -796,7 +864,6 @@ def screen_chat(s):
|
|
| 796 |
s["products"][idx]["conversation"] = conv
|
| 797 |
st.rerun()
|
| 798 |
|
| 799 |
-
# Done button
|
| 800 |
can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE
|
| 801 |
if can_finish:
|
| 802 |
if st.button("I'm done chatting →", use_container_width=True):
|
|
@@ -819,7 +886,7 @@ def screen_post_willingness(s):
|
|
| 819 |
"How willing would you be to buy this product now?",
|
| 820 |
WILLINGNESS_CHOICES,
|
| 821 |
index=None,
|
| 822 |
-
key=f"post_will_{idx}",
|
| 823 |
)
|
| 824 |
|
| 825 |
if st.button("Next →", type="primary", use_container_width=True):
|
|
@@ -918,7 +985,6 @@ def screen_done(s):
|
|
| 918 |
import pandas as pd
|
| 919 |
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 920 |
|
| 921 |
-
# MTurk submit button
|
| 922 |
assignment_id = s.get("assignment_id", "")
|
| 923 |
turk_submit_to = s.get("turk_submit_to", "")
|
| 924 |
if assignment_id and turk_submit_to:
|
|
|
|
| 2 |
Streamlit App: AI Product Willingness User Study
|
| 3 |
=================================================
|
| 4 |
Run locally:
|
| 5 |
+
streamlit run src/streamlit_app.py -- --category groceries
|
| 6 |
+
streamlit run src/streamlit_app.py -- --category groceries --debug
|
| 7 |
|
| 8 |
On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
|
| 9 |
HF_TOKEN - HuggingFace token
|
| 10 |
+
TINKER_API_KEY - Tinker AI API key
|
| 11 |
+
TINKER_MODEL_PATH - Tinker sampler checkpoint path
|
| 12 |
DATASET_REPO_ID - HuggingFace dataset repo to upload results
|
| 13 |
CATEGORY - groceries | books | movies | health (default: groceries)
|
| 14 |
DEBUG_MODE - "true" to skip validation (optional)
|
|
|
|
| 52 |
DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
|
| 53 |
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
|
| 54 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 55 |
+
|
| 56 |
+
TINKER_API_KEY = os.getenv("TINKER_API_KEY")
|
| 57 |
+
TINKER_BASE_URL = "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1"
|
| 58 |
+
MODEL_NAME = os.getenv("TINKER_MODEL_PATH", "tinker://YOUR_RUN_ID:train:0/sampler_weights/000080")
|
| 59 |
|
| 60 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 61 |
DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 62 |
ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
|
| 63 |
+
import shutil
|
| 64 |
+
shutil.rmtree(DATA_DIR, ignore_errors=True) # ← temporary, remove after one deploy
|
| 65 |
os.makedirs(DATA_DIR, exist_ok=True)
|
| 66 |
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
| 67 |
+
# DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 68 |
+
# ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
|
| 69 |
+
# os.makedirs(DATA_DIR, exist_ok=True)
|
| 70 |
+
# os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
| 71 |
+
|
| 72 |
|
| 73 |
CATEGORY_TO_HF = {
|
| 74 |
"books": "ehejin/amazon_books",
|
|
|
|
| 92 |
PRODUCTS_PER_USER = 5
|
| 93 |
MIN_TURNS = 3
|
| 94 |
MAX_TURNS = 10
|
| 95 |
+
TEST_SUBSET_SIZE = 100 # only use first 100 items from test split
|
| 96 |
+
|
| 97 |
+
# Familiarity values that trigger a product swap
|
| 98 |
+
SWAP_FAMILIARITY = {"Purchased it before"}
|
| 99 |
|
| 100 |
DEBUG_DEMOGRAPHICS = {
|
| 101 |
"age": "30", "gender": "Female", "geographic_region": "West",
|
|
|
|
| 119 |
WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
|
| 120 |
|
| 121 |
# ---------------------------------------------------------------------------
|
| 122 |
+
# Dataset loading — test split, first 100 items
|
| 123 |
# ---------------------------------------------------------------------------
|
| 124 |
+
LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_test100.json")
|
| 125 |
+
# Counter tracks which of the 100 products have been assigned globally
|
| 126 |
COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
|
| 127 |
COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
|
| 128 |
+
RETURN_QUEUE_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_return_queue.json")
|
| 129 |
+
# Overflow pool for swap replacements (products beyond the 100, or re-used ones)
|
| 130 |
+
OVERFLOW_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_overflow.json")
|
| 131 |
|
| 132 |
|
| 133 |
@st.cache_resource
|
| 134 |
def download_and_cache_dataset():
|
| 135 |
+
"""Download test split (first 100 items) from HuggingFace and cache locally."""
|
| 136 |
if os.path.exists(LOCAL_DATA_PATH):
|
| 137 |
print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
|
| 138 |
return
|
| 139 |
+
print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} (test split) from HuggingFace...")
|
| 140 |
try:
|
| 141 |
from datasets import load_dataset
|
| 142 |
import huggingface_hub
|
| 143 |
if HF_TOKEN:
|
| 144 |
huggingface_hub.login(token=HF_TOKEN)
|
| 145 |
+
|
| 146 |
+
ds = load_dataset(CATEGORY_TO_HF[CATEGORY], split="test")
|
| 147 |
+
|
| 148 |
+
def to_list(val):
|
| 149 |
+
if isinstance(val, list): return val
|
| 150 |
+
if isinstance(val, str): return [val] if val else []
|
| 151 |
+
return []
|
| 152 |
+
|
| 153 |
+
all_items = []
|
| 154 |
for row in ds:
|
| 155 |
meta = row.get("metadata", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
item = {
|
| 157 |
"id": str(uuid.uuid4()),
|
| 158 |
"title": meta.get("title", "") if isinstance(meta, dict) else "",
|
|
|
|
| 161 |
"price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
|
| 162 |
"category": CATEGORY,
|
| 163 |
}
|
| 164 |
+
all_items.append(item)
|
| 165 |
+
|
| 166 |
+
# First 100 are the primary pool; the rest are the overflow/swap pool
|
| 167 |
+
primary = all_items[:TEST_SUBSET_SIZE]
|
| 168 |
+
overflow = all_items[TEST_SUBSET_SIZE:]
|
| 169 |
+
|
| 170 |
with open(LOCAL_DATA_PATH, "w") as f:
|
| 171 |
+
json.dump(primary, f, indent=2)
|
| 172 |
+
with open(OVERFLOW_PATH, "w") as f:
|
| 173 |
+
json.dump(overflow, f, indent=2)
|
| 174 |
+
|
| 175 |
+
print(f"[DATA] Cached {len(primary)} primary + {len(overflow)} overflow items.")
|
| 176 |
except Exception as e:
|
| 177 |
print(f"[DATA] ERROR downloading dataset: {e}")
|
| 178 |
raise
|
| 179 |
|
| 180 |
|
| 181 |
@st.cache_resource
|
| 182 |
+
def load_primary_dataset():
|
| 183 |
with open(LOCAL_DATA_PATH, "r") as f:
|
| 184 |
return json.load(f)
|
| 185 |
|
| 186 |
|
| 187 |
@st.cache_resource
|
| 188 |
+
def load_overflow_dataset():
|
| 189 |
+
if not os.path.exists(OVERFLOW_PATH):
|
| 190 |
+
return []
|
| 191 |
+
with open(OVERFLOW_PATH, "r") as f:
|
| 192 |
+
return json.load(f)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def assign_products(n=PRODUCTS_PER_USER):
|
| 196 |
+
"""
|
| 197 |
+
Atomically assign the next n products.
|
| 198 |
+
Drains the return queue first (rejected products waiting for reassignment),
|
| 199 |
+
then pulls from the primary pool sequentially.
|
| 200 |
+
Falls back to overflow only if primary 100 is fully exhausted.
|
| 201 |
+
"""
|
| 202 |
+
items = load_primary_dataset()
|
| 203 |
+
total = len(items)
|
| 204 |
lock = FileLock(COUNTER_LOCK_PATH)
|
| 205 |
with lock:
|
| 206 |
+
# Load return queue
|
| 207 |
+
return_queue = []
|
| 208 |
+
if os.path.exists(RETURN_QUEUE_PATH):
|
| 209 |
+
with open(RETURN_QUEUE_PATH, "r") as f:
|
| 210 |
+
try:
|
| 211 |
+
return_queue = json.load(f)
|
| 212 |
+
except Exception:
|
| 213 |
+
return_queue = []
|
| 214 |
+
|
| 215 |
+
# Load counter
|
| 216 |
+
counter = 0
|
| 217 |
if os.path.exists(COUNTER_PATH):
|
| 218 |
with open(COUNTER_PATH, "r") as f:
|
| 219 |
counter = int(f.read().strip() or "0")
|
| 220 |
+
|
| 221 |
+
assigned = []
|
| 222 |
+
for _ in range(n):
|
| 223 |
+
if return_queue:
|
| 224 |
+
# Prioritise returned products so they still get reviewed
|
| 225 |
+
assigned.append(return_queue.pop(0))
|
| 226 |
+
elif counter < total:
|
| 227 |
+
assigned.append(items[counter])
|
| 228 |
+
counter += 1
|
| 229 |
+
else:
|
| 230 |
+
# Primary pool exhausted — fall back to overflow
|
| 231 |
+
overflow = load_overflow_dataset()
|
| 232 |
+
if overflow:
|
| 233 |
+
assigned.append(overflow[0])
|
| 234 |
+
# If truly nothing left, skip (shouldn't happen with 20 users / 100 products)
|
| 235 |
+
|
| 236 |
+
# Persist state
|
| 237 |
+
with open(RETURN_QUEUE_PATH, "w") as f:
|
| 238 |
+
json.dump(return_queue, f)
|
| 239 |
with open(COUNTER_PATH, "w") as f:
|
| 240 |
+
f.write(str(counter))
|
| 241 |
+
|
| 242 |
+
return assigned
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def return_product_to_queue(product: dict):
|
| 246 |
+
"""
|
| 247 |
+
Put a rejected/swapped product back into the queue so it gets
|
| 248 |
+
reassigned to the next available user slot.
|
| 249 |
+
"""
|
| 250 |
+
lock = FileLock(COUNTER_LOCK_PATH)
|
| 251 |
+
with lock:
|
| 252 |
+
queue = []
|
| 253 |
+
if os.path.exists(RETURN_QUEUE_PATH):
|
| 254 |
+
with open(RETURN_QUEUE_PATH, "r") as f:
|
| 255 |
+
try:
|
| 256 |
+
queue = json.load(f)
|
| 257 |
+
except Exception:
|
| 258 |
+
queue = []
|
| 259 |
+
# Avoid duplicates
|
| 260 |
+
if not any(p["id"] == product["id"] for p in queue):
|
| 261 |
+
queue.append(product)
|
| 262 |
+
with open(RETURN_QUEUE_PATH, "w") as f:
|
| 263 |
+
json.dump(queue, f)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_swap_product(exclude_ids: set) -> dict | None:
|
| 267 |
+
"""
|
| 268 |
+
Get the next unassigned product from the primary pool,
|
| 269 |
+
then overflow. Excludes IDs already held by this user.
|
| 270 |
+
"""
|
| 271 |
+
items = load_primary_dataset()
|
| 272 |
+
overflow = load_overflow_dataset()
|
| 273 |
+
for p in items + overflow:
|
| 274 |
+
if p["id"] not in exclude_ids:
|
| 275 |
+
return p
|
| 276 |
+
return None # extremely unlikely
|
| 277 |
|
| 278 |
|
| 279 |
# ---------------------------------------------------------------------------
|
|
|
|
| 282 |
@st.cache_resource
|
| 283 |
def get_model_client():
|
| 284 |
return AsyncOpenAI(
|
| 285 |
+
base_url=TINKER_BASE_URL,
|
| 286 |
+
api_key=TINKER_API_KEY,
|
| 287 |
timeout=60.0,
|
| 288 |
)
|
| 289 |
|
|
|
|
| 367 |
"product_index", "product_id", "title", "price", "familiarity",
|
| 368 |
"pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label",
|
| 369 |
"willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change",
|
| 370 |
+
"was_swapped",
|
| 371 |
]
|
| 372 |
rows = []
|
| 373 |
for i, prod in enumerate(products):
|
|
|
|
| 394 |
post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "",
|
| 395 |
delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])),
|
| 396 |
refl.get("standout_moment", ""), refl.get("thinking_change", ""),
|
| 397 |
+
prod.get("was_swapped", False),
|
| 398 |
]
|
| 399 |
rows.append(row)
|
| 400 |
|
|
|
|
| 439 |
|
| 440 |
You need to convince the user to buy it.
|
| 441 |
|
| 442 |
+
First message rules:
|
| 443 |
+
- In ONE paragraph: briefly highlight the product's best quality, explain why it's worth buying, and hit them with the strongest benefit
|
| 444 |
- End with an engaging question that draws out their interest or hesitation
|
| 445 |
|
| 446 |
+
Follow-up message rules:
|
| 447 |
+
- In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question
|
| 448 |
+
- Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature
|
| 449 |
+
- Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality)
|
| 450 |
- Use "imagine if..." scenarios to make benefits concrete
|
| 451 |
|
| 452 |
+
General style:
|
| 453 |
+
- Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson
|
| 454 |
+
- End your messages with an engaging question
|
| 455 |
+
- Never fabricate statistics, details, or reviews you don't have
|
|
|
|
| 456 |
- Never make up a price different from the one given
|
| 457 |
"""
|
| 458 |
|
|
|
|
| 478 |
]
|
| 479 |
|
| 480 |
|
| 481 |
+
def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
|
| 482 |
+
"""Return True if this product should be swapped out."""
|
| 483 |
+
if familiarity_val in SWAP_FAMILIARITY:
|
| 484 |
+
return True
|
| 485 |
+
if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)"
|
| 486 |
+
return True
|
| 487 |
+
return False
|
| 488 |
+
|
| 489 |
+
|
| 490 |
# ---------------------------------------------------------------------------
|
| 491 |
# State initialisation
|
| 492 |
# ---------------------------------------------------------------------------
|
| 493 |
+
def make_product_slot(p: dict, was_swapped: bool = False) -> dict:
|
| 494 |
+
return {
|
| 495 |
+
"id": p.get("id", str(uuid.uuid4())),
|
| 496 |
+
"title": p.get("title", ""),
|
| 497 |
+
"description": p.get("description", []),
|
| 498 |
+
"features": p.get("features", []),
|
| 499 |
+
"price": p.get("price", "N/A"),
|
| 500 |
+
"familiarity": None,
|
| 501 |
+
"pre_willingness": None,
|
| 502 |
+
"post_willingness": None,
|
| 503 |
+
"willingness_delta": None,
|
| 504 |
+
"was_swapped": was_swapped,
|
| 505 |
+
"conversation": {
|
| 506 |
+
"system_prompt": "",
|
| 507 |
+
"opening_user_message": "",
|
| 508 |
+
"turns": [],
|
| 509 |
+
"num_turns": 0,
|
| 510 |
+
},
|
| 511 |
+
"reflection": {},
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
|
| 515 |
def init_state():
|
| 516 |
download_and_cache_dataset()
|
| 517 |
+
assigned = assign_products(PRODUCTS_PER_USER)
|
|
|
|
|
|
|
| 518 |
|
|
|
|
| 519 |
try:
|
| 520 |
params = st.query_params
|
| 521 |
except Exception:
|
|
|
|
| 531 |
"start_time": time.time(),
|
| 532 |
"category": CATEGORY,
|
| 533 |
"demographics": {},
|
| 534 |
+
"products": [make_product_slot(p) for p in assigned],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
"current_product_index": 0,
|
| 536 |
+
"screen": "welcome",
|
| 537 |
"meta": {},
|
| 538 |
}
|
| 539 |
|
|
|
|
| 544 |
def inject_css():
|
| 545 |
st.markdown("""
|
| 546 |
<style>
|
|
|
|
| 547 |
#MainMenu, footer, header { visibility: hidden; }
|
| 548 |
.block-container { max-width: 820px; padding-top: 2rem; }
|
| 549 |
|
|
|
|
| 550 |
.product-card {
|
| 551 |
border: 2px solid #2563eb;
|
| 552 |
border-radius: 10px;
|
|
|
|
| 561 |
margin-bottom: 0.6rem;
|
| 562 |
gap: 1rem;
|
| 563 |
}
|
| 564 |
+
.pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; }
|
| 565 |
+
.pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
.pc-section { margin-top: 0.5rem; }
|
| 567 |
.pc-section-title {
|
| 568 |
+
font-weight: 600; font-size: 0.85rem; color: #475569;
|
| 569 |
+
text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.3rem;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
}
|
| 571 |
+
.pc-desc { font-size: 0.92rem; color: #334155; line-height: 1.6; }
|
| 572 |
+
.pc-list { margin: 0; padding-left: 1.2rem; font-size: 0.92rem; color: #334155; line-height: 1.5; }
|
| 573 |
.pc-list li { margin-bottom: 0.25rem; }
|
| 574 |
|
| 575 |
+
.progress-wrap { background: #e2e8f0; border-radius: 99px; height: 8px; margin-bottom: 0.25rem; overflow: hidden; }
|
| 576 |
+
.progress-fill { background: #2563eb; height: 100%; border-radius: 99px; }
|
| 577 |
+
.progress-label { font-size: 0.82rem; color: #64748b; text-align: right; margin-bottom: 1rem; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
|
|
|
| 579 |
.chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; }
|
| 580 |
.bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; }
|
| 581 |
.bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; }
|
| 582 |
.bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; }
|
| 583 |
.bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
</style>
|
| 585 |
""", unsafe_allow_html=True)
|
| 586 |
|
|
|
|
| 595 |
features = product.get("features", [])
|
| 596 |
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 597 |
|
| 598 |
+
# Description: joined with spaces as prose
|
| 599 |
desc_html = ""
|
| 600 |
if description:
|
| 601 |
+
desc_text = " ".join(d for d in description if d)
|
| 602 |
+
desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>'
|
| 603 |
|
| 604 |
+
# Features: bullet points
|
| 605 |
feat_html = ""
|
| 606 |
if features:
|
| 607 |
items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
|
|
|
|
| 644 |
# Screen renderers
|
| 645 |
# ---------------------------------------------------------------------------
|
| 646 |
def screen_welcome(s):
|
| 647 |
+
st.markdown("# 🛒 Product Evaluation Study")
|
| 648 |
st.markdown(
|
| 649 |
f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
|
| 650 |
"For each product you will:\n"
|
|
|
|
| 741 |
"How familiar are you with this product?",
|
| 742 |
get_familiarity_choices(),
|
| 743 |
index=None,
|
| 744 |
+
key=f"familiarity_{idx}_{product['id']}",
|
| 745 |
)
|
| 746 |
pre_will_val = st.radio(
|
| 747 |
"How willing would you be to buy this product?",
|
| 748 |
WILLINGNESS_CHOICES,
|
| 749 |
index=None,
|
| 750 |
+
key=f"pre_will_{idx}_{product['id']}",
|
| 751 |
)
|
| 752 |
|
| 753 |
if st.button("Start Chat →", type="primary", use_container_width=True):
|
|
|
|
| 758 |
if not pre_will_val:
|
| 759 |
st.error("⚠️ Please rate your willingness to buy.")
|
| 760 |
return
|
| 761 |
+
|
| 762 |
familiarity_val = familiarity_val or get_familiarity_choices()[0]
|
| 763 |
pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
|
| 764 |
|
| 765 |
+
# Check if we need to swap this product
|
| 766 |
+
if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE:
|
| 767 |
+
current_ids = {p["id"] for p in s["products"]}
|
| 768 |
+
replacement = get_swap_product(exclude_ids=current_ids)
|
| 769 |
+
if replacement:
|
| 770 |
+
# Return the rejected product to the queue so it gets reviewed by someone else
|
| 771 |
+
return_product_to_queue(s["products"][idx])
|
| 772 |
+
s["products"][idx] = make_product_slot(replacement, was_swapped=True)
|
| 773 |
+
st.info("We've swapped this product for a better match. Please review the new product below.")
|
| 774 |
+
st.rerun()
|
| 775 |
+
return
|
| 776 |
+
# If no replacement found, proceed anyway
|
| 777 |
+
|
| 778 |
pre_val = parse_willingness(pre_will_val)
|
| 779 |
s["products"][idx]["familiarity"] = familiarity_val
|
| 780 |
s["products"][idx]["pre_willingness"] = pre_val
|
| 781 |
s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val]
|
| 782 |
|
|
|
|
| 783 |
system_prompt = build_sales_system_prompt(product)
|
| 784 |
opening_user_msg = build_opening_user_message(product)
|
| 785 |
messages = [
|
|
|
|
| 808 |
render_progress(idx + 1)
|
| 809 |
st.markdown("## Chat with the AI")
|
| 810 |
|
|
|
|
| 811 |
title = product.get("title", "Product")
|
| 812 |
price = product.get("price", "N/A")
|
| 813 |
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
|
|
|
| 816 |
|
| 817 |
num_turns = conv["num_turns"]
|
| 818 |
st.markdown(
|
| 819 |
+
f"Chat with the AI about whether you'd like to purchase the product. "
|
| 820 |
f"Ask questions, push back, or explore your interest. "
|
| 821 |
f"You need at least **{MIN_TURNS} exchanges** before you can move on."
|
| 822 |
)
|
| 823 |
|
|
|
|
| 824 |
display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")]
|
| 825 |
render_chat_history(display_turns)
|
| 826 |
|
|
|
|
| 827 |
if num_turns >= MAX_TURNS:
|
| 828 |
st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.")
|
| 829 |
else:
|
| 830 |
st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}")
|
| 831 |
|
|
|
|
| 832 |
if num_turns < MAX_TURNS:
|
| 833 |
+
user_msg = st.text_area(
|
| 834 |
+
"Your response:",
|
| 835 |
+
placeholder="Type your response here…",
|
| 836 |
+
height=100,
|
| 837 |
+
key=f"chat_input_{idx}_{num_turns}",
|
| 838 |
+
)
|
| 839 |
col1, col2 = st.columns([3, 1])
|
| 840 |
with col2:
|
| 841 |
send_clicked = st.button("Send", type="primary", use_container_width=True)
|
|
|
|
| 847 |
st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).")
|
| 848 |
return
|
| 849 |
user_msg = user_msg.strip()
|
| 850 |
+
messages = [
|
| 851 |
+
{"role": "system", "content": conv["system_prompt"]},
|
| 852 |
+
{"role": "user", "content": conv["opening_user_message"]},
|
| 853 |
+
]
|
| 854 |
for turn in conv["turns"]:
|
| 855 |
messages.append({"role": turn["role"], "content": turn["content"]})
|
| 856 |
messages.append({"role": "user", "content": user_msg})
|
|
|
|
| 864 |
s["products"][idx]["conversation"] = conv
|
| 865 |
st.rerun()
|
| 866 |
|
|
|
|
| 867 |
can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE
|
| 868 |
if can_finish:
|
| 869 |
if st.button("I'm done chatting →", use_container_width=True):
|
|
|
|
| 886 |
"How willing would you be to buy this product now?",
|
| 887 |
WILLINGNESS_CHOICES,
|
| 888 |
index=None,
|
| 889 |
+
key=f"post_will_{idx}_{product['id']}",
|
| 890 |
)
|
| 891 |
|
| 892 |
if st.button("Next →", type="primary", use_container_width=True):
|
|
|
|
| 985 |
import pandas as pd
|
| 986 |
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 987 |
|
|
|
|
| 988 |
assignment_id = s.get("assignment_id", "")
|
| 989 |
turk_submit_to = s.get("turk_submit_to", "")
|
| 990 |
if assignment_id and turk_submit_to:
|