Spaces:
Sleeping
Sleeping
| """ | |
| Streamlit App: AI Product Willingness User Study | |
| ================================================= | |
| Run locally (single category): | |
| streamlit run src/streamlit_app.py -- --category groceries | |
| streamlit run src/streamlit_app.py -- --category groceries --debug | |
| Run locally (mixed mode — movies + groceries): | |
| streamlit run src/streamlit_app.py -- --mode mixed | |
| streamlit run src/streamlit_app.py -- --mode mixed --debug | |
| On HuggingFace Spaces, set these environment variables in Space Settings → Variables: | |
| HF_TOKEN - HuggingFace token | |
| TINKER_API_KEY - Tinker AI API key | |
| DATASET_REPO_ID - HuggingFace dataset repo to upload results | |
| CATEGORY - groceries | books | movies | health (single-category mode) | |
| MODE - mixed (overrides CATEGORY; runs movies + groceries together) | |
| DEBUG_MODE - "true" to skip validation (optional) | |
| """ | |
| import csv | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import tempfile | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from filelock import FileLock | |
| from huggingface_hub import HfApi | |
| load_dotenv() | |
| # --------------------------------------------------------------------------- | |
| # CLI args | |
| # --------------------------------------------------------------------------- | |
| import argparse | |
| parser = argparse.ArgumentParser(add_help=False) | |
| parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None) | |
| parser.add_argument("--mode", choices=["mixed"], default=None) | |
| parser.add_argument("--debug", action="store_true", default=False) | |
| cli_args, _ = parser.parse_known_args() | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| MODE = os.getenv("MODE") or cli_args.mode # "mixed" or None | |
| CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries" # used only in single-category mode | |
| DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug | |
| DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| TINKER_API_KEY = os.getenv("TINKER_API_KEY") | |
| MODEL_NAME = "openai/gpt-oss-20b" | |
| # --------------------------------------------------------------------------- | |
| # Mixed-mode constants | |
| # --------------------------------------------------------------------------- | |
| # In mixed mode these two categories are always used together | |
| MIXED_CATEGORIES = ["movies", "groceries"] | |
| # Each category contributes this many items to the shared pool of 100 | |
| MIXED_SUBSET_SIZE = 50 # 50 movies + 50 groceries = 100 total | |
| SINGLE_SUBSET_SIZE = 100 # legacy single-category mode | |
| # --------------------------------------------------------------------------- | |
| # Prolific config | |
| # --------------------------------------------------------------------------- | |
| PROLIFIC_COMPLETION_URL = "https://app.prolific.com/submissions/complete?cc=CYC7ALM1" | |
| PROLIFIC_COMPLETION_CODE = "CYC7ALM1" | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_DIR = os.path.join(BASE_DIR, "data") | |
| ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.makedirs(ANNOTATIONS_DIR, exist_ok=True) | |
| CATEGORY_TO_HF = { | |
| "books": "ehejin/amazon_books", | |
| "groceries": "ehejin/amazon_Grocery_and_Gourmet_Food", | |
| "movies": "ehejin/amazon_Movies_and_TV", | |
| "health": "ehejin/amazon_Health_and_Household", | |
| } | |
| CATEGORY_DISPLAY = { | |
| "books": "Books", | |
| "groceries": "Grocery Products", | |
| "movies": "Movies & TV", | |
| "health": "Health & Household Products", | |
| } | |
| # Per-product familiarity label (depends on the individual product's category) | |
| FAMILIARITY_USED_LABEL = { | |
| "books": "Read it before", | |
| "movies": "Watched it before", | |
| "groceries": "Used it before", | |
| "health": "Used it before", | |
| } | |
| PRODUCTS_PER_USER = 5 | |
| MIN_TURNS = 3 | |
| MAX_TURNS = 10 | |
| # Familiarity values that trigger a product swap | |
| SWAP_FAMILIARITY = {"Purchased it before"} | |
| DEBUG_DEMOGRAPHICS = { | |
| "age": "30", "gender": "Female", "geographic_region": "West", | |
| "education_level": "College graduate/some postgrad", "race": "White", | |
| "us_citizen": "Yes", "marital_status": "Single", | |
| "religion": "Agnostic", "religious_attendance": "Never", | |
| "political_affiliation": "Independent", "income": "$50,000-$75,000", | |
| "political_views": "Moderate", "household_size": "2", | |
| "employment_status": "Full-time employment", | |
| } | |
| WILLINGNESS_LABELS = { | |
| 1: "Definitely would not buy", | |
| 2: "Probably would not buy", | |
| 3: "Slightly unlikely to buy", | |
| 4: "Neutral", | |
| 5: "Slightly likely to buy", | |
| 6: "Probably would buy", | |
| 7: "Definitely would buy", | |
| } | |
| WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()] | |
| # --------------------------------------------------------------------------- | |
| # Helpers: per-category file paths | |
| # --------------------------------------------------------------------------- | |
| def _data_path(category: str, suffix: str) -> str: | |
| subset = MIXED_SUBSET_SIZE if MODE == "mixed" else SINGLE_SUBSET_SIZE | |
| return os.path.join(DATA_DIR, f"{category}_test{subset}_{suffix}") | |
| def local_data_path(category: str) -> str: | |
| return _data_path(category, "primary.json") | |
| def overflow_path(category: str) -> str: | |
| return _data_path(category, "overflow.json") | |
| def counter_path(category: str) -> str: | |
| return _data_path(category, "counter.txt") | |
| def counter_lock_path(category: str) -> str: | |
| return _data_path(category, "counter.lock") | |
| def return_queue_path(category: str) -> str: | |
| return _data_path(category, "return_queue.json") | |
| # --------------------------------------------------------------------------- | |
| # Dataset loading | |
| # --------------------------------------------------------------------------- | |
| def download_and_cache_dataset(category: str, subset_size: int): | |
| """Download test split from HuggingFace and cache locally.""" | |
| primary_path = local_data_path(category) | |
| over_path = overflow_path(category) | |
| if os.path.exists(primary_path): | |
| print(f"[DATA] Found cached dataset for {category} at {primary_path}") | |
| return | |
| print(f"[DATA] Downloading {CATEGORY_TO_HF[category]} (test split, first {subset_size}) from HuggingFace...") | |
| try: | |
| from datasets import load_dataset | |
| import huggingface_hub | |
| if HF_TOKEN: | |
| huggingface_hub.login(token=HF_TOKEN) | |
| ds = load_dataset(CATEGORY_TO_HF[category], split="test") | |
| def to_list(val): | |
| if isinstance(val, list): return val | |
| if isinstance(val, str): return [val] if val else [] | |
| return [] | |
| all_items = [] | |
| for row in ds: | |
| meta = row.get("metadata", {}) | |
| item = { | |
| "id": str(uuid.uuid4()), | |
| "title": meta.get("title", "") if isinstance(meta, dict) else "", | |
| "description": to_list(meta.get("description", []) if isinstance(meta, dict) else []), | |
| "features": to_list(meta.get("features", []) if isinstance(meta, dict) else []), | |
| "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A", | |
| "category": category, | |
| } | |
| all_items.append(item) | |
| primary = all_items[:subset_size] | |
| overflow = all_items[subset_size:] | |
| with open(primary_path, "w") as f: | |
| json.dump(primary, f, indent=2) | |
| with open(over_path, "w") as f: | |
| json.dump(overflow, f, indent=2) | |
| print(f"[DATA] {category}: cached {len(primary)} primary + {len(overflow)} overflow items.") | |
| except Exception as e: | |
| print(f"[DATA] ERROR downloading {category}: {e}") | |
| raise | |
| def load_primary_dataset(category: str): | |
| with open(local_data_path(category), "r") as f: | |
| return json.load(f) | |
| def load_overflow_dataset(category: str): | |
| path = overflow_path(category) | |
| if not os.path.exists(path): | |
| return [] | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| def _ensure_datasets(): | |
| """Download/cache all needed category datasets.""" | |
| if MODE == "mixed": | |
| for cat in MIXED_CATEGORIES: | |
| download_and_cache_dataset(cat, MIXED_SUBSET_SIZE) | |
| else: | |
| download_and_cache_dataset(CATEGORY, SINGLE_SUBSET_SIZE) | |
| # --------------------------------------------------------------------------- | |
| # Per-category counter helpers | |
| # --------------------------------------------------------------------------- | |
| def _read_counter(category: str) -> int: | |
| path = counter_path(category) | |
| if not os.path.exists(path): | |
| return 0 | |
| with open(path, "r") as f: | |
| return int(f.read().strip() or "0") | |
| def _write_counter(category: str, value: int): | |
| with open(counter_path(category), "w") as f: | |
| f.write(str(value)) | |
| def _read_return_queue(category: str) -> list: | |
| path = return_queue_path(category) | |
| if not os.path.exists(path): | |
| return [] | |
| with open(path, "r") as f: | |
| try: | |
| return json.load(f) | |
| except Exception: | |
| return [] | |
| def _write_return_queue(category: str, queue: list): | |
| with open(return_queue_path(category), "w") as f: | |
| json.dump(queue, f) | |
| # --------------------------------------------------------------------------- | |
| # Product assignment | |
| # --------------------------------------------------------------------------- | |
| def _assign_from_category(category: str, n: int) -> list: | |
| """ | |
| Atomically assign n products from a single category pool. | |
| - Drains the return queue first. | |
| - Pulls sequentially from the primary pool. | |
| - Wraps around (modulo pool size) when exhausted so user 21+ still get valid items. | |
| """ | |
| items = load_primary_dataset(category) | |
| total = len(items) | |
| lock = FileLock(counter_lock_path(category)) | |
| with lock: | |
| return_queue = _read_return_queue(category) | |
| counter = _read_counter(category) | |
| assigned = [] | |
| for _ in range(n): | |
| if return_queue: | |
| assigned.append(return_queue.pop(0)) | |
| else: | |
| # Wrap-around: counter mod total so we cycle through items | |
| assigned.append(items[counter % total]) | |
| counter += 1 | |
| _write_return_queue(category, return_queue) | |
| _write_counter(category, counter) | |
| return assigned | |
| def assign_mixed_products(n: int = PRODUCTS_PER_USER) -> list: | |
| """ | |
| Assign n products split across movies and groceries. | |
| Alternates the majority category each call so coverage stays balanced. | |
| User 1: 3 movies + 2 groceries | |
| User 2: 2 movies + 3 groceries | |
| User 3: 3 movies + 2 groceries ... etc. | |
| The split is decided by reading the movies counter parity (even → movies gets 3). | |
| """ | |
| movies_counter = _read_counter("movies") | |
| # Even call-count → movies gets the larger share | |
| if (movies_counter // 1) % 2 == 0: | |
| n_movies, n_groceries = 3, 2 | |
| else: | |
| n_movies, n_groceries = 2, 3 | |
| # Clamp in case n != 5 | |
| if n_movies + n_groceries != n: | |
| n_movies = n // 2 | |
| n_groceries = n - n_movies | |
| movie_items = _assign_from_category("movies", n_movies) | |
| grocery_items = _assign_from_category("groceries", n_groceries) | |
| combined = movie_items + grocery_items | |
| random.shuffle(combined) # mix so user doesn't see all movies then all groceries | |
| return combined | |
| def assign_products(n: int = PRODUCTS_PER_USER) -> list: | |
| """Dispatcher: mixed mode or single-category mode.""" | |
| if MODE == "mixed": | |
| return assign_mixed_products(n) | |
| # Single-category (legacy behaviour) | |
| return _assign_from_category(CATEGORY, n) | |
| def return_product_to_queue(product: dict): | |
| """Put a rejected/swapped product back so it gets reassigned.""" | |
| cat = product.get("category", CATEGORY) | |
| lock = FileLock(counter_lock_path(cat)) | |
| with lock: | |
| queue = _read_return_queue(cat) | |
| if not any(p["id"] == product["id"] for p in queue): | |
| queue.append(product) | |
| _write_return_queue(cat, queue) | |
| def get_swap_product(exclude_ids: set, category: str) -> dict | None: | |
| """ | |
| Get a replacement product for the given category. | |
| 1. Next unassigned primary product (advances counter). | |
| 2. Wrap-around: any primary product not held by this user. | |
| 3. Overflow pool. | |
| """ | |
| items = load_primary_dataset(category) | |
| overflow = load_overflow_dataset(category) | |
| total = len(items) | |
| lock = FileLock(counter_lock_path(category)) | |
| with lock: | |
| counter = _read_counter(category) | |
| # 1. Unassigned (with wrap-around awareness) | |
| attempts = 0 | |
| while attempts < total: | |
| candidate = items[counter % total] | |
| counter += 1 | |
| attempts += 1 | |
| if candidate["id"] not in exclude_ids: | |
| _write_counter(category, counter) | |
| return candidate | |
| # 2. Any primary product not held by this user | |
| for p in items: | |
| if p["id"] not in exclude_ids: | |
| return p | |
| # 3. Overflow | |
| for p in overflow: | |
| if p["id"] not in exclude_ids: | |
| return p | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # AI client (Tinker) | |
| # --------------------------------------------------------------------------- | |
| def get_tinker_clients(): | |
| """Initialise and cache Tinker sampling client, renderer, and tokenizer.""" | |
| import tinker | |
| from tinker import types as tinker_types | |
| from tinker_cookbook import renderers | |
| from tinker_cookbook.tokenizer_utils import get_tokenizer | |
| from tinker_cookbook.model_info import get_recommended_renderer_name | |
| service_client = tinker.ServiceClient() | |
| sampling_client = service_client.create_sampling_client(base_model=MODEL_NAME) | |
| tokenizer = get_tokenizer(MODEL_NAME) | |
| renderer_name = get_recommended_renderer_name(MODEL_NAME) | |
| renderer = renderers.get_renderer(renderer_name, tokenizer) | |
| return sampling_client, renderer, tinker_types | |
| def call_model(messages: list) -> str: | |
| try: | |
| from tinker_cookbook import renderers as tinker_renderers | |
| sampling_client, renderer, tinker_types = get_tinker_clients() | |
| prompt = renderer.build_generation_prompt(messages) | |
| params = tinker_types.SamplingParams( | |
| max_tokens=1000, | |
| temperature=0.7, | |
| stop=renderer.get_stop_sequences(), | |
| ) | |
| result = sampling_client.sample( | |
| prompt=prompt, | |
| sampling_params=params, | |
| num_samples=1, | |
| ).result() | |
| parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) | |
| content = tinker_renderers.format_content_as_string(parsed_message["content"]) | |
| content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() | |
| return content | |
| except Exception as e: | |
| print(f"[MODEL] Tinker error: {e}") | |
| return f"[Model error: {e}]" | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace upload | |
| # --------------------------------------------------------------------------- | |
| def get_hf_api(): | |
| api = HfApi(token=HF_TOKEN) if HF_TOKEN else HfApi() | |
| if HF_TOKEN: | |
| try: | |
| api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
| print(f"[HF] Repo {DATASET_REPO_ID} exists.") | |
| except Exception as e: | |
| if "404" in str(e) or "not found" in str(e).lower(): | |
| api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True) | |
| print(f"[HF] Created repo {DATASET_REPO_ID}.") | |
| else: | |
| print(f"[HF] WARNING: {e}") | |
| return api | |
| def save_and_upload(state: dict): | |
| hf_api = get_hf_api() | |
| worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") | |
| submission_id = state.get("submission_id", str(uuid.uuid4())) | |
| safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) | |
| mode_tag = state.get("mode", "single") | |
| filename = f"{submission_id}_{mode_tag}.json" | |
| folder = os.path.join(ANNOTATIONS_DIR, safe_worker) | |
| os.makedirs(folder, exist_ok=True) | |
| file_path = os.path.join(folder, filename) | |
| with open(file_path, "w") as f: | |
| json.dump(state, f, indent=2) | |
| print(f"[SAVE] Wrote {file_path}") | |
| if HF_TOKEN: | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=f"{safe_worker}/{filename}", | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| print("[HF] Uploaded JSON.") | |
| except Exception as e: | |
| print(f"[HF] JSON upload error: {e}") | |
| upload_csv_rows(state, hf_api, safe_worker, submission_id) | |
| def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str): | |
| demographics = state.get("demographics", {}) | |
| products = state.get("products", []) | |
| header = [ | |
| "submission_id", "prolific_pid", "study_id", "session_id", | |
| "submission_time", "duration_seconds", "mode", "category", | |
| "age", "gender", "geographic_region", "education_level", "race", | |
| "us_citizen", "marital_status", "religion", "religious_attendance", | |
| "political_affiliation", "income", "political_views", "household_size", "employment_status", | |
| "product_index", "product_id", "title", "price", "familiarity", | |
| "pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label", | |
| "willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change", | |
| "was_swapped", | |
| ] | |
| rows = [] | |
| for i, prod in enumerate(products): | |
| conv = prod.get("conversation", {}) | |
| refl = prod.get("reflection", {}) | |
| pre = prod.get("pre_willingness", "") | |
| post = prod.get("post_willingness", "") | |
| delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" | |
| row = [ | |
| submission_id, | |
| state.get("prolific_pid", ""), | |
| state.get("study_id", ""), | |
| state.get("session_id", ""), | |
| state.get("meta", {}).get("submission_time", ""), | |
| state.get("meta", {}).get("duration_seconds", ""), | |
| state.get("mode", "single"), | |
| prod.get("category", ""), # per-product category | |
| demographics.get("age", ""), demographics.get("gender", ""), | |
| demographics.get("geographic_region", ""), demographics.get("education_level", ""), | |
| demographics.get("race", ""), demographics.get("us_citizen", ""), | |
| demographics.get("marital_status", ""), demographics.get("religion", ""), | |
| demographics.get("religious_attendance", ""), demographics.get("political_affiliation", ""), | |
| demographics.get("income", ""), demographics.get("political_views", ""), | |
| demographics.get("household_size", ""), demographics.get("employment_status", ""), | |
| i + 1, prod.get("id", ""), prod.get("title", ""), prod.get("price", ""), | |
| prod.get("familiarity", ""), | |
| pre, WILLINGNESS_LABELS.get(pre, "") if isinstance(pre, int) else "", | |
| post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "", | |
| delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])), | |
| refl.get("standout_moment", ""), refl.get("thinking_change", ""), | |
| prod.get("was_swapped", False), | |
| ] | |
| rows.append(row) | |
| timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| unique_id = uuid.uuid4().hex[:8] | |
| csv_filename = f"csv_submissions/{timestamp_str}_{safe_worker}_{unique_id}.csv" | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8") as tmp: | |
| tmp_path = tmp.name | |
| writer = csv.writer(tmp) | |
| writer.writerow(header) | |
| writer.writerows(rows) | |
| if HF_TOKEN: | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=tmp_path, | |
| path_in_repo=csv_filename, | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| print("[HF] Uploaded CSV rows.") | |
| except Exception as e: | |
| print(f"[HF] CSV upload error: {e}") | |
| os.unlink(tmp_path) | |
| # --------------------------------------------------------------------------- | |
| # Prompt builders | |
| # --------------------------------------------------------------------------- | |
| def build_sales_system_prompt(product: dict) -> str: | |
| title = product.get("title", "this product") | |
| price = product.get("price", "N/A") | |
| description = product.get("description", []) | |
| features = product.get("features", []) | |
| desc_text = " ".join(description) if description else "No description available." | |
| feat_text = " ".join(features) if features else "" | |
| price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price | |
| return f"""You are a warm, knowledgeable sales agent for an Amazon product. Your goal is to convince the user to buy this product. | |
| Product name: {title} | |
| Product description and features: {desc_text} {feat_text} | |
| Price: {price_str} | |
| You need to convince the user to buy it. | |
| First message rules: | |
| - In ONE paragraph: briefly highlight the product's best quality, explain why it's worth buying, and hit them with the strongest benefit | |
| - End with an engaging question that draws out their interest or hesitation | |
| Follow-up message rules: | |
| - In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question | |
| - Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature | |
| - Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality) | |
| - Use "imagine if..." scenarios to make benefits concrete | |
| General style: | |
| - Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson | |
| - End your messages with an engaging question | |
| - Never fabricate statistics, details, or reviews you don't have | |
| - Never make up a price different from the one given | |
| """ | |
| def build_opening_user_message(product: dict) -> str: | |
| return f'Tell me about this product and why I should buy it: "{product.get("title", "this product")}"' | |
| def parse_willingness(choice_str: str) -> int: | |
| try: | |
| return int(choice_str.split("(")[1].rstrip(")")) | |
| except Exception: | |
| return 4 | |
| def get_familiarity_choices(category: str) -> list: | |
| """Return familiarity options with the correct 'used' label for this product's category.""" | |
| used_label = FAMILIARITY_USED_LABEL.get(category, "Used it before") | |
| return [ | |
| "Never heard of it", | |
| "Heard of it, but not used/purchased", | |
| used_label, | |
| "Purchased it before", | |
| ] | |
| def needs_swap(familiarity_val: str, pre_will_val: str) -> bool: | |
| if familiarity_val in SWAP_FAMILIARITY: | |
| return True | |
| if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)" | |
| return True | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Welcome screen helpers | |
| # --------------------------------------------------------------------------- | |
| def study_display_name() -> str: | |
| """Human-readable name for what the user will evaluate.""" | |
| if MODE == "mixed": | |
| return "Movies & TV and Grocery Products" | |
| return CATEGORY_DISPLAY.get(CATEGORY, CATEGORY) | |
| def study_category_breakdown() -> str: | |
| """Extra sentence shown on welcome screen describing the mix.""" | |
| if MODE == "mixed": | |
| return ( | |
| "You will evaluate a mix of **Movies & TV** and **Grocery Products** " | |
| "(roughly 2–3 of each)." | |
| ) | |
| return "" | |
| # --------------------------------------------------------------------------- | |
| # State initialisation | |
| # --------------------------------------------------------------------------- | |
| def make_product_slot(p: dict, was_swapped: bool = False) -> dict: | |
| return { | |
| "id": p.get("id", str(uuid.uuid4())), | |
| "title": p.get("title", ""), | |
| "description": p.get("description", []), | |
| "features": p.get("features", []), | |
| "price": p.get("price", "N/A"), | |
| "category": p.get("category", CATEGORY), # ← per-product category | |
| "familiarity": None, | |
| "pre_willingness": None, | |
| "post_willingness": None, | |
| "willingness_delta": None, | |
| "was_swapped": was_swapped, | |
| "conversation": { | |
| "system_prompt": "", | |
| "opening_user_message": "", | |
| "turns": [], | |
| "num_turns": 0, | |
| }, | |
| "reflection": {}, | |
| } | |
| def init_state(): | |
| _ensure_datasets() | |
| assigned = assign_products(PRODUCTS_PER_USER) | |
| try: | |
| params = st.query_params | |
| except Exception: | |
| params = {} | |
| return { | |
| "submission_id": str(uuid.uuid4()), | |
| "user_id": str(uuid.uuid4()), | |
| "prolific_pid": params.get("PROLIFIC_PID", ""), | |
| "study_id": params.get("STUDY_ID", ""), | |
| "session_id": params.get("SESSION_ID", ""), | |
| "start_time": time.time(), | |
| "mode": MODE or "single", | |
| "category": CATEGORY if MODE != "mixed" else "mixed", | |
| "demographics": {}, | |
| "products": [make_product_slot(p) for p in assigned], | |
| "current_product_index": 0, | |
| "screen": "welcome", | |
| "meta": {}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # CSS | |
| # --------------------------------------------------------------------------- | |
| def inject_css(): | |
| st.markdown(""" | |
| <style> | |
| #MainMenu, footer, header { visibility: hidden; } | |
| .block-container { max-width: 820px; padding-top: 2rem; } | |
| .product-card { | |
| border: 2px solid #2563eb; | |
| border-radius: 10px; | |
| padding: 1rem 1.25rem; | |
| background: #f0f6ff; | |
| margin-bottom: 0.75rem; | |
| } | |
| .pc-header { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: flex-start; | |
| margin-bottom: 0.6rem; | |
| gap: 1rem; | |
| } | |
| .pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; } | |
| .pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; } | |
| .pc-category-badge { | |
| display: inline-block; | |
| font-size: 0.75rem; font-weight: 600; | |
| padding: 0.15rem 0.55rem; | |
| border-radius: 99px; | |
| margin-bottom: 0.4rem; | |
| background: #dbeafe; color: #1e40af; | |
| } | |
| .pc-section { margin-top: 0.5rem; } | |
| .pc-section-title { | |
| font-weight: 600; font-size: 0.85rem; color: #475569; | |
| text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.3rem; | |
| } | |
| .pc-desc { font-size: 0.92rem; color: #334155; line-height: 1.6; } | |
| .pc-list { margin: 0; padding-left: 1.2rem; font-size: 0.92rem; color: #334155; line-height: 1.5; } | |
| .pc-list li { margin-bottom: 0.25rem; } | |
| .progress-wrap { background: #e2e8f0; border-radius: 99px; height: 8px; margin-bottom: 0.25rem; overflow: hidden; } | |
| .progress-fill { background: #2563eb; height: 100%; border-radius: 99px; } | |
| .progress-label { font-size: 0.82rem; color: #64748b; text-align: right; margin-bottom: 1rem; } | |
| .chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; } | |
| .bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; } | |
| .bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; } | |
| .bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; } | |
| .bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # UI helpers | |
| # --------------------------------------------------------------------------- | |
| def render_product_card_html(product: dict, compact: bool = False) -> str: | |
| title = product.get("title", "Unknown Product") | |
| price = product.get("price", "N/A") | |
| description = product.get("description", []) | |
| features = product.get("features", []) | |
| category = product.get("category", "") | |
| price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price | |
| # Category badge — only shown in mixed mode | |
| badge_html = "" | |
| if MODE == "mixed" and category: | |
| badge_label = CATEGORY_DISPLAY.get(category, category) | |
| badge_html = f'<div class="pc-category-badge">📂 {badge_label}</div>' | |
| desc_html = "" | |
| if description: | |
| desc_text = " ".join(d for d in description if d) | |
| desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>' | |
| feat_html = "" | |
| if features: | |
| items_html = "".join(f"<li>{feat}</li>" for feat in features if feat) | |
| feat_html = f'<div class="pc-section"><div class="pc-section-title">✨ Features</div><ul class="pc-list">{items_html}</ul></div>' | |
| max_h = "max-height:240px;overflow-y:auto;" if compact else "" | |
| return f""" | |
| <div class="product-card" style="{max_h}"> | |
| {badge_html} | |
| <div class="pc-header"> | |
| <div class="pc-title">{title}</div> | |
| <div class="pc-price">{price_str}</div> | |
| </div> | |
| {desc_html} | |
| {feat_html} | |
| </div>""" | |
| def render_progress(current: int, total: int = PRODUCTS_PER_USER): | |
| pct = int((current / total) * 100) | |
| st.markdown(f""" | |
| <div class="progress-wrap"><div class="progress-fill" style="width:{pct}%"></div></div> | |
| <div class="progress-label">Product {current} of {total}</div> | |
| """, unsafe_allow_html=True) | |
| def render_chat_history(turns: list): | |
| html = '<div class="chat-wrap">' | |
| for turn in turns: | |
| role = turn.get("role", "") | |
| content = turn.get("content", "") | |
| if role == "assistant": | |
| html += f'<div class="bubble-label">🤖 AI Sales Agent</div><div class="bubble bubble-ai">{content}</div>' | |
| elif role == "user": | |
| html += f'<div class="bubble-label" style="text-align:right">You</div><div class="bubble bubble-user">{content}</div>' | |
| html += "</div>" | |
| st.markdown(html, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Screen renderers | |
| # --------------------------------------------------------------------------- | |
| def screen_welcome(s): | |
| st.markdown("# 🛒 Product Evaluation Study") | |
| breakdown = study_category_breakdown() | |
| st.markdown( | |
| f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {study_display_name()}** products.\n\n" | |
| + (f"{breakdown}\n\n" if breakdown else "") | |
| + | |
| "For each product you will:\n" | |
| "1. Rate how familiar you are with the product\n" | |
| "2. Rate how willing you are to buy it\n" | |
| "3. Chat with an AI about the product (**at least 3 exchanges**)\n" | |
| "4. Rate your willingness to buy it again\n" | |
| "5. Answer two brief reflection questions\n\n" | |
| "After all 5 products, you're done! The study takes about **20–30 minutes**. " | |
| "Thank you for participating!" | |
| ) | |
| if st.button("Begin →", type="primary", use_container_width=True): | |
| if DEBUG_MODE: | |
| s["demographics"] = DEBUG_DEMOGRAPHICS.copy() | |
| s["screen"] = "product_intro" | |
| else: | |
| s["screen"] = "demographics" | |
| st.rerun() | |
| def screen_demographics(s): | |
| st.markdown("## Demographics — About You") | |
| st.markdown("All fields are required before you can proceed.") | |
| age = st.text_input("Age (years)", placeholder="e.g. 34") | |
| gender = st.selectbox("Gender", ["", "Female", "Male"]) | |
| geographic_region = st.selectbox("Geographic region", ["", "West", "South", "Midwest", "Northeast", "Pacific"]) | |
| education_level = st.selectbox("Highest education level", [ | |
| "", "Less than high school", "High school graduate", | |
| "Some college, no degree", "Associate's degree", | |
| "College graduate/some postgrad", "Postgraduate", | |
| ]) | |
| race = st.selectbox("Race / ethnicity", ["", "Asian", "Hispanic", "White", "Black", "Other"]) | |
| us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"]) | |
| marital_status = st.selectbox("Marital status", [ | |
| "", "Never been married", "Married", "Living with a partner", | |
| "Divorced", "Separated", "Widowed", | |
| ]) | |
| religion = st.selectbox("Religion", [ | |
| "", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish", | |
| "Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other", | |
| ]) | |
| religious_attendance = st.selectbox("How often do you attend religious services?", [ | |
| "", "Never", "Seldom", "A few times a year", "Once or twice a month", | |
| "Once a week", "More than once a week", | |
| ]) | |
| political_affiliation = st.selectbox("Political affiliation", [ | |
| "", "Democrat", "Republican", "Independent", "Something else", | |
| ]) | |
| income = st.selectbox("Household income", [ | |
| "", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000", | |
| "$75,000-$100,000", "$100,000 or more", | |
| ]) | |
| political_views = st.selectbox("Political views", [ | |
| "", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative", | |
| ]) | |
| household_size = st.selectbox("Household size", ["", "1", "2", "3", "4", "More than 4"]) | |
| employment_status = st.selectbox("Employment status", [ | |
| "", "Full-time employment", "Part-time employment", "Self-employed", | |
| "Unemployed", "Retired", "Home-maker", "Student", | |
| ]) | |
| if st.button("Next →", type="primary", use_container_width=True): | |
| fields = [age, gender, geographic_region, education_level, race, us_citizen, | |
| marital_status, religion, religious_attendance, political_affiliation, | |
| income, political_views, household_size, employment_status] | |
| if not all([f and (f.strip() if isinstance(f, str) else f) for f in fields]): | |
| st.error("⚠️ Please complete all fields.") | |
| return | |
| if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120): | |
| st.error("⚠️ Please enter a valid age.") | |
| return | |
| s["demographics"] = { | |
| "age": age.strip(), "gender": gender, "geographic_region": geographic_region, | |
| "education_level": education_level, "race": race, "us_citizen": us_citizen, | |
| "marital_status": marital_status, "religion": religion, | |
| "religious_attendance": religious_attendance, "political_affiliation": political_affiliation, | |
| "income": income, "political_views": political_views, | |
| "household_size": household_size, "employment_status": employment_status, | |
| } | |
| s["screen"] = "product_intro" | |
| st.rerun() | |
| def screen_product_intro(s): | |
| idx = s["current_product_index"] | |
| product = s["products"][idx] | |
| product_category = product.get("category", CATEGORY) | |
| render_progress(idx + 1) | |
| st.markdown("## Product Evaluation") | |
| st.markdown("Please read the product information carefully, then answer the two questions below.") | |
| st.markdown(render_product_card_html(product), unsafe_allow_html=True) | |
| # Use per-product familiarity choices based on the product's own category | |
| familiarity_choices = get_familiarity_choices(product_category) | |
| familiarity_val = st.radio( | |
| "How familiar are you with this product?", | |
| familiarity_choices, | |
| index=None, | |
| key=f"familiarity_{idx}_{product['id']}", | |
| ) | |
| pre_will_val = st.radio( | |
| "How willing would you be to buy this product?", | |
| WILLINGNESS_CHOICES, | |
| index=None, | |
| key=f"pre_will_{idx}_{product['id']}", | |
| ) | |
| if st.button("Start Chat →", type="primary", use_container_width=True): | |
| if not DEBUG_MODE: | |
| if not familiarity_val: | |
| st.error("⚠️ Please rate your familiarity.") | |
| return | |
| if not pre_will_val: | |
| st.error("⚠️ Please rate your willingness to buy.") | |
| return | |
| familiarity_val = familiarity_val or familiarity_choices[0] | |
| pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3] | |
| # Check if we need to swap this product | |
| if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE: | |
| current_ids = {p["id"] for p in s["products"]} | |
| replacement = get_swap_product(exclude_ids=current_ids, category=product_category) | |
| if replacement: | |
| return_product_to_queue(s["products"][idx]) | |
| s["products"][idx] = make_product_slot(replacement, was_swapped=True) | |
| st.info("We've swapped this product for a better match. Please review the new product below.") | |
| st.rerun() | |
| return | |
| # No replacement found — proceed with this product anyway | |
| pre_val = parse_willingness(pre_will_val) | |
| s["products"][idx]["familiarity"] = familiarity_val | |
| s["products"][idx]["pre_willingness"] = pre_val | |
| s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val] | |
| system_prompt = build_sales_system_prompt(product) | |
| opening_user_msg = build_opening_user_message(product) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": opening_user_msg}, | |
| ] | |
| with st.spinner("Starting conversation…"): | |
| ai_reply = call_model(messages) | |
| s["products"][idx]["conversation"]["system_prompt"] = system_prompt | |
| s["products"][idx]["conversation"]["opening_user_message"] = opening_user_msg | |
| s["products"][idx]["conversation"]["turns"] = [ | |
| {"turn_index": 0, "role": "assistant", "content": ai_reply, | |
| "timestamp": time.time(), "model": MODEL_NAME} | |
| ] | |
| s["products"][idx]["conversation"]["num_turns"] = 0 | |
| s["screen"] = "chat" | |
| st.rerun() | |
| def screen_chat(s): | |
| idx = s["current_product_index"] | |
| product = s["products"][idx] | |
| conv = s["products"][idx]["conversation"] | |
| render_progress(idx + 1) | |
| st.markdown("## Chat with the AI") | |
| title = product.get("title", "Product") | |
| price = product.get("price", "N/A") | |
| price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price | |
| with st.expander(f"📦 {title} — {price_str} (click to expand product details)"): | |
| st.markdown(render_product_card_html(product, compact=True), unsafe_allow_html=True) | |
| num_turns = conv["num_turns"] | |
| st.markdown( | |
| f"Chat with the AI about whether you'd like to purchase the product. " | |
| f"Ask questions, push back, or explore your interest. " | |
| f"You need at least **{MIN_TURNS} exchanges** before you can move on." | |
| ) | |
| display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")] | |
| render_chat_history(display_turns) | |
| if num_turns >= MAX_TURNS: | |
| st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.") | |
| else: | |
| st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}") | |
| st.caption("💡 If you don't see the latest messages, scroll down while hovering over the conversation.") | |
| if num_turns < MAX_TURNS: | |
| user_msg = st.text_area( | |
| "Your response:", | |
| placeholder="Type your response here…", | |
| height=100, | |
| key=f"chat_input_{idx}_{num_turns}", | |
| ) | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| send_clicked = st.button("Send", type="primary", use_container_width=True) | |
| if send_clicked: | |
| if not user_msg or not user_msg.strip(): | |
| st.error("⚠️ Please type a message.") | |
| return | |
| if len(user_msg.strip().split()) < 5 and not DEBUG_MODE: | |
| st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).") | |
| return | |
| user_msg = user_msg.strip() | |
| messages = [ | |
| {"role": "system", "content": conv["system_prompt"]}, | |
| {"role": "user", "content": conv["opening_user_message"]}, | |
| ] | |
| for turn in conv["turns"]: | |
| messages.append({"role": turn["role"], "content": turn["content"]}) | |
| messages.append({"role": "user", "content": user_msg}) | |
| with st.spinner("AI is responding…"): | |
| ai_reply = call_model(messages) | |
| conv["turns"].append({"turn_index": len(conv["turns"]), "role": "user", | |
| "content": user_msg, "timestamp": time.time()}) | |
| conv["turns"].append({"turn_index": len(conv["turns"]), "role": "assistant", | |
| "content": ai_reply, "timestamp": time.time(), "model": MODEL_NAME}) | |
| conv["num_turns"] = num_turns + 1 | |
| s["products"][idx]["conversation"] = conv | |
| st.rerun() | |
| can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE | |
| if can_finish: | |
| if st.button("I'm done chatting →", use_container_width=True): | |
| s["screen"] = "post_will" | |
| st.rerun() | |
| else: | |
| st.button("I'm done chatting →", disabled=True, use_container_width=True, | |
| help=f"Complete at least {MIN_TURNS} exchanges first.") | |
| def screen_post_willingness(s): | |
| idx = s["current_product_index"] | |
| product = s["products"][idx] | |
| render_progress(idx + 1) | |
| st.markdown("## Your View Now") | |
| st.markdown("Now that you've chatted with the AI, rate your willingness to buy again.") | |
| st.markdown(render_product_card_html(product), unsafe_allow_html=True) | |
| post_will_val = st.radio( | |
| "How willing would you be to buy this product now?", | |
| WILLINGNESS_CHOICES, | |
| index=None, | |
| key=f"post_will_{idx}_{product['id']}", | |
| ) | |
| if st.button("Next →", type="primary", use_container_width=True): | |
| if not post_will_val and not DEBUG_MODE: | |
| st.error("⚠️ Please rate your willingness to buy.") | |
| return | |
| post_will_val = post_will_val or WILLINGNESS_CHOICES[3] | |
| post_val = parse_willingness(post_will_val) | |
| pre_val = s["products"][idx].get("pre_willingness", 4) | |
| delta = post_val - pre_val | |
| s["products"][idx]["post_willingness"] = post_val | |
| s["products"][idx]["post_willingness_label"] = WILLINGNESS_LABELS[post_val] | |
| s["products"][idx]["willingness_delta"] = delta | |
| s["screen"] = "reflection" | |
| st.rerun() | |
| def screen_reflection(s): | |
| idx = s["current_product_index"] | |
| render_progress(idx + 1) | |
| st.markdown("## Reflection") | |
| standout = st.text_area( | |
| "What did the AI say that stood out to you most?", | |
| placeholder="Describe a specific argument, question, or moment from the conversation…", | |
| height=120, | |
| key=f"standout_{idx}", | |
| ) | |
| thinking_change = st.text_area( | |
| "How did your thinking about this product change (or not change) during the chat? Why?", | |
| placeholder="Be as specific as you can…", | |
| height=120, | |
| key=f"thinking_{idx}", | |
| ) | |
| next_label = "Next Product →" if idx + 1 < PRODUCTS_PER_USER else "Submit Study →" | |
| if st.button(next_label, type="primary", use_container_width=True): | |
| if not DEBUG_MODE: | |
| if not standout or not standout.strip(): | |
| st.error("⚠️ Please answer the first reflection question.") | |
| return | |
| if len(standout.strip().split()) < 10: | |
| st.error(f"⚠️ Please write at least 10 words for the first question ({len(standout.strip().split())} so far).") | |
| return | |
| if not thinking_change or not thinking_change.strip(): | |
| st.error("⚠️ Please answer the second reflection question.") | |
| return | |
| if len(thinking_change.strip().split()) < 10: | |
| st.error(f"⚠️ Please write at least 10 words for the second question ({len(thinking_change.strip().split())} so far).") | |
| return | |
| standout = (standout or "").strip() or "[debug placeholder]" | |
| thinking_change = (thinking_change or "").strip() or "[debug placeholder]" | |
| s["products"][idx]["reflection"] = { | |
| "standout_moment": standout, | |
| "thinking_change": thinking_change, | |
| } | |
| next_idx = idx + 1 | |
| s["current_product_index"] = next_idx | |
| if next_idx >= PRODUCTS_PER_USER: | |
| end_time = time.time() | |
| s["meta"] = { | |
| "submission_time": end_time, | |
| "duration_seconds": round(end_time - s.get("start_time", end_time), 1), | |
| "model": MODEL_NAME, | |
| "mode": MODE or "single", | |
| "category": CATEGORY if MODE != "mixed" else "mixed", | |
| } | |
| with st.spinner("Saving your responses…"): | |
| save_and_upload(s) | |
| s["screen"] = "done" | |
| else: | |
| s["screen"] = "product_intro" | |
| st.rerun() | |
| def screen_done(s): | |
| st.markdown("## ✅ Study Complete!") | |
| st.markdown("**Thank you for completing the study!**") | |
| st.markdown(f"Here's a summary of how your willingness changed across the {PRODUCTS_PER_USER} products:") | |
| rows = [] | |
| for i, p in enumerate(s["products"]): | |
| pre = p.get("pre_willingness", "?") | |
| post = p.get("post_willingness", "?") | |
| delta = p.get("willingness_delta", 0) | |
| arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️") | |
| cat_label = CATEGORY_DISPLAY.get(p.get("category", ""), "") if MODE == "mixed" else "" | |
| rows.append({ | |
| "#": i + 1, | |
| **({"Category": cat_label} if MODE == "mixed" else {}), | |
| "Product": p.get("title", "")[:55] + ("…" if len(p.get("title", "")) > 55 else ""), | |
| "Before": WILLINGNESS_LABELS.get(pre, str(pre)), | |
| "After": WILLINGNESS_LABELS.get(post, str(post)), | |
| "Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–", | |
| }) | |
| import pandas as pd | |
| st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) | |
| st.markdown("---") | |
| st.success( | |
| f"**Your completion code:** `{PROLIFIC_COMPLETION_CODE}`\n\n" | |
| "Please copy this code and paste it on the Prolific website to complete your submission." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| st.set_page_config(page_title="Product Study", page_icon="🛒", layout="centered") | |
| inject_css() | |
| if "study_state" not in st.session_state: | |
| st.session_state.study_state = init_state() | |
| s = st.session_state.study_state | |
| screen = s.get("screen", "welcome") | |
| if screen == "welcome": | |
| screen_welcome(s) | |
| elif screen == "demographics": | |
| screen_demographics(s) | |
| elif screen == "product_intro": | |
| screen_product_intro(s) | |
| elif screen == "chat": | |
| screen_chat(s) | |
| elif screen == "post_will": | |
| screen_post_willingness(s) | |
| elif screen == "reflection": | |
| screen_reflection(s) | |
| elif screen == "done": | |
| screen_done(s) | |
| if __name__ == "__main__": | |
| main() |