Spaces:
Sleeping
Sleeping
| """ | |
| Shared screens used by both preference and likelihood study types: | |
| welcome, demographics, background, chat, post_rating, reflection, done. | |
| The chat screen contains the main conversation loop and is identical for both | |
| study types β the only difference is which compact product view is shown in | |
| the expander and which Likert labels are used. | |
| """ | |
| import time | |
| import streamlit as st | |
| from src.config import ( | |
| BACKGROUND_QUESTIONS, | |
| CATEGORY_DISPLAY, | |
| LIKELIHOOD_LABELS, | |
| MIN_WORDS_BACKGROUND, | |
| MIN_WORDS_REFLECTION, | |
| PREFERENCE_LABELS, | |
| ) | |
| from src.model import call_model | |
| from src.lsp_wrappers import format_demographics | |
| from src.ui.components import ( | |
| familiarity_choices, | |
| parse_rating, | |
| rating_choices, | |
| render_chat_history, | |
| render_pair_cards, | |
| render_progress, | |
| render_single_card, | |
| ) | |
| from src.upload import save_and_upload | |
| # ββ Debug defaults ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _DEBUG_DEMOGRAPHICS = { | |
| "age": "30", "gender": "Female", "geographic_region": "West", | |
| "education_level": "College graduate/some postgrad", "race": "White", | |
| "us_citizen": "Yes", "marital_status": "Single", | |
| "religion": "Agnostic", "religious_attendance": "Never", | |
| "political_affiliation": "Independent", "income": "$50,000-$75,000", | |
| "political_views": "Moderate", "household_size": "2", | |
| "employment_status": "Full-time employment", | |
| } | |
| def _debug_background(cfg: dict) -> dict: | |
| filler = " placeholder" * 6 | |
| return { | |
| q["key"]: f"[debug β {q['key']}{filler}]" | |
| for cat in cfg["categories"] | |
| for q in BACKGROUND_QUESTIONS[cat["name"]] | |
| } | |
| # ββ Welcome βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def screen_welcome(s: dict, cfg: dict) -> None: | |
| n = cfg["pairs_per_user"] | |
| study_type = cfg["study_type"] | |
| cat_names = [CATEGORY_DISPLAY.get(c["name"], c["name"]) for c in cfg["categories"]] | |
| cat_str = " and ".join(f"**{c}**" for c in cat_names) | |
| turns = cfg["min_turns"] | |
| st.markdown("# π Product Study") | |
| if study_type == "preference": | |
| st.markdown( | |
| f"Welcome! In this study you will compare **{n} pairs** of products " | |
| f"({cat_str}).\n\n" | |
| "For each pair you will:\n" | |
| "1. Review two products (Product A and Product B)\n" | |
| "2. Rate your familiarity with each product\n" | |
| "3. Rate which product you would prefer to buy (1β7 scale)\n" | |
| f"4. Chat with an AI product agent for **exactly {turns} exchanges**\n" | |
| "5. Rate your preference again after the conversation\n" | |
| "6. Answer two brief reflection questions\n\n" | |
| "The whole study takes about **30β40 minutes**. " | |
| "Please read each product carefully before chatting." | |
| ) | |
| else: | |
| st.markdown( | |
| f"Welcome! In this study you will evaluate **{n} products** ({cat_str}).\n\n" | |
| "For each product you will:\n" | |
| "1. Review the product\n" | |
| "2. Rate your familiarity with it\n" | |
| "3. Rate how likely you are to buy it (1β7 scale)\n" | |
| f"4. Chat with an AI product agent for **exactly {turns} exchanges**\n" | |
| "5. Rate your likelihood of buying again after the conversation\n" | |
| "6. Answer two brief reflection questions\n\n" | |
| "The whole study takes about **30β40 minutes**. " | |
| "Please read each product carefully before chatting." | |
| ) | |
| if st.button("Begin β", type="primary", use_container_width=True): | |
| if cfg["debug_mode"]: | |
| s["demographics"] = _DEBUG_DEMOGRAPHICS.copy() | |
| s["background"] = _debug_background(cfg) | |
| s["screen"] = "item_intro" | |
| else: | |
| s["screen"] = "demographics" | |
| st.rerun() | |
| # ββ Demographics ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def screen_demographics(s: dict, cfg: dict) -> None: | |
| st.markdown("## About You") | |
| st.markdown("All fields are required.") | |
| age = st.text_input("Age (years)", placeholder="e.g. 34") | |
| gender = st.selectbox("Gender", [ | |
| "", "Female", "Male", "Non-binary / third gender", "Prefer not to say", | |
| ]) | |
| geographic_region = st.selectbox("Geographic region (U.S.)", [ | |
| "", "West", "South", "Midwest", "Northeast", "Pacific", | |
| ]) | |
| education_level = st.selectbox("Highest education level", [ | |
| "", "Less than high school", "High school graduate", | |
| "Some college, no degree", "Associate's degree", | |
| "College graduate/some postgrad", "Postgraduate", | |
| ]) | |
| race = st.selectbox("Race / ethnicity", [ | |
| "", "Asian", "Hispanic", "White", "Black", "Other", | |
| ]) | |
| us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"]) | |
| marital_status = st.selectbox("Marital status", [ | |
| "", "Never been married", "Married", "Living with a partner", | |
| "Divorced", "Separated", "Widowed", | |
| ]) | |
| religion = st.selectbox("Religion", [ | |
| "", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish", | |
| "Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other", | |
| ]) | |
| religious_attendance = st.selectbox("How often do you attend religious services?", [ | |
| "", "Never", "Seldom", "A few times a year", "Once or twice a month", | |
| "Once a week", "More than once a week", | |
| ]) | |
| political_affiliation = st.selectbox("Political party affiliation", [ | |
| "", "Democrat", "Republican", "Independent", "Something else", | |
| ]) | |
| income = st.selectbox("Annual household income", [ | |
| "", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000", | |
| "$75,000-$100,000", "$100,000 or more", | |
| ]) | |
| political_views = st.selectbox("Political views", [ | |
| "", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative", | |
| ]) | |
| household_size = st.selectbox("Household size (number of people)", [ | |
| "", "1", "2", "3", "4", "More than 4", | |
| ]) | |
| employment_status = st.selectbox("Employment status", [ | |
| "", "Full-time employment", "Part-time employment", "Self-employed", | |
| "Unemployed", "Retired", "Home-maker", "Student", | |
| ]) | |
| if st.button("Next β", type="primary", use_container_width=True): | |
| all_fields = [ | |
| age, gender, geographic_region, education_level, race, | |
| us_citizen, marital_status, religion, religious_attendance, | |
| political_affiliation, income, political_views, | |
| household_size, employment_status, | |
| ] | |
| if not all(str(f).strip() for f in all_fields): | |
| st.error("β οΈ Please complete all fields before continuing.") | |
| return | |
| if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120): | |
| st.error("β οΈ Please enter a valid age (a number between 1 and 120).") | |
| return | |
| s["demographics"] = { | |
| "age": age.strip(), "gender": gender, | |
| "geographic_region": geographic_region, "education_level": education_level, | |
| "race": race, "us_citizen": us_citizen, "marital_status": marital_status, | |
| "religion": religion, "religious_attendance": religious_attendance, | |
| "political_affiliation": political_affiliation, "income": income, | |
| "political_views": political_views, "household_size": household_size, | |
| "employment_status": employment_status, | |
| } | |
| s["screen"] = "background" | |
| st.rerun() | |
| # ββ Background questions ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def screen_background(s: dict, cfg: dict) -> None: | |
| st.markdown("## Your Preferences β Before We Start") | |
| st.markdown( | |
| "Before you evaluate any products, we'd like to understand your general preferences. " | |
| f"Please write **at least {MIN_WORDS_BACKGROUND} words** for each question." | |
| ) | |
| active_cats = [c["name"] for c in cfg["categories"]] | |
| answers = {} | |
| first = True | |
| for cat_name in active_cats: | |
| if not first: | |
| st.markdown('<hr class="section-divider">', unsafe_allow_html=True) | |
| first = False | |
| if cat_name == "movies": | |
| st.markdown('<div class="section-heading">π¬ Movies & TV</div>', unsafe_allow_html=True) | |
| else: | |
| st.markdown('<div class="section-heading-grocery">π Grocery & Food Products</div>', unsafe_allow_html=True) | |
| for q in BACKGROUND_QUESTIONS[cat_name]: | |
| answers[q["key"]] = st.text_area( | |
| q["label"], | |
| placeholder=q["placeholder"], | |
| height=100, | |
| key=f"bg_{q['key']}", | |
| ) | |
| if st.button("Next β", type="primary", use_container_width=True): | |
| all_qs = [q for cat in active_cats for q in BACKGROUND_QUESTIONS[cat]] | |
| for q in all_qs: | |
| val = (answers.get(q["key"]) or "").strip() | |
| if not val: | |
| st.error(f"β οΈ Please answer: *{q['label']}*") | |
| return | |
| wc = len(val.split()) | |
| if wc < MIN_WORDS_BACKGROUND: | |
| st.error( | |
| f"β οΈ Please write at least {MIN_WORDS_BACKGROUND} words for: " | |
| f"*{q['label']}* ({wc} word{'s' if wc != 1 else ''} so far)." | |
| ) | |
| return | |
| s["background"] = {q["key"]: answers[q["key"]].strip() for q in all_qs} | |
| s["screen"] = "item_intro" | |
| st.rerun() | |
| # ββ Chat (shared for both study types) βββββββββββββββββββββββββββββββββββββββ | |
| def screen_chat(s: dict, cfg: dict) -> None: | |
| idx = s["current_index"] | |
| item = s["items"][idx] | |
| conv = item["conversation"] | |
| study_type = cfg["study_type"] | |
| min_turns = cfg["min_turns"] | |
| max_turns = cfg["max_turns"] | |
| num_turns = conv["num_turns"] | |
| render_progress(idx + 1, cfg["pairs_per_user"]) | |
| st.markdown("## Chat with the AI") | |
| # Compact product reminder in a collapsible expander | |
| with st.expander("π¦ View product details"): | |
| if study_type == "preference": | |
| render_pair_cards(item, compact=True) | |
| else: | |
| render_single_card(item["product"], compact=True) | |
| st.markdown( | |
| f"Chat with the AI product agent about the " | |
| f"{'products' if study_type == 'preference' else 'product'}. " | |
| "Ask questions, push back, or share your thoughts. " | |
| f"You need **exactly {min_turns} exchanges** before you can move on." | |
| ) | |
| st.markdown( | |
| '<div style="background:#fef9c3;border:1px solid #fde047;border-radius:8px;' | |
| 'padding:0.5rem 0.9rem;margin-bottom:0.5rem;font-size:0.88rem;color:#713f12;">' | |
| 'π¬ New messages appear at the bottom of the chat β scroll down to see the latest response.' | |
| '</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| render_chat_history(conv["turns"], study_type) | |
| if num_turns >= max_turns: | |
| st.info("β You have completed all required exchanges. Please proceed.") | |
| else: | |
| st.caption(f"Exchanges completed: {num_turns} / {max_turns}") | |
| # Input area (hidden once max_turns reached) | |
| if num_turns < max_turns: | |
| user_msg = st.text_area( | |
| "Your response:", | |
| placeholder="Type your message hereβ¦", | |
| height=100, | |
| key=f"chat_input_{idx}_{num_turns}", | |
| ) | |
| _, col_btn = st.columns([4, 1]) | |
| with col_btn: | |
| send_clicked = st.button("Send β", type="primary", use_container_width=True) | |
| if send_clicked: | |
| if not (user_msg or "").strip(): | |
| st.error("β οΈ Please type a message before sending.") | |
| return | |
| user_msg = user_msg.strip() | |
| if len(user_msg.split()) < 5 and not cfg["debug_mode"]: | |
| st.error( | |
| f"β οΈ Please write at least 5 words " | |
| f"({len(user_msg.split())} word{'s' if len(user_msg.split()) != 1 else ''} so far)." | |
| ) | |
| return | |
| # Build full message list: system prompt + all stored turns + new user message | |
| messages = [{"role": "system", "content": conv["system_prompt"]}] | |
| for t in conv["turns"]: | |
| messages.append({"role": t["role"], "content": t["content"]}) | |
| messages.append({"role": "user", "content": user_msg}) | |
| # Use per-item model name | |
| item_cfg = {**cfg, "model_name": item.get("model_name", ""), "sampler_path": item.get("sampler_path", "")} | |
| with st.spinner("AI is respondingβ¦"): | |
| ai_reply = call_model(messages, item_cfg) | |
| now = time.time() | |
| turn_base = len(conv["turns"]) | |
| conv["turns"].append({ | |
| "turn_index": turn_base, | |
| "role": "user", | |
| "content": user_msg, | |
| "timestamp": now, | |
| }) | |
| conv["turns"].append({ | |
| "turn_index": turn_base + 1, | |
| "role": "assistant", | |
| "content": ai_reply, | |
| "timestamp": now, | |
| "model": item.get("model_name", ""), # β per-item, not cfg | |
| }) | |
| conv["num_turns"] = num_turns + 1 | |
| s["items"][idx]["conversation"] = conv | |
| st.rerun() | |
| # Done button β enabled only when min_turns reached (or debug mode) | |
| can_proceed = num_turns >= min_turns or cfg["debug_mode"] | |
| if can_proceed: | |
| if st.button("I'm done chatting β", use_container_width=True): | |
| s["screen"] = "post_rating" | |
| st.rerun() | |
| else: | |
| remaining = min_turns - num_turns | |
| st.button( | |
| "I'm done chatting β", | |
| disabled=True, | |
| use_container_width=True, | |
| help=f"Complete {remaining} more exchange{'s' if remaining != 1 else ''} first.", | |
| ) | |
| # ββ Post-rating βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def screen_post_rating(s: dict, cfg: dict) -> None: | |
| idx = s["current_index"] | |
| item = s["items"][idx] | |
| study_type = cfg["study_type"] | |
| render_progress(idx + 1, cfg["pairs_per_user"]) | |
| st.markdown("## Your Rating After the Conversation") | |
| st.markdown("Now that you have chatted with the AI, please rate again.") | |
| if study_type == "preference": | |
| render_pair_cards(item) | |
| question = "Which product would you **prefer to buy** now?" | |
| else: | |
| render_single_card(item["product"]) | |
| question = "How **likely** are you to purchase this product now?" | |
| choices = rating_choices(study_type) | |
| post_val = st.radio(question, choices, index=None, key=f"post_rating_{idx}") | |
| if st.button("Next β", type="primary", use_container_width=True): | |
| if not post_val and not cfg["debug_mode"]: | |
| st.error("β οΈ Please select a rating before continuing.") | |
| return | |
| post_val = post_val or choices[3] # Neutral fallback in debug mode | |
| post_int = parse_rating(post_val) | |
| pre_int = item.get("pre_rating", 4) | |
| delta = post_int - (pre_int if isinstance(pre_int, int) else 4) | |
| s["items"][idx]["post_rating"] = post_int | |
| s["items"][idx]["rating_delta"] = delta | |
| s["screen"] = "reflection" | |
| st.rerun() | |
| # ββ Reflection ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def screen_reflection(s: dict, cfg: dict) -> None: | |
| idx = s["current_index"] | |
| n = cfg["pairs_per_user"] | |
| render_progress(idx + 1, n) | |
| st.markdown("## Reflection") | |
| st.markdown("Please write at least a sentence or two for each question.") | |
| standout = st.text_area( | |
| "What did the AI say that stood out to you most?", | |
| placeholder="Describe a specific argument, question, or moment from the conversationβ¦", | |
| height=120, | |
| key=f"standout_{idx}", | |
| ) | |
| thinking_change = st.text_area( | |
| "How did your thinking about this product change (or not change) during the chat? Why?", | |
| placeholder="Be as specific as you canβ¦", | |
| height=120, | |
| key=f"thinking_{idx}", | |
| ) | |
| is_last = (idx + 1 >= n) | |
| next_label = "Submit Study β" if is_last else "Next β" | |
| if st.button(next_label, type="primary", use_container_width=True): | |
| if not cfg["debug_mode"]: | |
| for field_name, val in [("first", standout), ("second", thinking_change)]: | |
| val = (val or "").strip() | |
| if not val: | |
| st.error(f"β οΈ Please answer the {field_name} reflection question.") | |
| return | |
| wc = len(val.split()) | |
| if wc < MIN_WORDS_REFLECTION: | |
| st.error( | |
| f"β οΈ Please write at least {MIN_WORDS_REFLECTION} words for the " | |
| f"{field_name} question ({wc} word{'s' if wc != 1 else ''} so far)." | |
| ) | |
| return | |
| s["items"][idx]["reflection"] = { | |
| "standout_moment": (standout or "").strip() or "[debug]", | |
| "thinking_change": (thinking_change or "").strip() or "[debug]", | |
| } | |
| next_idx = idx + 1 | |
| s["current_index"] = next_idx | |
| if next_idx >= n: | |
| end_time = time.time() | |
| s["meta"] = { | |
| "submission_time": end_time, | |
| "duration_seconds": round(end_time - s.get("start_time", end_time), 1), | |
| "study_type": cfg["study_type"], | |
| } | |
| with st.spinner("Saving your responsesβ¦"): | |
| save_and_upload(s, cfg) | |
| s["screen"] = "done" | |
| else: | |
| s["screen"] = "item_intro" | |
| st.rerun() | |
| # ββ Done ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def screen_done(s: dict, cfg: dict) -> None: | |
| import pandas as pd | |
| study_type = cfg["study_type"] | |
| n = cfg["pairs_per_user"] | |
| code = cfg.get("prolific_completion_code", "") | |
| labels = PREFERENCE_LABELS if study_type == "preference" else LIKELIHOOD_LABELS | |
| st.markdown("## β Study Complete β Thank You!") | |
| st.markdown( | |
| f"You have finished all {n} {'pairs' if study_type == 'preference' else 'products'}. " | |
| "Here is a summary of how your ratings changed:" | |
| ) | |
| rows = [] | |
| for i, item in enumerate(s["items"]): | |
| pre = item.get("pre_rating", None) | |
| post = item.get("post_rating", None) | |
| delta = item.get("rating_delta", 0) | |
| arrow = "β‘οΈ" if delta == 0 else ("β¬οΈ" if (delta or 0) > 0 else "β¬οΈ") | |
| cat = CATEGORY_DISPLAY.get(item.get("category", ""), "") | |
| if study_type == "preference": | |
| rows.append({ | |
| "#": i + 1, | |
| "Category": cat, | |
| "Product A": (item.get("product_a", {}).get("title", "") or "")[:45] + "β¦", | |
| "Product B": (item.get("product_b", {}).get("title", "") or "")[:45] + "β¦", | |
| "Pre-rating": labels.get(pre, str(pre) if pre is not None else "β"), | |
| "Post-rating": labels.get(post, str(post) if post is not None else "β"), | |
| "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "β", | |
| }) | |
| else: | |
| rows.append({ | |
| "#": i + 1, | |
| "Category": cat, | |
| "Product": (item.get("product", {}).get("title", "") or "")[:65] + "β¦", | |
| "Pre-rating": labels.get(pre, str(pre) if pre is not None else "β"), | |
| "Post-rating": labels.get(post, str(post) if post is not None else "β"), | |
| "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "β", | |
| }) | |
| st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) | |
| st.markdown("---") | |
| st.success( | |
| f"**Your Prolific completion code: `{code}`**\n\n" | |
| "Please copy this code and paste it into the Prolific website to complete your submission." | |
| ) |