Spaces:
Running
Running
| """ | |
| Streamlit App: AI Product Preference User Study (Pairs) | |
| ======================================================== | |
| Participants compare two similar products on a 7-point scale | |
| (Product A ↔ Product B), chat with an AI that tries to change | |
| their mind, and then rate their preference again. | |
| Run locally (mixed mode — movies + groceries): | |
| streamlit run src/streamlit_app.py | |
| streamlit run src/streamlit_app.py -- --debug | |
| On HuggingFace Spaces, set these environment variables in Space Settings → Variables: | |
| HF_TOKEN - HuggingFace token | |
| TINKER_API_KEY - Tinker AI API key | |
| DATASET_REPO_ID - HuggingFace dataset repo to upload results | |
| DEBUG_MODE - "true" to skip validation (optional) | |
| """ | |
| import re | |
| import csv | |
| import html as html_lib | |
| import json | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import tempfile | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from filelock import FileLock | |
| from huggingface_hub import HfApi, hf_hub_download | |
| load_dotenv() | |
| # --------------------------------------------------------------------------- | |
| # CLI args | |
| # --------------------------------------------------------------------------- | |
| import argparse | |
| parser = argparse.ArgumentParser(add_help=False) | |
| parser.add_argument("--debug", action="store_true", default=False) | |
| cli_args, _ = parser.parse_known_args() | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| DEBUG_MODE = False | |
| DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/preference-study") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| TINKER_API_KEY = os.getenv("TINKER_API_KEY") | |
| MODEL_NAME = "openai/gpt-oss-20b" | |
| # --------------------------------------------------------------------------- | |
| # Pair selection | |
| # --------------------------------------------------------------------------- | |
| PAIR_SELECTION_SEED = 42 # fixed seed for reproducible pair selection | |
| PAIRS_PER_CATEGORY = 50 # 50 movies + 50 groceries = 100 pool | |
| CATEGORIES = ["movies", "groceries"] | |
| # --------------------------------------------------------------------------- | |
| # Prolific config | |
| # --------------------------------------------------------------------------- | |
| PROLIFIC_COMPLETION_URL = "https://app.prolific.com/submissions/complete?cc=C1JEJWOQ" | |
| PROLIFIC_COMPLETION_CODE = "C1JEJWOQ" | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_DIR = os.path.join(BASE_DIR, "data") | |
| ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.makedirs(ANNOTATIONS_DIR, exist_ok=True) | |
| # HuggingFace repos that hold the pairs JSON files (created by collect_pairs.py) | |
| CATEGORY_TO_PAIRS_REPO = { | |
| "movies": "lms-shape-preferences/pairs_Movies_and_TV", | |
| "groceries": "lms-shape-preferences/pairs_Grocery_and_Gourmet_Food", | |
| } | |
| CATEGORY_DISPLAY = { | |
| "books": "Books", | |
| "groceries": "Grocery Products", | |
| "movies": "Movies & TV", | |
| "health": "Health & Household Products", | |
| } | |
| # Per-product familiarity label (depends on the individual product's category) | |
| FAMILIARITY_USED_LABEL = { | |
| "books": "Read it before", | |
| "movies": "Watched it before", | |
| "groceries": "Used it before", | |
| "health": "Used it before", | |
| } | |
| PAIRS_PER_USER = 5 | |
| MIN_TURNS = 3 | |
| MAX_TURNS = 10 | |
| # --------------------------------------------------------------------------- | |
| # Preference background questions | |
| # --------------------------------------------------------------------------- | |
| MIN_WORDS_BACKGROUND = 20 | |
| BACKGROUND_QUESTIONS = [ | |
| { | |
| "key": "movies_criteria", | |
| "label": "When picking between movies to purchase, what matters to you?", | |
| "placeholder": "e.g. I look for strong storytelling, good reviews, genre, director, cast…", | |
| }, | |
| { | |
| "key": "movies_enjoy", | |
| "label": "What kinds of movies do you usually enjoy, and why?", | |
| "placeholder": "e.g. I love sci-fi thrillers because they keep me on the edge of my seat…", | |
| }, | |
| { | |
| "key": "movies_avoid", | |
| "label": "What kinds of movies do you usually avoid, and why?", | |
| "placeholder": "e.g. I tend to skip horror movies because I don't enjoy being scared…", | |
| }, | |
| { | |
| "key": "groceries_criteria", | |
| "label": "When picking between foods or grocery items to purchase, what matters to you?", | |
| "placeholder": "e.g. Price, ingredients, brand trust, nutritional value, taste…", | |
| }, | |
| { | |
| "key": "groceries_enjoy", | |
| "label": "What kinds of foods or grocery items do you usually enjoy, and why?", | |
| "placeholder": "e.g. I enjoy organic snacks because they feel healthier and taste fresh…", | |
| }, | |
| { | |
| "key": "groceries_avoid", | |
| "label": "What kinds of foods or grocery items do you usually avoid, and why?", | |
| "placeholder": "e.g. I avoid heavily processed foods because of the artificial ingredients…", | |
| }, | |
| ] | |
| DEBUG_BACKGROUND = {q["key"]: "[debug placeholder — " + q["key"] + " " * 20 + "]" for q in BACKGROUND_QUESTIONS} | |
| DEBUG_DEMOGRAPHICS = { | |
| "age": "30", "gender": "Female", "geographic_region": "West", | |
| "education_level": "College graduate/some postgrad", "race": "White", | |
| "us_citizen": "Yes", "marital_status": "Single", | |
| "religion": "Agnostic", "religious_attendance": "Never", | |
| "political_affiliation": "Independent", "income": "$50,000-$75,000", | |
| "political_views": "Moderate", "household_size": "2", | |
| "employment_status": "Full-time employment", | |
| } | |
| PREFERENCE_LABELS = { | |
| 1: "Definitely would buy Product A", | |
| 2: "Probably would buy Product A", | |
| 3: "Slightly likely to buy Product A", | |
| 4: "Neutral", | |
| 5: "Slightly likely to buy Product B", | |
| 6: "Probably would buy Product B", | |
| 7: "Definitely would buy Product B", | |
| } | |
| PREFERENCE_CHOICES = [f"{v} ({k})" for k, v in PREFERENCE_LABELS.items()] | |
| # --------------------------------------------------------------------------- | |
| # Helpers: file paths | |
| # --------------------------------------------------------------------------- | |
| def _data_path(name: str) -> str: | |
| return os.path.join(DATA_DIR, name) | |
| def local_pairs_path(category: str) -> str: | |
| return _data_path(f"pairs_{category}_selected.json") | |
| def counter_path(category: str) -> str: | |
| return _data_path(f"pairs_{category}_counter.txt") | |
| def counter_lock_path(category: str) -> str: | |
| return _data_path(f"pairs_{category}_counter.lock") | |
| def alternation_counter_path() -> str: | |
| return _data_path("alternation_counter.txt") | |
| def alternation_lock_path() -> str: | |
| return _data_path("alternation_counter.lock") | |
| def return_queue_path(category: str) -> str: | |
| return _data_path(f"pairs_{category}_return_queue.json") | |
| # --------------------------------------------------------------------------- | |
| # Dataset loading: download pairs, select 50 per category reproducibly | |
| # --------------------------------------------------------------------------- | |
| def download_and_select_pairs(category: str): | |
| """Download pairs_test.json from HuggingFace, select PAIRS_PER_CATEGORY with fixed seed.""" | |
| selected_path = local_pairs_path(category) | |
| if os.path.exists(selected_path): | |
| print(f"[DATA] Found cached pairs for {category} at {selected_path}") | |
| return | |
| repo_id = CATEGORY_TO_PAIRS_REPO[category] | |
| print(f"[DATA] Downloading pairs_test.json from {repo_id}...") | |
| try: | |
| import huggingface_hub | |
| if HF_TOKEN: | |
| huggingface_hub.login(token=HF_TOKEN) | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="pairs_test.json", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| with open(downloaded, "r") as f: | |
| all_pairs = json.load(f) | |
| print(f"[DATA] {category}: loaded {len(all_pairs)} test pairs from HF.") | |
| # Reproducible selection with fixed seed | |
| rng = random.Random(PAIR_SELECTION_SEED) | |
| indices = list(range(len(all_pairs))) | |
| rng.shuffle(indices) | |
| selected = [all_pairs[i] for i in indices[:PAIRS_PER_CATEGORY]] | |
| with open(selected_path, "w") as f: | |
| json.dump(selected, f, indent=2) | |
| print(f"[DATA] {category}: selected {len(selected)} pairs (seed={PAIR_SELECTION_SEED}).") | |
| except Exception as e: | |
| print(f"[DATA] ERROR downloading {category} pairs: {e}") | |
| raise | |
| def load_selected_pairs(category: str) -> list: | |
| with open(local_pairs_path(category), "r") as f: | |
| return json.load(f) | |
| def _ensure_datasets(): | |
| """Download/cache all needed category pair datasets.""" | |
| for cat in CATEGORIES: | |
| download_and_select_pairs(cat) | |
| # --------------------------------------------------------------------------- | |
| # Counter helpers | |
| # --------------------------------------------------------------------------- | |
| def _read_counter(path: str) -> int: | |
| if not os.path.exists(path): | |
| return 0 | |
| with open(path, "r") as f: | |
| return int(f.read().strip() or "0") | |
| def _write_counter(path: str, value: int): | |
| with open(path, "w") as f: | |
| f.write(str(value)) | |
| def _read_return_queue(category: str) -> list: | |
| path = return_queue_path(category) | |
| if not os.path.exists(path): | |
| return [] | |
| with open(path, "r") as f: | |
| try: | |
| return json.load(f) | |
| except Exception: | |
| return [] | |
| def _write_return_queue(category: str, queue: list): | |
| with open(return_queue_path(category), "w") as f: | |
| json.dump(queue, f) | |
| # --------------------------------------------------------------------------- | |
| # Pair assignment | |
| # --------------------------------------------------------------------------- | |
| def _assign_from_category(category: str, n: int) -> list: | |
| """ | |
| Atomically assign n pairs from a single category pool. | |
| Wraps around (modulo pool size) when exhausted. | |
| """ | |
| pairs = load_selected_pairs(category) | |
| total = len(pairs) | |
| lock = FileLock(counter_lock_path(category)) | |
| with lock: | |
| ctr = _read_counter(counter_path(category)) | |
| assigned = [] | |
| for _ in range(n): | |
| assigned.append(pairs[ctr % total]) | |
| ctr += 1 | |
| _write_counter(counter_path(category), ctr) | |
| return assigned | |
| def assign_pairs(n: int = PAIRS_PER_USER) -> list: | |
| """ | |
| Assign n pairs split across movies and groceries. | |
| Uses a dedicated alternation counter (increments by 1 per call) | |
| so the 3/2 split truly alternates between users. | |
| User 1: 3 movies + 2 groceries | |
| User 2: 2 movies + 3 groceries | |
| User 3: 3 movies + 2 groceries ... etc. | |
| BUG FIX: The original study used the movies product counter for | |
| alternation, but that counter advances by 2 or 3 (not 1), so parity | |
| was wrong after the first user. This version uses a separate counter | |
| that increments by exactly 1 per assignment call. | |
| """ | |
| lock = FileLock(alternation_lock_path()) | |
| with lock: | |
| call_count = _read_counter(alternation_counter_path()) | |
| if call_count % 2 == 0: | |
| n_movies, n_groceries = 3, 2 | |
| else: | |
| n_movies, n_groceries = 2, 3 | |
| _write_counter(alternation_counter_path(), call_count + 1) | |
| # Clamp in case n != 5 | |
| if n_movies + n_groceries != n: | |
| n_movies = n // 2 | |
| n_groceries = n - n_movies | |
| movie_pairs = _assign_from_category("movies", n_movies) | |
| grocery_pairs = _assign_from_category("groceries", n_groceries) | |
| combined = movie_pairs + grocery_pairs | |
| random.shuffle(combined) # mix so user doesn't see all movies then all groceries | |
| return combined | |
| # --------------------------------------------------------------------------- | |
| # AI client (Tinker) | |
| # --------------------------------------------------------------------------- | |
| def get_tinker_clients(): | |
| """Initialise and cache Tinker sampling client, renderer, and tokenizer.""" | |
| import tinker | |
| from tinker import types as tinker_types | |
| from tinker_cookbook import renderers | |
| from tinker_cookbook.tokenizer_utils import get_tokenizer | |
| from tinker_cookbook.model_info import get_recommended_renderer_name | |
| service_client = tinker.ServiceClient() | |
| sampling_client = service_client.create_sampling_client(base_model=MODEL_NAME) | |
| tokenizer = get_tokenizer(MODEL_NAME) | |
| renderer_name = get_recommended_renderer_name(MODEL_NAME) | |
| renderer = renderers.get_renderer(renderer_name, tokenizer) | |
| return sampling_client, renderer, tinker_types | |
| def call_model(messages: list) -> str: | |
| try: | |
| from tinker_cookbook import renderers as tinker_renderers | |
| sampling_client, renderer, tinker_types = get_tinker_clients() | |
| prompt = renderer.build_generation_prompt(messages) | |
| params = tinker_types.SamplingParams( | |
| max_tokens=1000, | |
| temperature=0.7, | |
| stop=renderer.get_stop_sequences(), | |
| ) | |
| result = sampling_client.sample( | |
| prompt=prompt, | |
| sampling_params=params, | |
| num_samples=1, | |
| ).result() | |
| parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) | |
| content = tinker_renderers.format_content_as_string(parsed_message["content"]) | |
| # --- cleanup --- | |
| # 1. Strip <think>...</think> blocks | |
| content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() | |
| # 2. Strip leaked control tokens like <|channel|>, <|message|>, <|end|>, etc. | |
| content = re.sub(r"<\|[^|]*\|>", "", content).strip() | |
| # 3. Detect degenerate repetition (Pocahontas-type failure): | |
| # If any 40+ char substring repeats 5+ times, truncate to first occurrence | |
| match = re.search(r"(.{40,}?)\1{4,}", content, flags=re.DOTALL) | |
| if match: | |
| first_end = match.start() + len(match.group(1)) | |
| content = content[:first_end].strip() | |
| # 4. If cleanup left us with nothing usable, return a fallback | |
| if not content or len(content.split()) < 3: | |
| raise ValueError("Model output cleanup failure") | |
| return content | |
| except Exception as e: | |
| print(f"[MODEL] Tinker error: {e}") | |
| return f"[Model error: {e}]" | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace upload | |
| # --------------------------------------------------------------------------- | |
| def get_hf_api(): | |
| api = HfApi(token=HF_TOKEN) if HF_TOKEN else HfApi() | |
| if HF_TOKEN: | |
| try: | |
| api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
| print(f"[HF] Repo {DATASET_REPO_ID} exists.") | |
| except Exception as e: | |
| if "404" in str(e) or "not found" in str(e).lower(): | |
| api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True) | |
| print(f"[HF] Created repo {DATASET_REPO_ID}.") | |
| else: | |
| print(f"[HF] WARNING: {e}") | |
| return api | |
| def save_and_upload(state: dict): | |
| hf_api = get_hf_api() | |
| worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") | |
| submission_id = state.get("submission_id", str(uuid.uuid4())) | |
| safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) | |
| filename = f"{submission_id}_preference.json" | |
| folder = os.path.join(ANNOTATIONS_DIR, safe_worker) | |
| os.makedirs(folder, exist_ok=True) | |
| file_path = os.path.join(folder, filename) | |
| with open(file_path, "w") as f: | |
| json.dump(state, f, indent=2) | |
| print(f"[SAVE] Wrote {file_path}") | |
| if HF_TOKEN: | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=f"{safe_worker}/{filename}", | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| print("[HF] Uploaded JSON.") | |
| except Exception as e: | |
| print(f"[HF] JSON upload error: {e}") | |
| upload_csv_rows(state, hf_api, safe_worker, submission_id) | |
| def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str): | |
| demographics = state.get("demographics", {}) | |
| background = state.get("preferences_background", {}) | |
| pairs = state.get("pairs", []) | |
| header = [ | |
| "submission_id", "prolific_pid", "study_id", "session_id", | |
| "submission_time", "duration_seconds", "study_type", "category", | |
| # demographics | |
| "age", "gender", "geographic_region", "education_level", "race", | |
| "us_citizen", "marital_status", "religion", "religious_attendance", | |
| "political_affiliation", "income", "political_views", "household_size", | |
| "employment_status", | |
| # preferences background | |
| "movies_criteria", "movies_enjoy", "movies_avoid", | |
| "groceries_criteria", "groceries_enjoy", "groceries_avoid", | |
| # pair info | |
| "pair_index", "pair_id", | |
| "product_a_id", "product_a_title", "product_a_price", "familiarity_a", | |
| "product_b_id", "product_b_title", "product_b_price", "familiarity_b", | |
| # preference | |
| "pre_preference", "pre_preference_label", | |
| "post_preference", "post_preference_label", | |
| "preference_delta", "persuasion_target", | |
| # conversation | |
| "num_turns", "conversation_json", | |
| # reflection | |
| "standout_moment", "thinking_change", | |
| ] | |
| rows = [] | |
| for i, pair in enumerate(pairs): | |
| conv = pair.get("conversation", {}) | |
| refl = pair.get("reflection", {}) | |
| pre = pair.get("pre_preference", "") | |
| post = pair.get("post_preference", "") | |
| delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" | |
| row = [ | |
| submission_id, | |
| state.get("prolific_pid", ""), | |
| state.get("study_id", ""), | |
| state.get("session_id", ""), | |
| state.get("meta", {}).get("submission_time", ""), | |
| state.get("meta", {}).get("duration_seconds", ""), | |
| "preference", | |
| pair.get("category", ""), | |
| # demographics | |
| demographics.get("age", ""), demographics.get("gender", ""), | |
| demographics.get("geographic_region", ""), demographics.get("education_level", ""), | |
| demographics.get("race", ""), demographics.get("us_citizen", ""), | |
| demographics.get("marital_status", ""), demographics.get("religion", ""), | |
| demographics.get("religious_attendance", ""), demographics.get("political_affiliation", ""), | |
| demographics.get("income", ""), demographics.get("political_views", ""), | |
| demographics.get("household_size", ""), demographics.get("employment_status", ""), | |
| # preferences background | |
| background.get("movies_criteria", ""), | |
| background.get("movies_enjoy", ""), | |
| background.get("movies_avoid", ""), | |
| background.get("groceries_criteria", ""), | |
| background.get("groceries_enjoy", ""), | |
| background.get("groceries_avoid", ""), | |
| # pair info | |
| i + 1, pair.get("pair_id", ""), | |
| pair.get("product_a", {}).get("id", ""), | |
| pair.get("product_a", {}).get("title", ""), | |
| pair.get("product_a", {}).get("price", ""), | |
| pair.get("familiarity_a", ""), | |
| pair.get("product_b", {}).get("id", ""), | |
| pair.get("product_b", {}).get("title", ""), | |
| pair.get("product_b", {}).get("price", ""), | |
| pair.get("familiarity_b", ""), | |
| # preference | |
| pre, PREFERENCE_LABELS.get(pre, "") if isinstance(pre, int) else "", | |
| post, PREFERENCE_LABELS.get(post, "") if isinstance(post, int) else "", | |
| delta, pair.get("persuasion_target", ""), | |
| # conversation | |
| conv.get("num_turns", 0), json.dumps(conv.get("turns", [])), | |
| # reflection | |
| refl.get("standout_moment", ""), refl.get("thinking_change", ""), | |
| ] | |
| rows.append(row) | |
| timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| unique_id = uuid.uuid4().hex[:8] | |
| csv_filename = f"csv_submissions/{timestamp_str}_{safe_worker}_{unique_id}.csv" | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, newline="", | |
| encoding="utf-8") as tmp: | |
| tmp_path = tmp.name | |
| writer = csv.writer(tmp) | |
| writer.writerow(header) | |
| writer.writerows(rows) | |
| if HF_TOKEN: | |
| try: | |
| hf_api.upload_file( | |
| path_or_fileobj=tmp_path, | |
| path_in_repo=csv_filename, | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| print("[HF] Uploaded CSV rows.") | |
| except Exception as e: | |
| print(f"[HF] CSV upload error: {e}") | |
| os.unlink(tmp_path) | |
| # --------------------------------------------------------------------------- | |
| # Prompt builders | |
| # --------------------------------------------------------------------------- | |
| def _product_info_block(product: dict, label: str) -> str: | |
| """Format one product's info for the system prompt.""" | |
| title = product.get("title", "Unknown") | |
| price = product.get("price", "N/A") | |
| description = product.get("description", []) | |
| features = product.get("features", []) | |
| desc_text = " ".join(description) if isinstance(description, list) else str(description) | |
| desc_text = desc_text | |
| feat_text = " ".join(features) if features else "" | |
| price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price | |
| block = f"""{label}: | |
| Name: {title} | |
| Price: {price_str} | |
| Description: {desc_text}""" | |
| if feat_text: | |
| block += f"\n Features: {feat_text}" | |
| return block | |
| def build_persuasion_system_prompt(pair: dict, persuasion_target: str, pre_preference: int) -> str: | |
| """ | |
| Build a system prompt that tells the AI to argue for the target product | |
| (the opposite of the user's initial lean). | |
| The user's preference is NOT included here — it will appear as the | |
| first user message in the conversation instead. | |
| persuasion_target: "A" or "B" | |
| pre_preference: 1-7 (4 = neutral) — kept for logging but not used in prompt | |
| """ | |
| product_a = pair["product_a"] | |
| product_b = pair["product_b"] | |
| block_a = _product_info_block(product_a, "Product A") | |
| block_b = _product_info_block(product_b, "Product B") | |
| return f"""You are a warm, knowledgeable sales agent for an Amazon product. A user is comparing two products. Your goal is to persuade them to buy Product {persuasion_target}. | |
| {block_a} | |
| {block_b} | |
| You are advocating for Product {persuasion_target}. | |
| The conversation will begin with you asking the user which product they'd prefer, and the user will state their preference. After that, begin persuading them toward Product {persuasion_target}. | |
| IMPORTANT: Even if the user says they are neutral or undecided, do NOT ask them to clarify — immediately begin persuading them toward Product {persuasion_target}. | |
| First message rules (your reply after the user states their preference): | |
| - In ONE paragraph: briefly highlight Product {persuasion_target}'s best quality, explain why it's worth buying, and hit them with the strongest benefit | |
| - End with an engaging question that draws out their interest or hesitation | |
| Follow-up message rules: | |
| - In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question | |
| - Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature | |
| - Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality) | |
| - Use "imagine if..." scenarios to make benefits concrete | |
| General style: | |
| - Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson | |
| - End your messages with an engaging question | |
| - Never fabricate statistics, details, or reviews you don't have | |
| - Never make up a price different from the one given | |
| """ | |
| def build_preference_statement(pre_preference: int) -> str: | |
| """Build the user's preference statement for the first turn of conversation.""" | |
| label = PREFERENCE_LABELS[pre_preference] | |
| return f"I'd say: {label}." | |
| # Hardcoded opening question from the AI (not generated by the model) | |
| OPENING_AI_QUESTION = "Which of these two products would you prefer to buy?" | |
| def parse_preference(choice_str: str) -> int: | |
| try: | |
| return int(choice_str.split("(")[1].rstrip(")")) | |
| except Exception: | |
| return 4 | |
| def get_familiarity_choices(category: str) -> list: | |
| """Return familiarity options with the correct 'used' label for this product's category.""" | |
| used_label = FAMILIARITY_USED_LABEL.get(category, "Used it before") | |
| return [ | |
| "Never heard of it", | |
| "Heard of it, but not used/purchased", | |
| used_label, | |
| "Purchased it before", | |
| ] | |
| def determine_persuasion_target(pre_preference: int) -> str: | |
| """ | |
| Determine which product the AI should argue for. | |
| - User leans A (1-3): persuade toward B | |
| - User leans B (5-7): persuade toward A | |
| - Neutral (4): random pick | |
| """ | |
| if pre_preference < 4: | |
| return "B" | |
| elif pre_preference > 4: | |
| return "A" | |
| else: | |
| return random.choice(["A", "B"]) | |
| # --------------------------------------------------------------------------- | |
| # State initialisation | |
| # --------------------------------------------------------------------------- | |
| def make_pair_slot(pair_data: dict) -> dict: | |
| return { | |
| "pair_id": pair_data.get("pair_id", str(uuid.uuid4())), | |
| "category": pair_data.get("category", ""), | |
| "product_a": pair_data.get("product_a", {}), | |
| "product_b": pair_data.get("product_b", {}), | |
| "familiarity_a": None, | |
| "familiarity_b": None, | |
| "pre_preference": None, | |
| "post_preference": None, | |
| "preference_delta": None, | |
| "persuasion_target": None, | |
| "conversation": { | |
| "system_prompt": "", | |
| "opening_user_message": "", | |
| "turns": [], | |
| "num_turns": 0, | |
| }, | |
| "reflection": {}, | |
| } | |
| def init_state(): | |
| _ensure_datasets() | |
| assigned = assign_pairs(PAIRS_PER_USER) | |
| try: | |
| params = st.query_params | |
| except Exception: | |
| params = {} | |
| return { | |
| "submission_id": str(uuid.uuid4()), | |
| "user_id": str(uuid.uuid4()), | |
| "prolific_pid": params.get("PROLIFIC_PID", ""), | |
| "study_id": params.get("STUDY_ID", ""), | |
| "session_id": params.get("SESSION_ID", ""), | |
| "start_time": time.time(), | |
| "study_type": "preference", | |
| "demographics": {}, | |
| "preferences_background": {}, | |
| "pairs": [make_pair_slot(p) for p in assigned], | |
| "current_pair_index": 0, | |
| "screen": "welcome", | |
| "meta": {}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # CSS | |
| # --------------------------------------------------------------------------- | |
| def inject_css(): | |
| st.markdown(""" | |
| <style> | |
| #MainMenu, footer, header { visibility: hidden; } | |
| .block-container { max-width: 860px; padding-top: 2rem; } | |
| .product-card { | |
| border-radius: 10px; | |
| padding: 1rem 1.25rem; | |
| margin-bottom: 0.75rem; | |
| } | |
| .product-card-a { | |
| border: 2px solid #2563eb; | |
| background: #eff6ff; | |
| } | |
| .product-card-b { | |
| border: 2px solid #9333ea; | |
| background: #faf5ff; | |
| } | |
| .pc-header { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: flex-start; | |
| margin-bottom: 0.6rem; | |
| gap: 1rem; | |
| } | |
| .pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; } | |
| .pc-price { font-size: 1.2rem; font-weight: 800; white-space: nowrap; } | |
| .pc-price-a { color: #16a34a; } | |
| .pc-price-b { color: #16a34a; } | |
| .pc-label { | |
| display: inline-block; | |
| font-size: 0.8rem; font-weight: 700; | |
| padding: 0.2rem 0.6rem; | |
| border-radius: 99px; | |
| margin-bottom: 0.4rem; | |
| } | |
| .pc-label-a { background: #dbeafe; color: #1e40af; } | |
| .pc-label-b { background: #ede9fe; color: #6b21a8; } | |
| .pc-category-badge { | |
| display: inline-block; | |
| font-size: 0.7rem; font-weight: 600; | |
| padding: 0.12rem 0.5rem; | |
| border-radius: 99px; | |
| margin-left: 0.4rem; | |
| background: #f1f5f9; color: #475569; | |
| } | |
| .pc-section { margin-top: 0.5rem; } | |
| .pc-section-title { | |
| font-weight: 600; font-size: 0.85rem; color: #475569; | |
| text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 0.3rem; | |
| } | |
| .pc-desc { font-size: 0.92rem; color: #334155; line-height: 1.6; } | |
| .pc-list { margin: 0; padding-left: 1.2rem; font-size: 0.92rem; color: #334155; line-height: 1.5; } | |
| .pc-list li { margin-bottom: 0.25rem; } | |
| .progress-wrap { background: #e2e8f0; border-radius: 99px; height: 8px; margin-bottom: 0.25rem; overflow: hidden; } | |
| .progress-fill { background: #2563eb; height: 100%; border-radius: 99px; } | |
| .progress-label { font-size: 0.82rem; color: #64748b; text-align: right; margin-bottom: 1rem; } | |
| .chat-wrap { max-height: 420px; overflow-y: auto; margin-bottom: 1rem; } | |
| .bubble { padding: 0.65rem 0.9rem; border-radius: 12px; margin-bottom: 0.5rem; font-size: 0.93rem; line-height: 1.5; } | |
| .bubble-ai { background: #eff6ff; border: 1px solid #93c5fd; margin-right: 10%; } | |
| .bubble-user { background: #f0fdf4; border: 1px solid #86efac; margin-left: 10%; text-align: right; } | |
| .bubble-label { font-size: 0.75rem; color: #94a3b8; margin-bottom: 0.2rem; } | |
| .vs-divider { | |
| text-align: center; font-size: 1.4rem; font-weight: 800; | |
| color: #94a3b8; margin: 0.3rem 0; | |
| } | |
| .section-divider { | |
| border: none; | |
| border-top: 2px solid #e2e8f0; | |
| margin: 1.5rem 0 1rem 0; | |
| } | |
| .section-heading { | |
| font-size: 1rem; font-weight: 700; color: #1e40af; | |
| margin-bottom: 0.5rem; | |
| } | |
| .section-heading-grocery { | |
| font-size: 1rem; font-weight: 700; color: #16a34a; | |
| margin-bottom: 0.5rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # HTML escaping | |
| # --------------------------------------------------------------------------- | |
| def _safe(text: str) -> str: | |
| unescaped = html_lib.unescape(str(text)) | |
| unescaped = re.sub(r'([.!?:])([A-Z])', r'\1 \2', unescaped) | |
| escaped = html_lib.escape(unescaped) | |
| for ch in ['*', '_', '~', '`', '[', ']']: | |
| escaped = escaped.replace(ch, f'&#{ord(ch)};') | |
| escaped = escaped.replace('\n', ' ') | |
| return escaped | |
| # --------------------------------------------------------------------------- | |
| # UI helpers | |
| # --------------------------------------------------------------------------- | |
| def render_single_product_card_html(product: dict, label: str, compact: bool = False) -> str: | |
| """Render one product card with an A/B label.""" | |
| title = _safe(product.get("title", "Unknown Product")) | |
| price = product.get("price", "N/A") | |
| description = product.get("description", []) | |
| features = product.get("features", []) | |
| category = product.get("category", "") | |
| price_str = f"${_safe(str(price))}" if price and price != "N/A" and not str(price).startswith("$") else _safe(str(price)) | |
| side = "a" if label == "A" else "b" | |
| cat_badge = "" | |
| if category: | |
| cat_label = _safe(CATEGORY_DISPLAY.get(category, category)) | |
| cat_badge = f'<span class="pc-category-badge">{cat_label}</span>' | |
| desc_html = "" | |
| if description: | |
| desc_text = " ".join(d for d in description if d) if isinstance(description, list) else str(description) | |
| desc_html = f'<div class="pc-section"><div class="pc-section-title">Description</div><div class="pc-desc">{_safe(desc_text)}</div></div>' | |
| feat_html = "" | |
| if features: | |
| items_html = "".join(f"<li>{_safe(feat)}</li>" for feat in features if feat) | |
| if items_html: | |
| feat_html = f'<div class="pc-section"><div class="pc-section-title">Features</div><ul class="pc-list">{items_html}</ul></div>' | |
| max_h = "max-height:220px;overflow-y:auto;" if compact else "" | |
| return f""" | |
| <div class="product-card product-card-{side}" style="{max_h}"> | |
| <div class="pc-label pc-label-{side}">Product {label}{cat_badge}</div> | |
| <div class="pc-header"> | |
| <div class="pc-title">{title}</div> | |
| <div class="pc-price pc-price-{side}">{price_str}</div> | |
| </div> | |
| {desc_html} | |
| {feat_html} | |
| </div>""" | |
| def render_pair_cards_html(pair: dict, compact: bool = False) -> str: | |
| html_a = render_single_product_card_html(pair["product_a"], "A", compact=compact) | |
| html_b = render_single_product_card_html(pair["product_b"], "B", compact=compact) | |
| return html_a + '<div class="vs-divider">— VS —</div>' + html_b | |
| def render_progress(current: int, total: int = PAIRS_PER_USER): | |
| pct = int((current / total) * 100) | |
| st.markdown(f""" | |
| <div class="progress-wrap"><div class="progress-fill" style="width:{pct}%"></div></div> | |
| <div class="progress-label">Pair {current} of {total}</div> | |
| """, unsafe_allow_html=True) | |
| def render_chat_history(turns: list): | |
| html = '<div class="chat-wrap">' | |
| for turn in turns: | |
| role = turn.get("role", "") | |
| content = _safe(turn.get("content", "")) | |
| if role == "assistant": | |
| html += f'<div class="bubble-label">🤖 AI Product Agent</div><div class="bubble bubble-ai">{content}</div>' | |
| elif role == "user": | |
| html += f'<div class="bubble-label" style="text-align:right">You</div><div class="bubble bubble-user">{content}</div>' | |
| html += "</div>" | |
| st.markdown(html, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # Screen renderers | |
| # --------------------------------------------------------------------------- | |
| def screen_welcome(s): | |
| st.markdown("# 🛒 Product Preference Study") | |
| st.markdown( | |
| f"Welcome! In this study you will compare **{PAIRS_PER_USER} pairs** of products " | |
| f"(**Movies & TV** and **Grocery Products**).\n\n" | |
| "For each pair you will:\n" | |
| "1. Review two similar products (Product A and Product B)\n" | |
| "2. Rate how familiar you are with each product\n" | |
| "3. Rate which product you'd prefer to buy on a 7-point scale\n" | |
| "4. Chat with an AI about the products (**at least 3 exchanges**)\n" | |
| "5. Rate your preference again\n" | |
| "6. Answer two brief reflection questions\n\n" | |
| "After all 5 pairs, you're done! The study takes about **30-40 minutes**. " | |
| "Thank you for participating!" | |
| ) | |
| if st.button("Begin →", type="primary", use_container_width=True): | |
| if DEBUG_MODE: | |
| s["demographics"] = DEBUG_DEMOGRAPHICS.copy() | |
| s["preferences_background"] = DEBUG_BACKGROUND.copy() | |
| s["screen"] = "pair_intro" | |
| else: | |
| s["screen"] = "demographics" | |
| st.rerun() | |
| def screen_demographics(s): | |
| st.markdown("## Demographics — About You") | |
| st.markdown("All fields are required before you can proceed.") | |
| age = st.text_input("Age (years)", placeholder="e.g. 34") | |
| gender = st.selectbox("Gender", ["", "Female", "Male"]) | |
| geographic_region = st.selectbox("Geographic region", | |
| ["", "West", "South", "Midwest", "Northeast", "Pacific"]) | |
| education_level = st.selectbox("Highest education level", [ | |
| "", "Less than high school", "High school graduate", | |
| "Some college, no degree", "Associate's degree", | |
| "College graduate/some postgrad", "Postgraduate", | |
| ]) | |
| race = st.selectbox("Race / ethnicity", ["", "Asian", "Hispanic", "White", "Black", "Other"]) | |
| us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"]) | |
| marital_status = st.selectbox("Marital status", [ | |
| "", "Never been married", "Married", "Living with a partner", | |
| "Divorced", "Separated", "Widowed", | |
| ]) | |
| religion = st.selectbox("Religion", [ | |
| "", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish", | |
| "Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other", | |
| ]) | |
| religious_attendance = st.selectbox("How often do you attend religious services?", [ | |
| "", "Never", "Seldom", "A few times a year", "Once or twice a month", | |
| "Once a week", "More than once a week", | |
| ]) | |
| political_affiliation = st.selectbox("Political affiliation", [ | |
| "", "Democrat", "Republican", "Independent", "Something else", | |
| ]) | |
| income = st.selectbox("Household income", [ | |
| "", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000", | |
| "$75,000-$100,000", "$100,000 or more", | |
| ]) | |
| political_views = st.selectbox("Political views", [ | |
| "", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative", | |
| ]) | |
| household_size = st.selectbox("Household size", ["", "1", "2", "3", "4", "More than 4"]) | |
| employment_status = st.selectbox("Employment status", [ | |
| "", "Full-time employment", "Part-time employment", "Self-employed", | |
| "Unemployed", "Retired", "Home-maker", "Student", | |
| ]) | |
| if st.button("Next →", type="primary", use_container_width=True): | |
| fields = [age, gender, geographic_region, education_level, race, us_citizen, | |
| marital_status, religion, religious_attendance, political_affiliation, | |
| income, political_views, household_size, employment_status] | |
| if not all([f and (f.strip() if isinstance(f, str) else f) for f in fields]): | |
| st.error("⚠️ Please complete all fields.") | |
| return | |
| if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120): | |
| st.error("⚠️ Please enter a valid age.") | |
| return | |
| s["demographics"] = { | |
| "age": age.strip(), "gender": gender, "geographic_region": geographic_region, | |
| "education_level": education_level, "race": race, "us_citizen": us_citizen, | |
| "marital_status": marital_status, "religion": religion, | |
| "religious_attendance": religious_attendance, | |
| "political_affiliation": political_affiliation, | |
| "income": income, "political_views": political_views, | |
| "household_size": household_size, "employment_status": employment_status, | |
| } | |
| s["screen"] = "preferences_background" | |
| st.rerun() | |
| def screen_preferences_background(s): | |
| st.markdown("## Your Preferences — Before We Start") | |
| st.markdown( | |
| "Before you begin evaluating products, we'd like to understand your general preferences. " | |
| f"Please write at least **{MIN_WORDS_BACKGROUND} words** for each question." | |
| ) | |
| # --- Movies section --- | |
| st.markdown('<div class="section-heading">🎬 Movies & TV</div>', unsafe_allow_html=True) | |
| answers = {} | |
| for q in BACKGROUND_QUESTIONS[:3]: | |
| answers[q["key"]] = st.text_area( | |
| q["label"], | |
| placeholder=q["placeholder"], | |
| height=100, | |
| key=f"bg_{q['key']}", | |
| ) | |
| # --- Groceries section --- | |
| st.markdown('<hr class="section-divider">', unsafe_allow_html=True) | |
| st.markdown('<div class="section-heading-grocery">🛒 Grocery Products</div>', unsafe_allow_html=True) | |
| for q in BACKGROUND_QUESTIONS[3:]: | |
| answers[q["key"]] = st.text_area( | |
| q["label"], | |
| placeholder=q["placeholder"], | |
| height=100, | |
| key=f"bg_{q['key']}", | |
| ) | |
| if st.button("Next →", type="primary", use_container_width=True): | |
| # Validate all answers | |
| for q in BACKGROUND_QUESTIONS: | |
| val = (answers.get(q["key"]) or "").strip() | |
| if not val: | |
| st.error(f"⚠️ Please answer: *{q['label']}*") | |
| return | |
| word_count = len(val.split()) | |
| if word_count < MIN_WORDS_BACKGROUND: | |
| st.error( | |
| f"⚠️ Please write at least {MIN_WORDS_BACKGROUND} words for: " | |
| f"*{q['label']}* ({word_count} so far)." | |
| ) | |
| return | |
| s["preferences_background"] = {q["key"]: answers[q["key"]].strip() for q in BACKGROUND_QUESTIONS} | |
| s["screen"] = "pair_intro" | |
| st.rerun() | |
| def screen_pair_intro(s): | |
| idx = s["current_pair_index"] | |
| pair = s["pairs"][idx] | |
| product_a = pair["product_a"] | |
| product_b = pair["product_b"] | |
| pair_category = pair.get("category", "") | |
| render_progress(idx + 1) | |
| st.markdown("## Product Comparison") | |
| st.markdown("Please read both products carefully, then answer the questions below.") | |
| # Show both product cards | |
| st.markdown(render_pair_cards_html(pair), unsafe_allow_html=True) | |
| # Familiarity for Product A | |
| st.markdown("---") | |
| fam_choices_a = get_familiarity_choices(product_a.get("category", pair_category)) | |
| familiarity_a = st.radio( | |
| f"How familiar are you with **Product A** (*{product_a.get('title', '')[:60]}*)?", | |
| fam_choices_a, | |
| index=None, | |
| key=f"fam_a_{idx}_{pair['pair_id']}", | |
| ) | |
| # Familiarity for Product B | |
| fam_choices_b = get_familiarity_choices(product_b.get("category", pair_category)) | |
| familiarity_b = st.radio( | |
| f"How familiar are you with **Product B** (*{product_b.get('title', '')[:60]}*)?", | |
| fam_choices_b, | |
| index=None, | |
| key=f"fam_b_{idx}_{pair['pair_id']}", | |
| ) | |
| # Initial preference | |
| st.markdown("---") | |
| pre_pref_val = st.radio( | |
| "Which product would you prefer to buy?", | |
| PREFERENCE_CHOICES, | |
| index=None, | |
| key=f"pre_pref_{idx}_{pair['pair_id']}", | |
| ) | |
| if st.button("Start Chat →", type="primary", use_container_width=True): | |
| if not DEBUG_MODE: | |
| if not familiarity_a: | |
| st.error("⚠️ Please rate your familiarity with Product A.") | |
| return | |
| if not familiarity_b: | |
| st.error("⚠️ Please rate your familiarity with Product B.") | |
| return | |
| if not pre_pref_val: | |
| st.error("⚠️ Please rate your preference.") | |
| return | |
| familiarity_a = familiarity_a or fam_choices_a[0] | |
| familiarity_b = familiarity_b or fam_choices_b[0] | |
| pre_pref_val = pre_pref_val or PREFERENCE_CHOICES[3] | |
| pre_val = parse_preference(pre_pref_val) | |
| persuasion_target = determine_persuasion_target(pre_val) | |
| s["pairs"][idx]["familiarity_a"] = familiarity_a | |
| s["pairs"][idx]["familiarity_b"] = familiarity_b | |
| s["pairs"][idx]["pre_preference"] = pre_val | |
| s["pairs"][idx]["pre_preference_label"] = PREFERENCE_LABELS[pre_val] | |
| s["pairs"][idx]["persuasion_target"] = persuasion_target | |
| system_prompt = build_persuasion_system_prompt(pair, persuasion_target, pre_val) | |
| preference_statement = build_preference_statement(pre_val) | |
| # Build the conversation: AI asks → user states preference → model generates persuasion | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "assistant", "content": OPENING_AI_QUESTION}, | |
| {"role": "user", "content": preference_statement}, | |
| ] | |
| with st.spinner("Starting conversation…"): | |
| ai_reply = call_model(messages) | |
| s["pairs"][idx]["conversation"]["system_prompt"] = system_prompt | |
| s["pairs"][idx]["conversation"]["opening_user_message"] = "" # no longer used | |
| s["pairs"][idx]["conversation"]["turns"] = [ | |
| {"turn_index": 0, "role": "assistant", "content": OPENING_AI_QUESTION, | |
| "timestamp": time.time(), "synthetic": True}, | |
| {"turn_index": 1, "role": "user", "content": preference_statement, | |
| "timestamp": time.time(), "synthetic": True}, | |
| {"turn_index": 2, "role": "assistant", "content": ai_reply, | |
| "timestamp": time.time(), "model": MODEL_NAME}, | |
| ] | |
| s["pairs"][idx]["conversation"]["num_turns"] = 0 | |
| s["screen"] = "chat" | |
| st.rerun() | |
| def screen_chat(s): | |
| idx = s["current_pair_index"] | |
| pair = s["pairs"][idx] | |
| conv = s["pairs"][idx]["conversation"] | |
| render_progress(idx + 1) | |
| st.markdown("## Chat with the AI") | |
| title_a = pair["product_a"].get("title", "Product A") | |
| title_b = pair["product_b"].get("title", "Product B") | |
| with st.expander("📦 Click to expand product details"): | |
| st.markdown(render_pair_cards_html(pair, compact=True), unsafe_allow_html=True) | |
| num_turns = conv["num_turns"] | |
| st.markdown( | |
| "Chat with the AI about which product you'd prefer. " | |
| "Ask questions, push back, or explore your thinking. " | |
| f"You need at least **{MIN_TURNS} exchanges** before you can move on." | |
| ) | |
| display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")] | |
| render_chat_history(display_turns) | |
| if num_turns >= MAX_TURNS: | |
| st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.") | |
| else: | |
| st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}") | |
| st.caption("💡 If you don't see the latest messages, scroll down while hovering over the conversation.") | |
| if num_turns < MAX_TURNS: | |
| user_msg = st.text_area( | |
| "Your response:", | |
| placeholder="Type your response here…", | |
| height=100, | |
| key=f"chat_input_{idx}_{num_turns}", | |
| ) | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| send_clicked = st.button("Send", type="primary", use_container_width=True) | |
| if send_clicked: | |
| if not user_msg or not user_msg.strip(): | |
| st.error("⚠️ Please type a message.") | |
| return | |
| if len(user_msg.strip().split()) < 5 and not DEBUG_MODE: | |
| st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).") | |
| return | |
| user_msg = user_msg.strip() | |
| messages = [ | |
| {"role": "system", "content": conv["system_prompt"]}, | |
| ] | |
| for turn in conv["turns"]: | |
| messages.append({"role": turn["role"], "content": turn["content"]}) | |
| messages.append({"role": "user", "content": user_msg}) | |
| with st.spinner("AI is responding…"): | |
| ai_reply = call_model(messages) | |
| conv["turns"].append({"turn_index": len(conv["turns"]), "role": "user", | |
| "content": user_msg, "timestamp": time.time()}) | |
| conv["turns"].append({"turn_index": len(conv["turns"]), "role": "assistant", | |
| "content": ai_reply, "timestamp": time.time(), | |
| "model": MODEL_NAME}) | |
| conv["num_turns"] = num_turns + 1 | |
| s["pairs"][idx]["conversation"] = conv | |
| st.rerun() | |
| can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE | |
| if can_finish: | |
| if st.button("I'm done chatting →", use_container_width=True): | |
| s["screen"] = "post_pref" | |
| st.rerun() | |
| else: | |
| st.button("I'm done chatting →", disabled=True, use_container_width=True, | |
| help=f"Complete at least {MIN_TURNS} exchanges first.") | |
| def screen_post_preference(s): | |
| idx = s["current_pair_index"] | |
| pair = s["pairs"][idx] | |
| render_progress(idx + 1) | |
| st.markdown("## Your Preference Now") | |
| st.markdown("Now that you've chatted with the AI, rate your preference again.") | |
| st.markdown(render_pair_cards_html(pair), unsafe_allow_html=True) | |
| post_pref_val = st.radio( | |
| "Which product would you prefer to buy now?", | |
| PREFERENCE_CHOICES, | |
| index=None, | |
| key=f"post_pref_{idx}_{pair['pair_id']}", | |
| ) | |
| if st.button("Next →", type="primary", use_container_width=True): | |
| if not post_pref_val and not DEBUG_MODE: | |
| st.error("⚠️ Please rate your preference.") | |
| return | |
| post_pref_val = post_pref_val or PREFERENCE_CHOICES[3] | |
| post_val = parse_preference(post_pref_val) | |
| pre_val = s["pairs"][idx].get("pre_preference", 4) | |
| delta = post_val - pre_val | |
| s["pairs"][idx]["post_preference"] = post_val | |
| s["pairs"][idx]["post_preference_label"] = PREFERENCE_LABELS[post_val] | |
| s["pairs"][idx]["preference_delta"] = delta | |
| s["screen"] = "reflection" | |
| st.rerun() | |
| def screen_reflection(s): | |
| idx = s["current_pair_index"] | |
| render_progress(idx + 1) | |
| st.markdown("## Reflection") | |
| standout = st.text_area( | |
| "What did the AI say that stood out to you most?", | |
| placeholder="Describe a specific argument, question, or moment from the conversation…", | |
| height=120, | |
| key=f"standout_{idx}", | |
| ) | |
| thinking_change = st.text_area( | |
| "How did your thinking about these products change (or not change) during the chat? Why?", | |
| placeholder="Be as specific as you can…", | |
| height=120, | |
| key=f"thinking_{idx}", | |
| ) | |
| next_label = "Next Pair →" if idx + 1 < PAIRS_PER_USER else "Submit Study →" | |
| if st.button(next_label, type="primary", use_container_width=True): | |
| if not DEBUG_MODE: | |
| if not standout or not standout.strip(): | |
| st.error("⚠️ Please answer the first reflection question.") | |
| return | |
| if len(standout.strip().split()) < 10: | |
| st.error( | |
| f"⚠️ Please write at least 10 words for the first question " | |
| f"({len(standout.strip().split())} so far)." | |
| ) | |
| return | |
| if not thinking_change or not thinking_change.strip(): | |
| st.error("⚠️ Please answer the second reflection question.") | |
| return | |
| if len(thinking_change.strip().split()) < 10: | |
| st.error( | |
| f"⚠️ Please write at least 10 words for the second question " | |
| f"({len(thinking_change.strip().split())} so far)." | |
| ) | |
| return | |
| standout = (standout or "").strip() or "[debug placeholder]" | |
| thinking_change = (thinking_change or "").strip() or "[debug placeholder]" | |
| s["pairs"][idx]["reflection"] = { | |
| "standout_moment": standout, | |
| "thinking_change": thinking_change, | |
| } | |
| next_idx = idx + 1 | |
| s["current_pair_index"] = next_idx | |
| if next_idx >= PAIRS_PER_USER: | |
| end_time = time.time() | |
| s["meta"] = { | |
| "submission_time": end_time, | |
| "duration_seconds": round(end_time - s.get("start_time", end_time), 1), | |
| "model": MODEL_NAME, | |
| "study_type": "preference", | |
| } | |
| with st.spinner("Saving your responses…"): | |
| save_and_upload(s) | |
| s["screen"] = "done" | |
| else: | |
| s["screen"] = "pair_intro" | |
| st.rerun() | |
| def screen_done(s): | |
| st.markdown("## ✅ Study Complete!") | |
| st.markdown("**Thank you for completing the study!**") | |
| st.markdown( | |
| f"Here's a summary of how your preferences changed across the {PAIRS_PER_USER} pairs:" | |
| ) | |
| rows = [] | |
| for i, pair in enumerate(s["pairs"]): | |
| pre = pair.get("pre_preference", "?") | |
| post = pair.get("post_preference", "?") | |
| delta = pair.get("preference_delta", 0) | |
| target = pair.get("persuasion_target", "?") | |
| arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️") | |
| cat_label = CATEGORY_DISPLAY.get(pair.get("category", ""), "") | |
| rows.append({ | |
| "#": i + 1, | |
| "Category": cat_label, | |
| "Product A": pair.get("product_a", {}).get("title", "")[:40] + "…", | |
| "Product B": pair.get("product_b", {}).get("title", "")[:40] + "…", | |
| "Before": PREFERENCE_LABELS.get(pre, str(pre)), | |
| "After": PREFERENCE_LABELS.get(post, str(post)), | |
| "AI argued for": f"Product {target}", | |
| "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–", | |
| }) | |
| import pandas as pd | |
| st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) | |
| st.markdown("---") | |
| st.success( | |
| f"**Your completion code:** `{PROLIFIC_COMPLETION_CODE}`\n\n" | |
| "Please copy this code and paste it on the Prolific website to complete your submission." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| st.set_page_config(page_title="Product Preference Study", page_icon="🛒", layout="centered") | |
| inject_css() | |
| if "study_state" not in st.session_state: | |
| st.session_state.study_state = init_state() | |
| s = st.session_state.study_state | |
| screen = s.get("screen", "welcome") | |
| if screen == "welcome": | |
| screen_welcome(s) | |
| elif screen == "demographics": | |
| screen_demographics(s) | |
| elif screen == "preferences_background": | |
| screen_preferences_background(s) | |
| elif screen == "pair_intro": | |
| screen_pair_intro(s) | |
| elif screen == "chat": | |
| screen_chat(s) | |
| elif screen == "post_pref": | |
| screen_post_preference(s) | |
| elif screen == "reflection": | |
| screen_reflection(s) | |
| elif screen == "done": | |
| screen_done(s) | |
| if __name__ == "__main__": | |
| main() |