"""
Shared screens used by both preference and likelihood study types:
welcome, demographics, background, chat, post_rating, reflection, done.
The chat screen contains the main conversation loop and is identical for both
study types — the only difference is which compact product view is shown in
the expander and which Likert labels are used.
"""
import time
import streamlit as st
from src.config import (
BACKGROUND_QUESTIONS,
CATEGORY_DISPLAY,
LIKELIHOOD_LABELS,
MIN_WORDS_BACKGROUND,
MIN_WORDS_REFLECTION,
PREFERENCE_LABELS,
)
from src.model import call_model
from src.lsp_wrappers import format_demographics
from src.ui.components import (
familiarity_choices,
parse_rating,
rating_choices,
render_chat_history,
render_pair_cards,
render_progress,
render_single_card,
)
from src.upload import save_and_upload
# ── Debug defaults ────────────────────────────────────────────────────────────
_DEBUG_DEMOGRAPHICS = {
"age": "30", "gender": "Female", "geographic_region": "West",
"education_level": "College graduate/some postgrad", "race": "White",
"us_citizen": "Yes", "marital_status": "Single",
"religion": "Agnostic", "religious_attendance": "Never",
"political_affiliation": "Independent", "income": "$50,000-$75,000",
"political_views": "Moderate", "household_size": "2",
"employment_status": "Full-time employment",
}
def _debug_background(cfg: dict) -> dict:
filler = " placeholder" * 6
return {
q["key"]: f"[debug — {q['key']}{filler}]"
for cat in cfg["categories"]
for q in BACKGROUND_QUESTIONS[cat["name"]]
}
# ── Welcome ───────────────────────────────────────────────────────────────────
def screen_welcome(s: dict, cfg: dict) -> None:
n = cfg["pairs_per_user"]
study_type = cfg["study_type"]
cat_names = [CATEGORY_DISPLAY.get(c["name"], c["name"]) for c in cfg["categories"]]
cat_str = " and ".join(f"**{c}**" for c in cat_names)
turns = cfg["min_turns"]
st.markdown("# 🛒 Product Study")
if study_type == "preference":
st.markdown(
f"Welcome! In this study you will compare **{n} pairs** of products "
f"({cat_str}).\n\n"
"For each pair you will:\n"
"1. Review two products (Product A and Product B)\n"
"2. Rate your familiarity with each product\n"
"3. Rate which product you would prefer to buy (1–7 scale)\n"
f"4. Chat with an AI product agent for **exactly {turns} exchanges**\n"
"5. Rate your preference again after the conversation\n"
"6. Answer two brief reflection questions\n\n"
"The whole study takes about **30–40 minutes**. "
"Please read each product carefully before chatting."
)
else:
st.markdown(
f"Welcome! In this study you will evaluate **{n} products** ({cat_str}).\n\n"
"For each product you will:\n"
"1. Review the product\n"
"2. Rate your familiarity with it\n"
"3. Rate how likely you are to buy it (1–7 scale)\n"
f"4. Chat with an AI product agent for **exactly {turns} exchanges**\n"
"5. Rate your likelihood of buying again after the conversation\n"
"6. Answer two brief reflection questions\n\n"
"The whole study takes about **30–40 minutes**. "
"Please read each product carefully before chatting."
)
if st.button("Begin →", type="primary", use_container_width=True):
if cfg["debug_mode"]:
s["demographics"] = _DEBUG_DEMOGRAPHICS.copy()
s["background"] = _debug_background(cfg)
s["screen"] = "item_intro"
else:
s["screen"] = "demographics"
st.rerun()
# ── Demographics ──────────────────────────────────────────────────────────────
def screen_demographics(s: dict, cfg: dict) -> None:
st.markdown("## About You")
st.markdown("All fields are required.")
age = st.text_input("Age (years)", placeholder="e.g. 34")
gender = st.selectbox("Gender", [
"", "Female", "Male", "Non-binary / third gender", "Prefer not to say",
])
geographic_region = st.selectbox("Geographic region (U.S.)", [
"", "West", "South", "Midwest", "Northeast", "Pacific",
])
education_level = st.selectbox("Highest education level", [
"", "Less than high school", "High school graduate",
"Some college, no degree", "Associate's degree",
"College graduate/some postgrad", "Postgraduate",
])
race = st.selectbox("Race / ethnicity", [
"", "Asian", "Hispanic", "White", "Black", "Other",
])
us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"])
marital_status = st.selectbox("Marital status", [
"", "Never been married", "Married", "Living with a partner",
"Divorced", "Separated", "Widowed",
])
religion = st.selectbox("Religion", [
"", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish",
"Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other",
])
religious_attendance = st.selectbox("How often do you attend religious services?", [
"", "Never", "Seldom", "A few times a year", "Once or twice a month",
"Once a week", "More than once a week",
])
political_affiliation = st.selectbox("Political party affiliation", [
"", "Democrat", "Republican", "Independent", "Something else",
])
income = st.selectbox("Annual household income", [
"", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000",
"$75,000-$100,000", "$100,000 or more",
])
political_views = st.selectbox("Political views", [
"", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative",
])
household_size = st.selectbox("Household size (number of people)", [
"", "1", "2", "3", "4", "More than 4",
])
employment_status = st.selectbox("Employment status", [
"", "Full-time employment", "Part-time employment", "Self-employed",
"Unemployed", "Retired", "Home-maker", "Student",
])
if st.button("Next →", type="primary", use_container_width=True):
all_fields = [
age, gender, geographic_region, education_level, race,
us_citizen, marital_status, religion, religious_attendance,
political_affiliation, income, political_views,
household_size, employment_status,
]
if not all(str(f).strip() for f in all_fields):
st.error("⚠️ Please complete all fields before continuing.")
return
if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120):
st.error("⚠️ Please enter a valid age (a number between 1 and 120).")
return
s["demographics"] = {
"age": age.strip(), "gender": gender,
"geographic_region": geographic_region, "education_level": education_level,
"race": race, "us_citizen": us_citizen, "marital_status": marital_status,
"religion": religion, "religious_attendance": religious_attendance,
"political_affiliation": political_affiliation, "income": income,
"political_views": political_views, "household_size": household_size,
"employment_status": employment_status,
}
s["screen"] = "background"
st.rerun()
# ── Background questions ──────────────────────────────────────────────────────
def screen_background(s: dict, cfg: dict) -> None:
st.markdown("## Your Preferences — Before We Start")
st.markdown(
"Before you evaluate any products, we'd like to understand your general preferences. "
f"Please write **at least {MIN_WORDS_BACKGROUND} words** for each question."
)
active_cats = [c["name"] for c in cfg["categories"]]
answers = {}
first = True
for cat_name in active_cats:
if not first:
st.markdown('
', unsafe_allow_html=True)
first = False
if cat_name == "movies":
st.markdown('🎬 Movies & TV
', unsafe_allow_html=True)
else:
st.markdown('🛒 Grocery & Food Products
', unsafe_allow_html=True)
for q in BACKGROUND_QUESTIONS[cat_name]:
answers[q["key"]] = st.text_area(
q["label"],
placeholder=q["placeholder"],
height=100,
key=f"bg_{q['key']}",
)
if st.button("Next →", type="primary", use_container_width=True):
all_qs = [q for cat in active_cats for q in BACKGROUND_QUESTIONS[cat]]
for q in all_qs:
val = (answers.get(q["key"]) or "").strip()
if not val:
st.error(f"⚠️ Please answer: *{q['label']}*")
return
wc = len(val.split())
if wc < MIN_WORDS_BACKGROUND:
st.error(
f"⚠️ Please write at least {MIN_WORDS_BACKGROUND} words for: "
f"*{q['label']}* ({wc} word{'s' if wc != 1 else ''} so far)."
)
return
s["background"] = {q["key"]: answers[q["key"]].strip() for q in all_qs}
s["screen"] = "item_intro"
st.rerun()
# ── Chat (shared for both study types) ───────────────────────────────────────
def screen_chat(s: dict, cfg: dict) -> None:
idx = s["current_index"]
item = s["items"][idx]
conv = item["conversation"]
study_type = cfg["study_type"]
min_turns = cfg["min_turns"]
max_turns = cfg["max_turns"]
num_turns = conv["num_turns"]
render_progress(idx + 1, cfg["pairs_per_user"])
st.markdown("## Chat with the AI")
# Compact product reminder in a collapsible expander
with st.expander("📦 View product details"):
if study_type == "preference":
render_pair_cards(item, compact=True)
else:
render_single_card(item["product"], compact=True)
st.markdown(
f"Chat with the AI product agent about the "
f"{'products' if study_type == 'preference' else 'product'}. "
"Ask questions, push back, or share your thoughts. "
f"You need **exactly {min_turns} exchanges** before you can move on."
)
st.markdown(
''
'💬 New messages appear at the bottom of the chat — scroll down to see the latest response.'
'
',
unsafe_allow_html=True,
)
render_chat_history(conv["turns"], study_type)
if num_turns >= max_turns:
st.info("✅ You have completed all required exchanges. Please proceed.")
else:
st.caption(f"Exchanges completed: {num_turns} / {max_turns}")
# Input area (hidden once max_turns reached)
if num_turns < max_turns:
user_msg = st.text_area(
"Your response:",
placeholder="Type your message here…",
height=100,
key=f"chat_input_{idx}_{num_turns}",
)
_, col_btn = st.columns([4, 1])
with col_btn:
send_clicked = st.button("Send →", type="primary", use_container_width=True)
if send_clicked:
if not (user_msg or "").strip():
st.error("⚠️ Please type a message before sending.")
return
user_msg = user_msg.strip()
if len(user_msg.split()) < 5 and not cfg["debug_mode"]:
st.error(
f"⚠️ Please write at least 5 words "
f"({len(user_msg.split())} word{'s' if len(user_msg.split()) != 1 else ''} so far)."
)
return
# Build full message list: system prompt + all stored turns + new user message
messages = [{"role": "system", "content": conv["system_prompt"]}]
for t in conv["turns"]:
messages.append({"role": t["role"], "content": t["content"]})
messages.append({"role": "user", "content": user_msg})
# Use per-item model name
item_cfg = {**cfg, "model_name": item.get("model_name", ""), "sampler_path": item.get("sampler_path", "")}
with st.spinner("AI is responding…"):
ai_reply = call_model(messages, item_cfg)
now = time.time()
turn_base = len(conv["turns"])
conv["turns"].append({
"turn_index": turn_base,
"role": "user",
"content": user_msg,
"timestamp": now,
})
conv["turns"].append({
"turn_index": turn_base + 1,
"role": "assistant",
"content": ai_reply,
"timestamp": now,
"model": item.get("model_name", ""), # ← per-item, not cfg
})
conv["num_turns"] = num_turns + 1
s["items"][idx]["conversation"] = conv
st.rerun()
# Done button — enabled only when min_turns reached (or debug mode)
can_proceed = num_turns >= min_turns or cfg["debug_mode"]
if can_proceed:
if st.button("I'm done chatting →", use_container_width=True):
s["screen"] = "post_rating"
st.rerun()
else:
remaining = min_turns - num_turns
st.button(
"I'm done chatting →",
disabled=True,
use_container_width=True,
help=f"Complete {remaining} more exchange{'s' if remaining != 1 else ''} first.",
)
# ── Post-rating ───────────────────────────────────────────────────────────────
def screen_post_rating(s: dict, cfg: dict) -> None:
idx = s["current_index"]
item = s["items"][idx]
study_type = cfg["study_type"]
render_progress(idx + 1, cfg["pairs_per_user"])
st.markdown("## Your Rating After the Conversation")
st.markdown("Now that you have chatted with the AI, please rate again.")
if study_type == "preference":
render_pair_cards(item)
question = "Which product would you **prefer to buy** now?"
else:
render_single_card(item["product"])
question = "How **likely** are you to purchase this product now?"
choices = rating_choices(study_type)
post_val = st.radio(question, choices, index=None, key=f"post_rating_{idx}")
if st.button("Next →", type="primary", use_container_width=True):
if not post_val and not cfg["debug_mode"]:
st.error("⚠️ Please select a rating before continuing.")
return
post_val = post_val or choices[3] # Neutral fallback in debug mode
post_int = parse_rating(post_val)
pre_int = item.get("pre_rating", 4)
delta = post_int - (pre_int if isinstance(pre_int, int) else 4)
s["items"][idx]["post_rating"] = post_int
s["items"][idx]["rating_delta"] = delta
s["screen"] = "reflection"
st.rerun()
# ── Reflection ────────────────────────────────────────────────────────────────
def screen_reflection(s: dict, cfg: dict) -> None:
idx = s["current_index"]
n = cfg["pairs_per_user"]
render_progress(idx + 1, n)
st.markdown("## Reflection")
st.markdown("Please write at least a sentence or two for each question.")
standout = st.text_area(
"What did the AI say that stood out to you most?",
placeholder="Describe a specific argument, question, or moment from the conversation…",
height=120,
key=f"standout_{idx}",
)
thinking_change = st.text_area(
"How did your thinking about this product change (or not change) during the chat? Why?",
placeholder="Be as specific as you can…",
height=120,
key=f"thinking_{idx}",
)
is_last = (idx + 1 >= n)
next_label = "Submit Study →" if is_last else "Next →"
if st.button(next_label, type="primary", use_container_width=True):
if not cfg["debug_mode"]:
for field_name, val in [("first", standout), ("second", thinking_change)]:
val = (val or "").strip()
if not val:
st.error(f"⚠️ Please answer the {field_name} reflection question.")
return
wc = len(val.split())
if wc < MIN_WORDS_REFLECTION:
st.error(
f"⚠️ Please write at least {MIN_WORDS_REFLECTION} words for the "
f"{field_name} question ({wc} word{'s' if wc != 1 else ''} so far)."
)
return
s["items"][idx]["reflection"] = {
"standout_moment": (standout or "").strip() or "[debug]",
"thinking_change": (thinking_change or "").strip() or "[debug]",
}
next_idx = idx + 1
s["current_index"] = next_idx
if next_idx >= n:
end_time = time.time()
s["meta"] = {
"submission_time": end_time,
"duration_seconds": round(end_time - s.get("start_time", end_time), 1),
"study_type": cfg["study_type"],
}
with st.spinner("Saving your responses…"):
save_and_upload(s, cfg)
s["screen"] = "done"
else:
s["screen"] = "item_intro"
st.rerun()
# ── Done ──────────────────────────────────────────────────────────────────────
def screen_done(s: dict, cfg: dict) -> None:
import pandas as pd
study_type = cfg["study_type"]
n = cfg["pairs_per_user"]
code = cfg.get("prolific_completion_code", "")
labels = PREFERENCE_LABELS if study_type == "preference" else LIKELIHOOD_LABELS
st.markdown("## ✅ Study Complete — Thank You!")
st.markdown(
f"You have finished all {n} {'pairs' if study_type == 'preference' else 'products'}. "
"Here is a summary of how your ratings changed:"
)
rows = []
for i, item in enumerate(s["items"]):
pre = item.get("pre_rating", None)
post = item.get("post_rating", None)
delta = item.get("rating_delta", 0)
arrow = "➡️" if delta == 0 else ("⬆️" if (delta or 0) > 0 else "⬇️")
cat = CATEGORY_DISPLAY.get(item.get("category", ""), "")
if study_type == "preference":
rows.append({
"#": i + 1,
"Category": cat,
"Product A": (item.get("product_a", {}).get("title", "") or "")[:45] + "…",
"Product B": (item.get("product_b", {}).get("title", "") or "")[:45] + "…",
"Pre-rating": labels.get(pre, str(pre) if pre is not None else "—"),
"Post-rating": labels.get(post, str(post) if post is not None else "—"),
"Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "—",
})
else:
rows.append({
"#": i + 1,
"Category": cat,
"Product": (item.get("product", {}).get("title", "") or "")[:65] + "…",
"Pre-rating": labels.get(pre, str(pre) if pre is not None else "—"),
"Post-rating": labels.get(post, str(post) if post is not None else "—"),
"Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "—",
})
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
st.markdown("---")
st.success(
f"**Your Prolific completion code: `{code}`**\n\n"
"Please copy this code and paste it into the Prolific website to complete your submission."
)