Spaces:
Sleeping
Sleeping
added user study
Browse files- requirements.txt +7 -3
- src/streamlit_app.py +965 -34
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.32.0
|
| 2 |
+
openai>=1.0.0
|
| 3 |
+
huggingface_hub>=0.20.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
filelock>=3.13.0
|
| 6 |
+
python-dotenv>=1.0.0
|
| 7 |
+
pandas>=2.0.0
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,971 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit App: AI Product Willingness User Study
|
| 3 |
+
=================================================
|
| 4 |
+
Run locally:
|
| 5 |
+
streamlit run app.py -- --category groceries
|
| 6 |
+
streamlit run app.py -- --category groceries --debug
|
| 7 |
|
| 8 |
+
On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
|
| 9 |
+
HF_TOKEN - HuggingFace token
|
| 10 |
+
TOGETHER_API_KEY - Together AI API key
|
| 11 |
+
DATASET_REPO_ID - HuggingFace dataset repo to upload results
|
| 12 |
+
CATEGORY - groceries | books | movies | health (default: groceries)
|
| 13 |
+
DEBUG_MODE - "true" to skip validation (optional)
|
| 14 |
"""
|
|
|
|
| 15 |
|
| 16 |
+
import asyncio
|
| 17 |
+
import concurrent.futures
|
| 18 |
+
import csv
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import re
|
| 23 |
+
import sys
|
| 24 |
+
import tempfile
|
| 25 |
+
import time
|
| 26 |
+
import uuid
|
| 27 |
+
from datetime import datetime
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import streamlit as st
|
| 31 |
+
from dotenv import load_dotenv
|
| 32 |
+
from filelock import FileLock
|
| 33 |
+
from huggingface_hub import HfApi
|
| 34 |
+
from openai import AsyncOpenAI
|
| 35 |
+
|
| 36 |
+
load_dotenv()
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# CLI args (supported locally; ignored on HF Spaces — use env vars instead)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
import argparse
|
| 42 |
+
parser = argparse.ArgumentParser(add_help=False)
|
| 43 |
+
parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None)
|
| 44 |
+
parser.add_argument("--debug", action="store_true", default=False)
|
| 45 |
+
cli_args, _ = parser.parse_known_args()
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Config (env vars take precedence, then CLI args, then defaults)
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries"
|
| 51 |
+
DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
|
| 52 |
+
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
|
| 53 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 54 |
+
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
|
| 55 |
+
MODEL_NAME = "openai/gpt-oss-20b"
|
| 56 |
+
|
| 57 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 58 |
+
DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 59 |
+
ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations")
|
| 60 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 61 |
+
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
CATEGORY_TO_HF = {
|
| 64 |
+
"books": "ehejin/amazon_books",
|
| 65 |
+
"groceries": "ehejin/amazon_Grocery_and_Gourmet_Food",
|
| 66 |
+
"movies": "ehejin/amazon_Movies_and_TV",
|
| 67 |
+
"health": "ehejin/amazon_Health_and_Household",
|
| 68 |
+
}
|
| 69 |
+
CATEGORY_DISPLAY = {
|
| 70 |
+
"books": "Books",
|
| 71 |
+
"groceries": "Grocery Products",
|
| 72 |
+
"movies": "Movies & TV",
|
| 73 |
+
"health": "Health & Household Products",
|
| 74 |
+
}
|
| 75 |
+
FAMILIARITY_USED_LABEL = {
|
| 76 |
+
"books": "Read it before",
|
| 77 |
+
"movies": "Watched it before",
|
| 78 |
+
"groceries": "Used it before",
|
| 79 |
+
"health": "Used it before",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
PRODUCTS_PER_USER = 5
|
| 83 |
+
MIN_TURNS = 3
|
| 84 |
+
MAX_TURNS = 10
|
| 85 |
+
|
| 86 |
+
DEBUG_DEMOGRAPHICS = {
|
| 87 |
+
"age": "30", "gender": "Female", "geographic_region": "West",
|
| 88 |
+
"education_level": "College graduate/some postgrad", "race": "White",
|
| 89 |
+
"us_citizen": "Yes", "marital_status": "Single",
|
| 90 |
+
"religion": "Agnostic", "religious_attendance": "Never",
|
| 91 |
+
"political_affiliation": "Independent", "income": "$50,000-$75,000",
|
| 92 |
+
"political_views": "Moderate", "household_size": "2",
|
| 93 |
+
"employment_status": "Full-time employment",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
WILLINGNESS_LABELS = {
|
| 97 |
+
1: "Definitely would not buy",
|
| 98 |
+
2: "Probably would not buy",
|
| 99 |
+
3: "Slightly unlikely to buy",
|
| 100 |
+
4: "Neutral",
|
| 101 |
+
5: "Slightly likely to buy",
|
| 102 |
+
6: "Probably would buy",
|
| 103 |
+
7: "Definitely would buy",
|
| 104 |
+
}
|
| 105 |
+
WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
|
| 106 |
+
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
# Dataset loading
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}.json")
|
| 111 |
+
ORDER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_order.json")
|
| 112 |
+
COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
|
| 113 |
+
COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@st.cache_resource
|
| 117 |
+
def download_and_cache_dataset():
|
| 118 |
+
if os.path.exists(LOCAL_DATA_PATH):
|
| 119 |
+
print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
|
| 120 |
+
return
|
| 121 |
+
print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} from HuggingFace...")
|
| 122 |
+
try:
|
| 123 |
+
from datasets import load_dataset
|
| 124 |
+
import huggingface_hub
|
| 125 |
+
if HF_TOKEN:
|
| 126 |
+
huggingface_hub.login(token=HF_TOKEN)
|
| 127 |
+
ds = load_dataset(CATEGORY_TO_HF[CATEGORY], split="train")
|
| 128 |
+
items = []
|
| 129 |
+
for row in ds:
|
| 130 |
+
meta = row.get("metadata", {})
|
| 131 |
+
def to_list(val):
|
| 132 |
+
if isinstance(val, list): return val
|
| 133 |
+
if isinstance(val, str): return [val] if val else []
|
| 134 |
+
return []
|
| 135 |
+
item = {
|
| 136 |
+
"id": str(uuid.uuid4()),
|
| 137 |
+
"title": meta.get("title", "") if isinstance(meta, dict) else "",
|
| 138 |
+
"description": to_list(meta.get("description", []) if isinstance(meta, dict) else []),
|
| 139 |
+
"features": to_list(meta.get("features", []) if isinstance(meta, dict) else []),
|
| 140 |
+
"price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
|
| 141 |
+
"category": CATEGORY,
|
| 142 |
+
}
|
| 143 |
+
items.append(item)
|
| 144 |
+
with open(LOCAL_DATA_PATH, "w") as f:
|
| 145 |
+
json.dump(items, f, indent=2)
|
| 146 |
+
print(f"[DATA] Cached {len(items)} items to {LOCAL_DATA_PATH}")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"[DATA] ERROR downloading dataset: {e}")
|
| 149 |
+
raise
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@st.cache_resource
|
| 153 |
+
def load_local_dataset():
|
| 154 |
+
with open(LOCAL_DATA_PATH, "r") as f:
|
| 155 |
+
return json.load(f)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@st.cache_resource
|
| 159 |
+
def ensure_shuffled_order(n_items):
|
| 160 |
+
if os.path.exists(ORDER_PATH):
|
| 161 |
+
with open(ORDER_PATH, "r") as f:
|
| 162 |
+
return json.load(f)
|
| 163 |
+
indices = list(range(n_items))
|
| 164 |
+
random.shuffle(indices)
|
| 165 |
+
with open(ORDER_PATH, "w") as f:
|
| 166 |
+
json.dump(indices, f)
|
| 167 |
+
return indices
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def assign_products(items, order, n=PRODUCTS_PER_USER):
|
| 171 |
+
lock = FileLock(COUNTER_LOCK_PATH)
|
| 172 |
+
with lock:
|
| 173 |
+
if os.path.exists(COUNTER_PATH):
|
| 174 |
+
with open(COUNTER_PATH, "r") as f:
|
| 175 |
+
counter = int(f.read().strip() or "0")
|
| 176 |
+
else:
|
| 177 |
+
counter = 0
|
| 178 |
+
total = len(order)
|
| 179 |
+
assigned_indices = [order[(counter + i) % total] for i in range(n)]
|
| 180 |
+
new_counter = (counter + n) % total
|
| 181 |
+
with open(COUNTER_PATH, "w") as f:
|
| 182 |
+
f.write(str(new_counter))
|
| 183 |
+
return [items[i] for i in assigned_indices]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# AI client
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
@st.cache_resource
|
| 190 |
+
def get_model_client():
|
| 191 |
+
return AsyncOpenAI(
|
| 192 |
+
base_url="https://api.together.xyz/v1",
|
| 193 |
+
api_key=TOGETHER_API_KEY,
|
| 194 |
+
timeout=60.0,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def call_model(messages: list) -> str:
|
| 199 |
+
async def _call():
|
| 200 |
+
try:
|
| 201 |
+
client = get_model_client()
|
| 202 |
+
response = await client.chat.completions.create(
|
| 203 |
+
model=MODEL_NAME,
|
| 204 |
+
messages=messages,
|
| 205 |
+
max_tokens=1000,
|
| 206 |
+
temperature=0.7,
|
| 207 |
+
top_p=0.9,
|
| 208 |
+
)
|
| 209 |
+
content = response.choices[0].message.content.strip()
|
| 210 |
+
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
| 211 |
+
return content
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"[MODEL] Error: {e}")
|
| 214 |
+
return f"[Model error: {e}]"
|
| 215 |
+
|
| 216 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
| 217 |
+
future = pool.submit(asyncio.run, _call())
|
| 218 |
+
return future.result()
|
| 219 |
+
|
| 220 |
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
# HuggingFace upload
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
@st.cache_resource
|
| 225 |
+
def get_hf_api():
|
| 226 |
+
api = HfApi(token=HF_TOKEN) if HF_TOKEN else HfApi()
|
| 227 |
+
if HF_TOKEN:
|
| 228 |
+
try:
|
| 229 |
+
api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset")
|
| 230 |
+
print(f"[HF] Repo {DATASET_REPO_ID} exists.")
|
| 231 |
+
except Exception as e:
|
| 232 |
+
if "404" in str(e) or "not found" in str(e).lower():
|
| 233 |
+
api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True)
|
| 234 |
+
print(f"[HF] Created repo {DATASET_REPO_ID}.")
|
| 235 |
+
else:
|
| 236 |
+
print(f"[HF] WARNING: {e}")
|
| 237 |
+
return api
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def save_and_upload(state: dict):
|
| 241 |
+
hf_api = get_hf_api()
|
| 242 |
+
worker_id = state.get("worker_id") or state.get("user_id", "anonymous")
|
| 243 |
+
submission_id = state.get("submission_id", str(uuid.uuid4()))
|
| 244 |
+
safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
|
| 245 |
+
filename = f"{submission_id}_{CATEGORY}.json"
|
| 246 |
+
folder = os.path.join(ANNOTATIONS_DIR, safe_worker)
|
| 247 |
+
os.makedirs(folder, exist_ok=True)
|
| 248 |
+
file_path = os.path.join(folder, filename)
|
| 249 |
+
with open(file_path, "w") as f:
|
| 250 |
+
json.dump(state, f, indent=2)
|
| 251 |
+
print(f"[SAVE] Wrote {file_path}")
|
| 252 |
+
if HF_TOKEN:
|
| 253 |
+
try:
|
| 254 |
+
hf_api.upload_file(
|
| 255 |
+
path_or_fileobj=file_path,
|
| 256 |
+
path_in_repo=f"{safe_worker}/{filename}",
|
| 257 |
+
repo_id=DATASET_REPO_ID,
|
| 258 |
+
repo_type="dataset",
|
| 259 |
+
)
|
| 260 |
+
print("[HF] Uploaded JSON.")
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"[HF] JSON upload error: {e}")
|
| 263 |
+
upload_csv_rows(state, hf_api, safe_worker, submission_id)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
|
| 267 |
+
demographics = state.get("demographics", {})
|
| 268 |
+
products = state.get("products", [])
|
| 269 |
+
header = [
|
| 270 |
+
"submission_id", "worker_id", "submission_time", "duration_seconds", "category",
|
| 271 |
+
"age", "gender", "geographic_region", "education_level", "race",
|
| 272 |
+
"us_citizen", "marital_status", "religion", "religious_attendance",
|
| 273 |
+
"political_affiliation", "income", "political_views", "household_size", "employment_status",
|
| 274 |
+
"product_index", "product_id", "title", "price", "familiarity",
|
| 275 |
+
"pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label",
|
| 276 |
+
"willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change",
|
| 277 |
+
]
|
| 278 |
+
rows = []
|
| 279 |
+
for i, prod in enumerate(products):
|
| 280 |
+
conv = prod.get("conversation", {})
|
| 281 |
+
refl = prod.get("reflection", {})
|
| 282 |
+
pre = prod.get("pre_willingness", "")
|
| 283 |
+
post = prod.get("post_willingness", "")
|
| 284 |
+
delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
|
| 285 |
+
row = [
|
| 286 |
+
submission_id, state.get("worker_id", ""),
|
| 287 |
+
state.get("meta", {}).get("submission_time", ""),
|
| 288 |
+
state.get("meta", {}).get("duration_seconds", ""),
|
| 289 |
+
CATEGORY,
|
| 290 |
+
demographics.get("age", ""), demographics.get("gender", ""),
|
| 291 |
+
demographics.get("geographic_region", ""), demographics.get("education_level", ""),
|
| 292 |
+
demographics.get("race", ""), demographics.get("us_citizen", ""),
|
| 293 |
+
demographics.get("marital_status", ""), demographics.get("religion", ""),
|
| 294 |
+
demographics.get("religious_attendance", ""), demographics.get("political_affiliation", ""),
|
| 295 |
+
demographics.get("income", ""), demographics.get("political_views", ""),
|
| 296 |
+
demographics.get("household_size", ""), demographics.get("employment_status", ""),
|
| 297 |
+
i + 1, prod.get("id", ""), prod.get("title", ""), prod.get("price", ""),
|
| 298 |
+
prod.get("familiarity", ""),
|
| 299 |
+
pre, WILLINGNESS_LABELS.get(pre, "") if isinstance(pre, int) else "",
|
| 300 |
+
post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "",
|
| 301 |
+
delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])),
|
| 302 |
+
refl.get("standout_moment", ""), refl.get("thinking_change", ""),
|
| 303 |
+
]
|
| 304 |
+
rows.append(row)
|
| 305 |
+
|
| 306 |
+
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 307 |
+
unique_id = uuid.uuid4().hex[:8]
|
| 308 |
+
csv_filename = f"csv_submissions/{timestamp_str}_{safe_worker}_{unique_id}.csv"
|
| 309 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8") as tmp:
|
| 310 |
+
tmp_path = tmp.name
|
| 311 |
+
writer = csv.writer(tmp)
|
| 312 |
+
writer.writerow(header)
|
| 313 |
+
writer.writerows(rows)
|
| 314 |
+
if HF_TOKEN:
|
| 315 |
+
try:
|
| 316 |
+
hf_api.upload_file(
|
| 317 |
+
path_or_fileobj=tmp_path,
|
| 318 |
+
path_in_repo=csv_filename,
|
| 319 |
+
repo_id=DATASET_REPO_ID,
|
| 320 |
+
repo_type="dataset",
|
| 321 |
+
)
|
| 322 |
+
print("[HF] Uploaded CSV rows.")
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f"[HF] CSV upload error: {e}")
|
| 325 |
+
os.unlink(tmp_path)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
# Prompt builders
|
| 330 |
+
# ---------------------------------------------------------------------------
|
| 331 |
+
def build_sales_system_prompt(product: dict) -> str:
|
| 332 |
+
title = product.get("title", "this product")
|
| 333 |
+
price = product.get("price", "N/A")
|
| 334 |
+
description = product.get("description", [])
|
| 335 |
+
features = product.get("features", [])
|
| 336 |
+
desc_text = " ".join(description) if description else "No description available."
|
| 337 |
+
feat_text = " ".join(features) if features else ""
|
| 338 |
+
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 339 |
+
return f"""You are a warm, knowledgeable sales agent for an Amazon product. Your goal is to convince the user to buy this product.
|
| 340 |
+
|
| 341 |
+
Product name: {title}
|
| 342 |
+
Product description and features: {desc_text} {feat_text}
|
| 343 |
+
Price: {price_str}
|
| 344 |
+
|
| 345 |
+
Rules:
|
| 346 |
+
- Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson
|
| 347 |
+
- Be persuasive and proactive: do not just answer questions; actively make the case for buying.
|
| 348 |
+
- Tailor your argument to what the user cares about or seems hesitant about.
|
| 349 |
+
- Use only the information provided in the title, description, features, and price.
|
| 350 |
+
- Do not fabricate reviews, ratings, statistics, comparisons, discounts, or product details.
|
| 351 |
+
- If the user raises a concern, respond directly and turn the conversation back to why the product is worth buying.
|
| 352 |
+
- Keep the reply concise and natural.
|
| 353 |
+
|
| 354 |
+
First message:
|
| 355 |
+
- Open with the product's strongest selling point.
|
| 356 |
+
- Explain why it is worth buying.
|
| 357 |
+
- End with a question that helps uncover the user's preferences, objections, or hesitation.
|
| 358 |
+
|
| 359 |
+
Follow-up rules:
|
| 360 |
+
- Acknowledge what the user said.
|
| 361 |
+
- Address their main objection directly with a concrete benefit grounded in the product info.
|
| 362 |
+
- Continue steering the conversation toward purchase.
|
| 363 |
+
- Usually end with a natural question that keeps the user engaged.
|
| 364 |
"""
|
| 365 |
|
| 366 |
+
|
| 367 |
+
def build_opening_user_message(product: dict) -> str:
|
| 368 |
+
return f'Tell me about this product and why I should buy it: "{product.get("title", "this product")}"'
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def parse_willingness(choice_str: str) -> int:
|
| 372 |
+
try:
|
| 373 |
+
return int(choice_str.split("(")[1].rstrip(")"))
|
| 374 |
+
except Exception:
|
| 375 |
+
return 4
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def get_familiarity_choices():
|
| 379 |
+
used_label = FAMILIARITY_USED_LABEL.get(CATEGORY, "Used it before")
|
| 380 |
+
return [
|
| 381 |
+
"Never heard of it",
|
| 382 |
+
"Heard of it, but not used/purchased",
|
| 383 |
+
used_label,
|
| 384 |
+
"Purchased it before",
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# ---------------------------------------------------------------------------
|
| 389 |
+
# State initialisation
|
| 390 |
+
# ---------------------------------------------------------------------------
|
| 391 |
+
def init_state():
|
| 392 |
+
download_and_cache_dataset()
|
| 393 |
+
items = load_local_dataset()
|
| 394 |
+
order = ensure_shuffled_order(len(items))
|
| 395 |
+
assigned = assign_products(items, order, PRODUCTS_PER_USER)
|
| 396 |
+
|
| 397 |
+
# Read MTurk query params if available
|
| 398 |
+
try:
|
| 399 |
+
params = st.query_params
|
| 400 |
+
except Exception:
|
| 401 |
+
params = {}
|
| 402 |
+
|
| 403 |
+
return {
|
| 404 |
+
"submission_id": str(uuid.uuid4()),
|
| 405 |
+
"user_id": str(uuid.uuid4()),
|
| 406 |
+
"worker_id": params.get("workerId", ""),
|
| 407 |
+
"assignment_id": params.get("assignmentId", ""),
|
| 408 |
+
"hit_id": params.get("hitId", ""),
|
| 409 |
+
"turk_submit_to": params.get("turkSubmitTo", ""),
|
| 410 |
+
"start_time": time.time(),
|
| 411 |
+
"category": CATEGORY,
|
| 412 |
+
"demographics": {},
|
| 413 |
+
"products": [
|
| 414 |
+
{
|
| 415 |
+
"id": p.get("id", str(uuid.uuid4())),
|
| 416 |
+
"title": p.get("title", ""),
|
| 417 |
+
"description": p.get("description", []),
|
| 418 |
+
"features": p.get("features", []),
|
| 419 |
+
"price": p.get("price", "N/A"),
|
| 420 |
+
"familiarity": None,
|
| 421 |
+
"pre_willingness": None,
|
| 422 |
+
"post_willingness": None,
|
| 423 |
+
"willingness_delta": None,
|
| 424 |
+
"conversation": {
|
| 425 |
+
"system_prompt": "",
|
| 426 |
+
"opening_user_message": "",
|
| 427 |
+
"turns": [],
|
| 428 |
+
"num_turns": 0,
|
| 429 |
+
},
|
| 430 |
+
"reflection": {},
|
| 431 |
+
}
|
| 432 |
+
for p in assigned
|
| 433 |
+
],
|
| 434 |
+
"current_product_index": 0,
|
| 435 |
+
"screen": "welcome", # screens: welcome | demographics | product_intro | chat | post_will | reflection | done
|
| 436 |
+
"meta": {},
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# ---------------------------------------------------------------------------
|
| 441 |
+
# CSS
|
| 442 |
+
# ---------------------------------------------------------------------------
|
| 443 |
+
def inject_css():
|
| 444 |
+
st.markdown("""
|
| 445 |
+
<style>
|
| 446 |
+
/* Hide Streamlit chrome */
|
| 447 |
+
#MainMenu, footer, header { visibility: hidden; }
|
| 448 |
+
.block-container { max-width: 820px; padding-top: 2rem; }
|
| 449 |
+
|
| 450 |
+
/* Product card */
|
| 451 |
+
.product-card {
|
| 452 |
+
border: 2px solid #2563eb;
|
| 453 |
+
border-radius: 10px;
|
| 454 |
+
padding: 1rem 1.25rem;
|
| 455 |
+
background: #f0f6ff;
|
| 456 |
+
margin-bottom: 0.75rem;
|
| 457 |
+
}
|
| 458 |
+
.pc-header {
|
| 459 |
+
display: flex;
|
| 460 |
+
justify-content: space-between;
|
| 461 |
+
align-items: flex-start;
|
| 462 |
+
margin-bottom: 0.6rem;
|
| 463 |
+
gap: 1rem;
|
| 464 |
+
}
|
| 465 |
+
.pc-title {
|
| 466 |
+
font-size: 1.05rem;
|
| 467 |
+
font-weight: 700;
|
| 468 |
+
color: #1a1a2e;
|
| 469 |
+
line-height: 1.35;
|
| 470 |
+
flex: 1;
|
| 471 |
+
}
|
| 472 |
+
.pc-price {
|
| 473 |
+
font-size: 1.2rem;
|
| 474 |
+
font-weight: 800;
|
| 475 |
+
color: #16a34a;
|
| 476 |
+
white-space: nowrap;
|
| 477 |
+
}
|
| 478 |
+
.pc-section { margin-top: 0.5rem; }
|
| 479 |
+
.pc-section-title {
|
| 480 |
+
font-weight: 600;
|
| 481 |
+
font-size: 0.85rem;
|
| 482 |
+
color: #475569;
|
| 483 |
+
text-transform: uppercase;
|
| 484 |
+
letter-spacing: 0.04em;
|
| 485 |
+
margin-bottom: 0.3rem;
|
| 486 |
+
}
|
| 487 |
+
.pc-list {
|
| 488 |
+
margin: 0;
|
| 489 |
+
padding-left: 1.2rem;
|
| 490 |
+
font-size: 0.92rem;
|
| 491 |
+
color: #334155;
|
| 492 |
+
line-height: 1.5;
|
| 493 |
+
}
|
| 494 |
+
.pc-list li { margin-bottom: 0.25rem; }
|
| 495 |
+
|
| 496 |
+
/* Progress bar */
|
| 497 |
+
.progress-wrap {
|
| 498 |
+
background: #e2e8f0;
|
| 499 |
+
border-radius: 99px;
|
| 500 |
+
height: 8px;
|
| 501 |
+
margin-bottom: 0.25rem;
|
| 502 |
+
overflow: hidden;
|
| 503 |
+
}
|
| 504 |
+
.progress-fill {
|
| 505 |
+
background: #2563eb;
|
| 506 |
+
height: 100%;
|
| 507 |
+
border-radius: 99px;
|
| 508 |
+
}
|
| 509 |
+
.progress-label {
|
| 510 |
+
font-size: 0.82rem;
|
| 511 |
+
color: #64748b;
|
| 512 |
+
text-align: right;
|
| 513 |
+
margin-bottom: 1rem;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
/* Chat bubbles */
|
| 517 |
+
.chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; }
|
| 518 |
+
.bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; }
|
| 519 |
+
.bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; }
|
| 520 |
+
.bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; }
|
| 521 |
+
.bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; }
|
| 522 |
+
|
| 523 |
+
/* Compact product banner above chat */
|
| 524 |
+
.chat-product-banner {
|
| 525 |
+
border: 1.5px solid #93c5fd;
|
| 526 |
+
border-radius: 8px;
|
| 527 |
+
padding: 0.6rem 1rem;
|
| 528 |
+
background: #eff6ff;
|
| 529 |
+
margin-bottom: 0.75rem;
|
| 530 |
+
font-size: 0.88rem;
|
| 531 |
+
color: #1d4ed8;
|
| 532 |
+
font-weight: 600;
|
| 533 |
+
cursor: pointer;
|
| 534 |
+
}
|
| 535 |
+
</style>
|
| 536 |
+
""", unsafe_allow_html=True)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# ---------------------------------------------------------------------------
|
| 540 |
+
# UI helpers
|
| 541 |
+
# ---------------------------------------------------------------------------
|
| 542 |
+
def render_product_card_html(product: dict, compact: bool = False) -> str:
|
| 543 |
+
title = product.get("title", "Unknown Product")
|
| 544 |
+
price = product.get("price", "N/A")
|
| 545 |
+
description = product.get("description", [])
|
| 546 |
+
features = product.get("features", [])
|
| 547 |
+
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 548 |
+
|
| 549 |
+
desc_html = ""
|
| 550 |
+
if description:
|
| 551 |
+
items_html = "".join(f"<li>{d}</li>" for d in description if d)
|
| 552 |
+
desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><ul class="pc-list">{items_html}</ul></div>'
|
| 553 |
+
|
| 554 |
+
feat_html = ""
|
| 555 |
+
if features:
|
| 556 |
+
items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
|
| 557 |
+
feat_html = f'<div class="pc-section"><div class="pc-section-title">✨ Features</div><ul class="pc-list">{items_html}</ul></div>'
|
| 558 |
+
|
| 559 |
+
max_h = "max-height:240px;overflow-y:auto;" if compact else ""
|
| 560 |
+
return f"""
|
| 561 |
+
<div class="product-card" style="{max_h}">
|
| 562 |
+
<div class="pc-header">
|
| 563 |
+
<div class="pc-title">{title}</div>
|
| 564 |
+
<div class="pc-price">{price_str}</div>
|
| 565 |
+
</div>
|
| 566 |
+
{desc_html}
|
| 567 |
+
{feat_html}
|
| 568 |
+
</div>"""
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def render_progress(current: int, total: int = PRODUCTS_PER_USER):
|
| 572 |
+
pct = int((current / total) * 100)
|
| 573 |
+
st.markdown(f"""
|
| 574 |
+
<div class="progress-wrap"><div class="progress-fill" style="width:{pct}%"></div></div>
|
| 575 |
+
<div class="progress-label">Product {current} of {total}</div>
|
| 576 |
+
""", unsafe_allow_html=True)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def render_chat_history(turns: list):
|
| 580 |
+
html = '<div class="chat-wrap">'
|
| 581 |
+
for turn in turns:
|
| 582 |
+
role = turn.get("role", "")
|
| 583 |
+
content = turn.get("content", "")
|
| 584 |
+
if role == "assistant":
|
| 585 |
+
html += f'<div class="bubble-label">🤖 AI Sales Agent</div><div class="bubble bubble-ai">{content}</div>'
|
| 586 |
+
elif role == "user":
|
| 587 |
+
html += f'<div class="bubble-label" style="text-align:right">You</div><div class="bubble bubble-user">{content}</div>'
|
| 588 |
+
html += "</div>"
|
| 589 |
+
st.markdown(html, unsafe_allow_html=True)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# ---------------------------------------------------------------------------
|
| 593 |
+
# Screen renderers
|
| 594 |
+
# ---------------------------------------------------------------------------
|
| 595 |
+
def screen_welcome(s):
|
| 596 |
+
st.markdown(f"# 🛒 Product Evaluation Study")
|
| 597 |
+
st.markdown(
|
| 598 |
+
f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
|
| 599 |
+
"For each product you will:\n"
|
| 600 |
+
"1. Rate how familiar you are with the product\n"
|
| 601 |
+
"2. Rate how willing you are to buy it\n"
|
| 602 |
+
"3. Chat with an AI about the product (**at least 3 exchanges**)\n"
|
| 603 |
+
"4. Rate your willingness to buy it again\n"
|
| 604 |
+
"5. Answer two brief reflection questions\n\n"
|
| 605 |
+
"After all 5 products, you're done! The study takes about **20–30 minutes**. "
|
| 606 |
+
"Thank you for participating!"
|
| 607 |
+
)
|
| 608 |
+
if st.button("Begin →", type="primary", use_container_width=True):
|
| 609 |
+
if DEBUG_MODE:
|
| 610 |
+
s["demographics"] = DEBUG_DEMOGRAPHICS.copy()
|
| 611 |
+
s["screen"] = "product_intro"
|
| 612 |
+
else:
|
| 613 |
+
s["screen"] = "demographics"
|
| 614 |
+
st.rerun()
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def screen_demographics(s):
|
| 618 |
+
st.markdown("## Demographics — About You")
|
| 619 |
+
st.markdown("All fields are required before you can proceed.")
|
| 620 |
+
|
| 621 |
+
age = st.text_input("Age (years)", placeholder="e.g. 34")
|
| 622 |
+
gender = st.selectbox("Gender", ["", "Female", "Male"])
|
| 623 |
+
geographic_region = st.selectbox("Geographic region", ["", "West", "South", "Midwest", "Northeast", "Pacific"])
|
| 624 |
+
education_level = st.selectbox("Highest education level", [
|
| 625 |
+
"", "Less than high school", "High school graduate",
|
| 626 |
+
"Some college, no degree", "Associate's degree",
|
| 627 |
+
"College graduate/some postgrad", "Postgraduate",
|
| 628 |
+
])
|
| 629 |
+
race = st.selectbox("Race / ethnicity", ["", "Asian", "Hispanic", "White", "Black", "Other"])
|
| 630 |
+
us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"])
|
| 631 |
+
marital_status = st.selectbox("Marital status", [
|
| 632 |
+
"", "Never been married", "Married", "Living with a partner",
|
| 633 |
+
"Divorced", "Separated", "Widowed",
|
| 634 |
+
])
|
| 635 |
+
religion = st.selectbox("Religion", [
|
| 636 |
+
"", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish",
|
| 637 |
+
"Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other",
|
| 638 |
+
])
|
| 639 |
+
religious_attendance = st.selectbox("How often do you attend religious services?", [
|
| 640 |
+
"", "Never", "Seldom", "A few times a year", "Once or twice a month",
|
| 641 |
+
"Once a week", "More than once a week",
|
| 642 |
+
])
|
| 643 |
+
political_affiliation = st.selectbox("Political affiliation", [
|
| 644 |
+
"", "Democrat", "Republican", "Independent", "Something else",
|
| 645 |
+
])
|
| 646 |
+
income = st.selectbox("Household income", [
|
| 647 |
+
"", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000",
|
| 648 |
+
"$75,000-$100,000", "$100,000 or more",
|
| 649 |
+
])
|
| 650 |
+
political_views = st.selectbox("Political views", [
|
| 651 |
+
"", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative",
|
| 652 |
+
])
|
| 653 |
+
household_size = st.selectbox("Household size", ["", "1", "2", "3", "4", "More than 4"])
|
| 654 |
+
employment_status = st.selectbox("Employment status", [
|
| 655 |
+
"", "Full-time employment", "Part-time employment", "Self-employed",
|
| 656 |
+
"Unemployed", "Retired", "Home-maker", "Student",
|
| 657 |
+
])
|
| 658 |
+
|
| 659 |
+
if st.button("Next →", type="primary", use_container_width=True):
|
| 660 |
+
fields = [age, gender, geographic_region, education_level, race, us_citizen,
|
| 661 |
+
marital_status, religion, religious_attendance, political_affiliation,
|
| 662 |
+
income, political_views, household_size, employment_status]
|
| 663 |
+
if not all([f and (f.strip() if isinstance(f, str) else f) for f in fields]):
|
| 664 |
+
st.error("⚠️ Please complete all fields.")
|
| 665 |
+
return
|
| 666 |
+
if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120):
|
| 667 |
+
st.error("⚠️ Please enter a valid age.")
|
| 668 |
+
return
|
| 669 |
+
s["demographics"] = {
|
| 670 |
+
"age": age.strip(), "gender": gender, "geographic_region": geographic_region,
|
| 671 |
+
"education_level": education_level, "race": race, "us_citizen": us_citizen,
|
| 672 |
+
"marital_status": marital_status, "religion": religion,
|
| 673 |
+
"religious_attendance": religious_attendance, "political_affiliation": political_affiliation,
|
| 674 |
+
"income": income, "political_views": political_views,
|
| 675 |
+
"household_size": household_size, "employment_status": employment_status,
|
| 676 |
+
}
|
| 677 |
+
s["screen"] = "product_intro"
|
| 678 |
+
st.rerun()
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def screen_product_intro(s):
|
| 682 |
+
idx = s["current_product_index"]
|
| 683 |
+
product = s["products"][idx]
|
| 684 |
+
render_progress(idx + 1)
|
| 685 |
+
st.markdown("## Product Evaluation")
|
| 686 |
+
st.markdown("Please read the product information carefully, then answer the two questions below.")
|
| 687 |
+
st.markdown(render_product_card_html(product), unsafe_allow_html=True)
|
| 688 |
+
|
| 689 |
+
familiarity_val = st.radio(
|
| 690 |
+
"How familiar are you with this product?",
|
| 691 |
+
get_familiarity_choices(),
|
| 692 |
+
index=None,
|
| 693 |
+
key=f"familiarity_{idx}",
|
| 694 |
+
)
|
| 695 |
+
pre_will_val = st.radio(
|
| 696 |
+
"How willing would you be to buy this product?",
|
| 697 |
+
WILLINGNESS_CHOICES,
|
| 698 |
+
index=None,
|
| 699 |
+
key=f"pre_will_{idx}",
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
if st.button("Start Chat →", type="primary", use_container_width=True):
|
| 703 |
+
if not DEBUG_MODE:
|
| 704 |
+
if not familiarity_val:
|
| 705 |
+
st.error("⚠️ Please rate your familiarity.")
|
| 706 |
+
return
|
| 707 |
+
if not pre_will_val:
|
| 708 |
+
st.error("⚠️ Please rate your willingness to buy.")
|
| 709 |
+
return
|
| 710 |
+
familiarity_val = familiarity_val or get_familiarity_choices()[0]
|
| 711 |
+
pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
|
| 712 |
+
|
| 713 |
+
pre_val = parse_willingness(pre_will_val)
|
| 714 |
+
s["products"][idx]["familiarity"] = familiarity_val
|
| 715 |
+
s["products"][idx]["pre_willingness"] = pre_val
|
| 716 |
+
s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val]
|
| 717 |
+
|
| 718 |
+
# Get opening AI message
|
| 719 |
+
system_prompt = build_sales_system_prompt(product)
|
| 720 |
+
opening_user_msg = build_opening_user_message(product)
|
| 721 |
+
messages = [
|
| 722 |
+
{"role": "system", "content": system_prompt},
|
| 723 |
+
{"role": "user", "content": opening_user_msg},
|
| 724 |
+
]
|
| 725 |
+
with st.spinner("Starting conversation…"):
|
| 726 |
+
ai_reply = call_model(messages)
|
| 727 |
+
|
| 728 |
+
s["products"][idx]["conversation"]["system_prompt"] = system_prompt
|
| 729 |
+
s["products"][idx]["conversation"]["opening_user_message"] = opening_user_msg
|
| 730 |
+
s["products"][idx]["conversation"]["turns"] = [
|
| 731 |
+
{"turn_index": 0, "role": "assistant", "content": ai_reply,
|
| 732 |
+
"timestamp": time.time(), "model": MODEL_NAME}
|
| 733 |
+
]
|
| 734 |
+
s["products"][idx]["conversation"]["num_turns"] = 0
|
| 735 |
+
s["screen"] = "chat"
|
| 736 |
+
st.rerun()
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def screen_chat(s):
|
| 740 |
+
idx = s["current_product_index"]
|
| 741 |
+
product = s["products"][idx]
|
| 742 |
+
conv = s["products"][idx]["conversation"]
|
| 743 |
+
|
| 744 |
+
render_progress(idx + 1)
|
| 745 |
+
st.markdown("## Chat with the AI")
|
| 746 |
+
|
| 747 |
+
# Compact product banner
|
| 748 |
+
title = product.get("title", "Product")
|
| 749 |
+
price = product.get("price", "N/A")
|
| 750 |
+
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 751 |
+
with st.expander(f"📦 {title} — {price_str} (click to expand product details)"):
|
| 752 |
+
st.markdown(render_product_card_html(product, compact=True), unsafe_allow_html=True)
|
| 753 |
+
|
| 754 |
+
num_turns = conv["num_turns"]
|
| 755 |
+
st.markdown(
|
| 756 |
+
f"The AI is trying to convince you to buy this product. "
|
| 757 |
+
f"Ask questions, push back, or explore your interest. "
|
| 758 |
+
f"You need at least **{MIN_TURNS} exchanges** before you can move on."
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
# Chat history (only user/assistant turns, not the opening system exchange)
|
| 762 |
+
display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")]
|
| 763 |
+
render_chat_history(display_turns)
|
| 764 |
+
|
| 765 |
+
# Turn counter
|
| 766 |
+
if num_turns >= MAX_TURNS:
|
| 767 |
+
st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.")
|
| 768 |
+
else:
|
| 769 |
+
st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}")
|
| 770 |
+
|
| 771 |
+
# Input
|
| 772 |
+
if num_turns < MAX_TURNS:
|
| 773 |
+
user_msg = st.text_area("Your response:", placeholder="Type your response here…", height=100, key=f"chat_input_{idx}_{num_turns}")
|
| 774 |
+
col1, col2 = st.columns([3, 1])
|
| 775 |
+
with col2:
|
| 776 |
+
send_clicked = st.button("Send", type="primary", use_container_width=True)
|
| 777 |
+
if send_clicked:
|
| 778 |
+
if not user_msg or not user_msg.strip():
|
| 779 |
+
st.error("⚠️ Please type a message.")
|
| 780 |
+
return
|
| 781 |
+
if len(user_msg.strip().split()) < 5 and not DEBUG_MODE:
|
| 782 |
+
st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).")
|
| 783 |
+
return
|
| 784 |
+
user_msg = user_msg.strip()
|
| 785 |
+
messages = [{"role": "system", "content": conv["system_prompt"]},
|
| 786 |
+
{"role": "user", "content": conv["opening_user_message"]}]
|
| 787 |
+
for turn in conv["turns"]:
|
| 788 |
+
messages.append({"role": turn["role"], "content": turn["content"]})
|
| 789 |
+
messages.append({"role": "user", "content": user_msg})
|
| 790 |
+
with st.spinner("AI is responding…"):
|
| 791 |
+
ai_reply = call_model(messages)
|
| 792 |
+
conv["turns"].append({"turn_index": len(conv["turns"]), "role": "user",
|
| 793 |
+
"content": user_msg, "timestamp": time.time()})
|
| 794 |
+
conv["turns"].append({"turn_index": len(conv["turns"]), "role": "assistant",
|
| 795 |
+
"content": ai_reply, "timestamp": time.time(), "model": MODEL_NAME})
|
| 796 |
+
conv["num_turns"] = num_turns + 1
|
| 797 |
+
s["products"][idx]["conversation"] = conv
|
| 798 |
+
st.rerun()
|
| 799 |
+
|
| 800 |
+
# Done button
|
| 801 |
+
can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE
|
| 802 |
+
if can_finish:
|
| 803 |
+
if st.button("I'm done chatting →", use_container_width=True):
|
| 804 |
+
s["screen"] = "post_will"
|
| 805 |
+
st.rerun()
|
| 806 |
+
else:
|
| 807 |
+
st.button("I'm done chatting →", disabled=True, use_container_width=True,
|
| 808 |
+
help=f"Complete at least {MIN_TURNS} exchanges first.")
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def screen_post_willingness(s):
|
| 812 |
+
idx = s["current_product_index"]
|
| 813 |
+
product = s["products"][idx]
|
| 814 |
+
render_progress(idx + 1)
|
| 815 |
+
st.markdown("## Your View Now")
|
| 816 |
+
st.markdown("Now that you've chatted with the AI, rate your willingness to buy again.")
|
| 817 |
+
st.markdown(render_product_card_html(product), unsafe_allow_html=True)
|
| 818 |
+
|
| 819 |
+
post_will_val = st.radio(
|
| 820 |
+
"How willing would you be to buy this product now?",
|
| 821 |
+
WILLINGNESS_CHOICES,
|
| 822 |
+
index=None,
|
| 823 |
+
key=f"post_will_{idx}",
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
if st.button("Next →", type="primary", use_container_width=True):
|
| 827 |
+
if not post_will_val and not DEBUG_MODE:
|
| 828 |
+
st.error("⚠️ Please rate your willingness to buy.")
|
| 829 |
+
return
|
| 830 |
+
post_will_val = post_will_val or WILLINGNESS_CHOICES[3]
|
| 831 |
+
post_val = parse_willingness(post_will_val)
|
| 832 |
+
pre_val = s["products"][idx].get("pre_willingness", 4)
|
| 833 |
+
delta = post_val - pre_val
|
| 834 |
+
s["products"][idx]["post_willingness"] = post_val
|
| 835 |
+
s["products"][idx]["post_willingness_label"] = WILLINGNESS_LABELS[post_val]
|
| 836 |
+
s["products"][idx]["willingness_delta"] = delta
|
| 837 |
+
s["screen"] = "reflection"
|
| 838 |
+
st.rerun()
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def screen_reflection(s):
|
| 842 |
+
idx = s["current_product_index"]
|
| 843 |
+
render_progress(idx + 1)
|
| 844 |
+
st.markdown("## Reflection")
|
| 845 |
+
|
| 846 |
+
standout = st.text_area(
|
| 847 |
+
"What did the AI say that stood out to you most?",
|
| 848 |
+
placeholder="Describe a specific argument, question, or moment from the conversation…",
|
| 849 |
+
height=120,
|
| 850 |
+
key=f"standout_{idx}",
|
| 851 |
+
)
|
| 852 |
+
thinking_change = st.text_area(
|
| 853 |
+
"How did your thinking about this product change (or not change) during the chat? Why?",
|
| 854 |
+
placeholder="Be as specific as you can…",
|
| 855 |
+
height=120,
|
| 856 |
+
key=f"thinking_{idx}",
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
next_label = "Next Product →" if idx + 1 < PRODUCTS_PER_USER else "Submit Study →"
|
| 860 |
+
if st.button(next_label, type="primary", use_container_width=True):
|
| 861 |
+
if not DEBUG_MODE:
|
| 862 |
+
if not standout or not standout.strip():
|
| 863 |
+
st.error("⚠️ Please answer the first reflection question.")
|
| 864 |
+
return
|
| 865 |
+
if len(standout.strip().split()) < 10:
|
| 866 |
+
st.error(f"⚠️ Please write at least 10 words for the first question ({len(standout.strip().split())} so far).")
|
| 867 |
+
return
|
| 868 |
+
if not thinking_change or not thinking_change.strip():
|
| 869 |
+
st.error("⚠️ Please answer the second reflection question.")
|
| 870 |
+
return
|
| 871 |
+
if len(thinking_change.strip().split()) < 10:
|
| 872 |
+
st.error(f"⚠️ Please write at least 10 words for the second question ({len(thinking_change.strip().split())} so far).")
|
| 873 |
+
return
|
| 874 |
+
|
| 875 |
+
standout = (standout or "").strip() or "[debug placeholder]"
|
| 876 |
+
thinking_change = (thinking_change or "").strip() or "[debug placeholder]"
|
| 877 |
+
s["products"][idx]["reflection"] = {
|
| 878 |
+
"standout_moment": standout,
|
| 879 |
+
"thinking_change": thinking_change,
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
next_idx = idx + 1
|
| 883 |
+
s["current_product_index"] = next_idx
|
| 884 |
+
|
| 885 |
+
if next_idx >= PRODUCTS_PER_USER:
|
| 886 |
+
end_time = time.time()
|
| 887 |
+
s["meta"] = {
|
| 888 |
+
"submission_time": end_time,
|
| 889 |
+
"duration_seconds": round(end_time - s.get("start_time", end_time), 1),
|
| 890 |
+
"model": MODEL_NAME,
|
| 891 |
+
"category": CATEGORY,
|
| 892 |
+
}
|
| 893 |
+
with st.spinner("Saving your responses…"):
|
| 894 |
+
save_and_upload(s)
|
| 895 |
+
s["screen"] = "done"
|
| 896 |
+
else:
|
| 897 |
+
s["screen"] = "product_intro"
|
| 898 |
+
st.rerun()
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
def screen_done(s):
|
| 902 |
+
st.markdown("## ✅ Study Complete!")
|
| 903 |
+
st.markdown("**Thank you for completing the study!**")
|
| 904 |
+
st.markdown(f"Here's a summary of how your willingness changed across the {PRODUCTS_PER_USER} products:")
|
| 905 |
+
|
| 906 |
+
rows = []
|
| 907 |
+
for i, p in enumerate(s["products"]):
|
| 908 |
+
pre = p.get("pre_willingness", "?")
|
| 909 |
+
post = p.get("post_willingness", "?")
|
| 910 |
+
delta = p.get("willingness_delta", 0)
|
| 911 |
+
arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️")
|
| 912 |
+
rows.append({
|
| 913 |
+
"#": i + 1,
|
| 914 |
+
"Product": p.get("title", "")[:60] + ("…" if len(p.get("title", "")) > 60 else ""),
|
| 915 |
+
"Before": WILLINGNESS_LABELS.get(pre, str(pre)),
|
| 916 |
+
"After": WILLINGNESS_LABELS.get(post, str(post)),
|
| 917 |
+
"Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–",
|
| 918 |
+
})
|
| 919 |
+
import pandas as pd
|
| 920 |
+
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 921 |
+
|
| 922 |
+
# MTurk submit button
|
| 923 |
+
assignment_id = s.get("assignment_id", "")
|
| 924 |
+
turk_submit_to = s.get("turk_submit_to", "")
|
| 925 |
+
if assignment_id and turk_submit_to:
|
| 926 |
+
submit_url = f"{turk_submit_to}/mturk/externalSubmit"
|
| 927 |
+
submission_id = s.get("submission_id", "")
|
| 928 |
+
st.markdown(f"""
|
| 929 |
+
<form id="mturk-submit-form" method="POST" action="{submit_url}">
|
| 930 |
+
<input type="hidden" name="assignmentId" value="{assignment_id}" />
|
| 931 |
+
<input type="hidden" name="submission_id" value="{submission_id}" />
|
| 932 |
+
<button type="submit" style="
|
| 933 |
+
background:#2563eb; color:white; border:none; padding:12px 28px;
|
| 934 |
+
font-size:1rem; border-radius:6px; cursor:pointer; margin-top:12px;">
|
| 935 |
+
✅ Submit to MTurk
|
| 936 |
+
</button>
|
| 937 |
+
</form>
|
| 938 |
+
""", unsafe_allow_html=True)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
# ---------------------------------------------------------------------------
|
| 942 |
+
# Main
|
| 943 |
+
# ---------------------------------------------------------------------------
|
| 944 |
+
def main():
|
| 945 |
+
st.set_page_config(page_title="Product Study", page_icon="🛒", layout="centered")
|
| 946 |
+
inject_css()
|
| 947 |
+
|
| 948 |
+
if "study_state" not in st.session_state:
|
| 949 |
+
st.session_state.study_state = init_state()
|
| 950 |
+
|
| 951 |
+
s = st.session_state.study_state
|
| 952 |
+
screen = s.get("screen", "welcome")
|
| 953 |
+
|
| 954 |
+
if screen == "welcome":
|
| 955 |
+
screen_welcome(s)
|
| 956 |
+
elif screen == "demographics":
|
| 957 |
+
screen_demographics(s)
|
| 958 |
+
elif screen == "product_intro":
|
| 959 |
+
screen_product_intro(s)
|
| 960 |
+
elif screen == "chat":
|
| 961 |
+
screen_chat(s)
|
| 962 |
+
elif screen == "post_will":
|
| 963 |
+
screen_post_willingness(s)
|
| 964 |
+
elif screen == "reflection":
|
| 965 |
+
screen_reflection(s)
|
| 966 |
+
elif screen == "done":
|
| 967 |
+
screen_done(s)
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
if __name__ == "__main__":
|
| 971 |
+
main()
|