File size: 21,225 Bytes
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6508a6c
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d34de84
 
 
 
 
 
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a40c8e4
584421e
6b23da9
6b0bcdc
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
a40c8e4
6b23da9
a40c8e4
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a40c8e4
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a40c8e4
 
 
 
6b23da9
 
 
 
 
 
a40c8e4
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
"""
Shared screens used by both preference and likelihood study types:
  welcome, demographics, background, chat, post_rating, reflection, done.

The chat screen contains the main conversation loop and is identical for both
study types β€” the only difference is which compact product view is shown in
the expander and which Likert labels are used.
"""
import time

import streamlit as st

from src.config import (
    BACKGROUND_QUESTIONS,
    CATEGORY_DISPLAY,
    LIKELIHOOD_LABELS,
    MIN_WORDS_BACKGROUND,
    MIN_WORDS_REFLECTION,
    PREFERENCE_LABELS,
)
from src.model import call_model
from src.lsp_wrappers import format_demographics
from src.ui.components import (
    familiarity_choices,
    parse_rating,
    rating_choices,
    render_chat_history,
    render_pair_cards,
    render_progress,
    render_single_card,
)
from src.upload import save_and_upload


# ── Debug defaults ────────────────────────────────────────────────────────────

_DEBUG_DEMOGRAPHICS = {
    "age": "30", "gender": "Female", "geographic_region": "West",
    "education_level": "College graduate/some postgrad", "race": "White",
    "us_citizen": "Yes", "marital_status": "Single",
    "religion": "Agnostic", "religious_attendance": "Never",
    "political_affiliation": "Independent", "income": "$50,000-$75,000",
    "political_views": "Moderate", "household_size": "2",
    "employment_status": "Full-time employment",
}


def _debug_background(cfg: dict) -> dict:
    filler = " placeholder" * 6
    return {
        q["key"]: f"[debug β€” {q['key']}{filler}]"
        for cat in cfg["categories"]
        for q in BACKGROUND_QUESTIONS[cat["name"]]
    }


# ── Welcome ───────────────────────────────────────────────────────────────────

def screen_welcome(s: dict, cfg: dict) -> None:
    n          = cfg["pairs_per_user"]
    study_type = cfg["study_type"]
    cat_names  = [CATEGORY_DISPLAY.get(c["name"], c["name"]) for c in cfg["categories"]]
    cat_str    = " and ".join(f"**{c}**" for c in cat_names)
    turns      = cfg["min_turns"]

    st.markdown("# πŸ›’ Product Study")

    if study_type == "preference":
        st.markdown(
            f"Welcome! In this study you will compare **{n} pairs** of products "
            f"({cat_str}).\n\n"
            "For each pair you will:\n"
            "1. Review two products (Product A and Product B)\n"
            "2. Rate your familiarity with each product\n"
            "3. Rate which product you would prefer to buy (1–7 scale)\n"
            f"4. Chat with an AI product agent for **exactly {turns} exchanges**\n"
            "5. Rate your preference again after the conversation\n"
            "6. Answer two brief reflection questions\n\n"
            "The whole study takes about **30–40 minutes**. "
            "Please read each product carefully before chatting."
        )
    else:
        st.markdown(
            f"Welcome! In this study you will evaluate **{n} products** ({cat_str}).\n\n"
            "For each product you will:\n"
            "1. Review the product\n"
            "2. Rate your familiarity with it\n"
            "3. Rate how likely you are to buy it (1–7 scale)\n"
            f"4. Chat with an AI product agent for **exactly {turns} exchanges**\n"
            "5. Rate your likelihood of buying again after the conversation\n"
            "6. Answer two brief reflection questions\n\n"
            "The whole study takes about **30–40 minutes**. "
            "Please read each product carefully before chatting."
        )

    if st.button("Begin β†’", type="primary", use_container_width=True):
        if cfg["debug_mode"]:
            s["demographics"] = _DEBUG_DEMOGRAPHICS.copy()
            s["background"]   = _debug_background(cfg)
            s["screen"]       = "item_intro"
        else:
            s["screen"] = "demographics"
        st.rerun()


# ── Demographics ──────────────────────────────────────────────────────────────

def screen_demographics(s: dict, cfg: dict) -> None:
    st.markdown("## About You")
    st.markdown("All fields are required.")

    age = st.text_input("Age (years)", placeholder="e.g. 34")

    gender = st.selectbox("Gender", [
        "", "Female", "Male", "Non-binary / third gender", "Prefer not to say",
    ])
    geographic_region = st.selectbox("Geographic region (U.S.)", [
        "", "West", "South", "Midwest", "Northeast", "Pacific",
    ])
    education_level = st.selectbox("Highest education level", [
        "", "Less than high school", "High school graduate",
        "Some college, no degree", "Associate's degree",
        "College graduate/some postgrad", "Postgraduate",
    ])
    race = st.selectbox("Race / ethnicity", [
        "", "Asian", "Hispanic", "White", "Black", "Other",
    ])
    us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"])
    marital_status = st.selectbox("Marital status", [
        "", "Never been married", "Married", "Living with a partner",
        "Divorced", "Separated", "Widowed",
    ])
    religion = st.selectbox("Religion", [
        "", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish",
        "Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other",
    ])
    religious_attendance = st.selectbox("How often do you attend religious services?", [
        "", "Never", "Seldom", "A few times a year", "Once or twice a month",
        "Once a week", "More than once a week",
    ])
    political_affiliation = st.selectbox("Political party affiliation", [
        "", "Democrat", "Republican", "Independent", "Something else",
    ])
    income = st.selectbox("Annual household income", [
        "", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000",
        "$75,000-$100,000", "$100,000 or more",
    ])
    political_views = st.selectbox("Political views", [
        "", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative",
    ])
    household_size = st.selectbox("Household size (number of people)", [
        "", "1", "2", "3", "4", "More than 4",
    ])
    employment_status = st.selectbox("Employment status", [
        "", "Full-time employment", "Part-time employment", "Self-employed",
        "Unemployed", "Retired", "Home-maker", "Student",
    ])

    if st.button("Next β†’", type="primary", use_container_width=True):
        all_fields = [
            age, gender, geographic_region, education_level, race,
            us_citizen, marital_status, religion, religious_attendance,
            political_affiliation, income, political_views,
            household_size, employment_status,
        ]
        if not all(str(f).strip() for f in all_fields):
            st.error("⚠️ Please complete all fields before continuing.")
            return
        if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120):
            st.error("⚠️ Please enter a valid age (a number between 1 and 120).")
            return

        s["demographics"] = {
            "age": age.strip(), "gender": gender,
            "geographic_region": geographic_region, "education_level": education_level,
            "race": race, "us_citizen": us_citizen, "marital_status": marital_status,
            "religion": religion, "religious_attendance": religious_attendance,
            "political_affiliation": political_affiliation, "income": income,
            "political_views": political_views, "household_size": household_size,
            "employment_status": employment_status,
        }
        s["screen"] = "background"
        st.rerun()


# ── Background questions ──────────────────────────────────────────────────────

def screen_background(s: dict, cfg: dict) -> None:
    st.markdown("## Your Preferences β€” Before We Start")
    st.markdown(
        "Before you evaluate any products, we'd like to understand your general preferences. "
        f"Please write **at least {MIN_WORDS_BACKGROUND} words** for each question."
    )

    active_cats = [c["name"] for c in cfg["categories"]]
    answers     = {}
    first       = True

    for cat_name in active_cats:
        if not first:
            st.markdown('<hr class="section-divider">', unsafe_allow_html=True)
        first = False

        if cat_name == "movies":
            st.markdown('<div class="section-heading">🎬 Movies &amp; TV</div>', unsafe_allow_html=True)
        else:
            st.markdown('<div class="section-heading-grocery">πŸ›’ Grocery &amp; Food Products</div>', unsafe_allow_html=True)

        for q in BACKGROUND_QUESTIONS[cat_name]:
            answers[q["key"]] = st.text_area(
                q["label"],
                placeholder=q["placeholder"],
                height=100,
                key=f"bg_{q['key']}",
            )

    if st.button("Next β†’", type="primary", use_container_width=True):
        all_qs = [q for cat in active_cats for q in BACKGROUND_QUESTIONS[cat]]
        for q in all_qs:
            val = (answers.get(q["key"]) or "").strip()
            if not val:
                st.error(f"⚠️ Please answer: *{q['label']}*")
                return
            wc = len(val.split())
            if wc < MIN_WORDS_BACKGROUND:
                st.error(
                    f"⚠️ Please write at least {MIN_WORDS_BACKGROUND} words for: "
                    f"*{q['label']}* ({wc} word{'s' if wc != 1 else ''} so far)."
                )
                return

        s["background"] = {q["key"]: answers[q["key"]].strip() for q in all_qs}
        s["screen"] = "item_intro"
        st.rerun()


# ── Chat (shared for both study types) ───────────────────────────────────────

def screen_chat(s: dict, cfg: dict) -> None:
    idx        = s["current_index"]
    item       = s["items"][idx]
    conv       = item["conversation"]
    study_type = cfg["study_type"]
    min_turns  = cfg["min_turns"]
    max_turns  = cfg["max_turns"]
    num_turns  = conv["num_turns"]

    render_progress(idx + 1, cfg["pairs_per_user"])
    st.markdown("## Chat with the AI")

    # Compact product reminder in a collapsible expander
    with st.expander("πŸ“¦ View product details"):
        if study_type == "preference":
            render_pair_cards(item, compact=True)
        else:
            render_single_card(item["product"], compact=True)

    st.markdown(
        f"Chat with the AI product agent about the "
        f"{'products' if study_type == 'preference' else 'product'}. "
        "Ask questions, push back, or share your thoughts. "
        f"You need **exactly {min_turns} exchanges** before you can move on."
    )

    st.markdown(
        '<div style="background:#fef9c3;border:1px solid #fde047;border-radius:8px;'
        'padding:0.5rem 0.9rem;margin-bottom:0.5rem;font-size:0.88rem;color:#713f12;">'
        'πŸ’¬ New messages appear at the bottom of the chat β€” scroll down to see the latest response.'
        '</div>',
        unsafe_allow_html=True,
    )

    render_chat_history(conv["turns"], study_type)

    if num_turns >= max_turns:
        st.info("βœ… You have completed all required exchanges. Please proceed.")
    else:
        st.caption(f"Exchanges completed: {num_turns} / {max_turns}")

    # Input area (hidden once max_turns reached)
    if num_turns < max_turns:
        user_msg = st.text_area(
            "Your response:",
            placeholder="Type your message here…",
            height=100,
            key=f"chat_input_{idx}_{num_turns}",
        )
        _, col_btn = st.columns([4, 1])
        with col_btn:
            send_clicked = st.button("Send β†’", type="primary", use_container_width=True)

        if send_clicked:
            if not (user_msg or "").strip():
                st.error("⚠️ Please type a message before sending.")
                return
            user_msg = user_msg.strip()
            if len(user_msg.split()) < 5 and not cfg["debug_mode"]:
                st.error(
                    f"⚠️ Please write at least 5 words "
                    f"({len(user_msg.split())} word{'s' if len(user_msg.split()) != 1 else ''} so far)."
                )
                return

            # Build full message list: system prompt + all stored turns + new user message
            messages = [{"role": "system", "content": conv["system_prompt"]}]
            for t in conv["turns"]:
                messages.append({"role": t["role"], "content": t["content"]})
            messages.append({"role": "user", "content": user_msg})

            # Use per-item model name
            item_cfg = {**cfg, "model_name": item.get("model_name", ""), "sampler_path":   item.get("sampler_path", "")}
            with st.spinner("AI is responding…"):
                ai_reply = call_model(messages, item_cfg)

            now       = time.time()
            turn_base = len(conv["turns"])
            conv["turns"].append({
                "turn_index": turn_base,
                "role":       "user",
                "content":    user_msg,
                "timestamp":  now,
            })
            conv["turns"].append({
                "turn_index": turn_base + 1,
                "role":       "assistant",
                "content":    ai_reply,
                "timestamp":  now,
                "model":      item.get("model_name", ""),   # ← per-item, not cfg
            })
            conv["num_turns"]               = num_turns + 1
            s["items"][idx]["conversation"] = conv
            st.rerun()

    # Done button β€” enabled only when min_turns reached (or debug mode)
    can_proceed = num_turns >= min_turns or cfg["debug_mode"]
    if can_proceed:
        if st.button("I'm done chatting β†’", use_container_width=True):
            s["screen"] = "post_rating"
            st.rerun()
    else:
        remaining = min_turns - num_turns
        st.button(
            "I'm done chatting β†’",
            disabled=True,
            use_container_width=True,
            help=f"Complete {remaining} more exchange{'s' if remaining != 1 else ''} first.",
        )


# ── Post-rating ───────────────────────────────────────────────────────────────

def screen_post_rating(s: dict, cfg: dict) -> None:
    idx        = s["current_index"]
    item       = s["items"][idx]
    study_type = cfg["study_type"]

    render_progress(idx + 1, cfg["pairs_per_user"])
    st.markdown("## Your Rating After the Conversation")
    st.markdown("Now that you have chatted with the AI, please rate again.")

    if study_type == "preference":
        render_pair_cards(item)
        question = "Which product would you **prefer to buy** now?"
    else:
        render_single_card(item["product"])
        question = "How **likely** are you to purchase this product now?"

    choices  = rating_choices(study_type)
    post_val = st.radio(question, choices, index=None, key=f"post_rating_{idx}")

    if st.button("Next β†’", type="primary", use_container_width=True):
        if not post_val and not cfg["debug_mode"]:
            st.error("⚠️ Please select a rating before continuing.")
            return
        post_val = post_val or choices[3]   # Neutral fallback in debug mode
        post_int = parse_rating(post_val)
        pre_int  = item.get("pre_rating", 4)
        delta    = post_int - (pre_int if isinstance(pre_int, int) else 4)

        s["items"][idx]["post_rating"]  = post_int
        s["items"][idx]["rating_delta"] = delta
        s["screen"]                     = "reflection"
        st.rerun()


# ── Reflection ────────────────────────────────────────────────────────────────

def screen_reflection(s: dict, cfg: dict) -> None:
    idx = s["current_index"]
    n   = cfg["pairs_per_user"]

    render_progress(idx + 1, n)
    st.markdown("## Reflection")
    st.markdown("Please write at least a sentence or two for each question.")

    standout = st.text_area(
        "What did the AI say that stood out to you most?",
        placeholder="Describe a specific argument, question, or moment from the conversation…",
        height=120,
        key=f"standout_{idx}",
    )
    thinking_change = st.text_area(
        "How did your thinking about this product change (or not change) during the chat? Why?",
        placeholder="Be as specific as you can…",
        height=120,
        key=f"thinking_{idx}",
    )

    is_last    = (idx + 1 >= n)
    next_label = "Submit Study β†’" if is_last else "Next β†’"

    if st.button(next_label, type="primary", use_container_width=True):
        if not cfg["debug_mode"]:
            for field_name, val in [("first", standout), ("second", thinking_change)]:
                val = (val or "").strip()
                if not val:
                    st.error(f"⚠️ Please answer the {field_name} reflection question.")
                    return
                wc = len(val.split())
                if wc < MIN_WORDS_REFLECTION:
                    st.error(
                        f"⚠️ Please write at least {MIN_WORDS_REFLECTION} words for the "
                        f"{field_name} question ({wc} word{'s' if wc != 1 else ''} so far)."
                    )
                    return

        s["items"][idx]["reflection"] = {
            "standout_moment": (standout or "").strip() or "[debug]",
            "thinking_change": (thinking_change or "").strip() or "[debug]",
        }

        next_idx           = idx + 1
        s["current_index"] = next_idx

        if next_idx >= n:
            end_time  = time.time()
            s["meta"] = {
                "submission_time":  end_time,
                "duration_seconds": round(end_time - s.get("start_time", end_time), 1),
                "study_type":       cfg["study_type"],
            }
            with st.spinner("Saving your responses…"):
                save_and_upload(s, cfg)
            s["screen"] = "done"
        else:
            s["screen"] = "item_intro"
        st.rerun()


# ── Done ──────────────────────────────────────────────────────────────────────

def screen_done(s: dict, cfg: dict) -> None:
    import pandas as pd

    study_type = cfg["study_type"]
    n          = cfg["pairs_per_user"]
    code       = cfg.get("prolific_completion_code", "")
    labels     = PREFERENCE_LABELS if study_type == "preference" else LIKELIHOOD_LABELS

    st.markdown("## βœ… Study Complete β€” Thank You!")
    st.markdown(
        f"You have finished all {n} {'pairs' if study_type == 'preference' else 'products'}. "
        "Here is a summary of how your ratings changed:"
    )

    rows = []
    for i, item in enumerate(s["items"]):
        pre   = item.get("pre_rating",   None)
        post  = item.get("post_rating",  None)
        delta = item.get("rating_delta", 0)
        arrow = "➑️" if delta == 0 else ("⬆️" if (delta or 0) > 0 else "⬇️")
        cat   = CATEGORY_DISPLAY.get(item.get("category", ""), "")

        if study_type == "preference":
            rows.append({
                "#":           i + 1,
                "Category":    cat,
                "Product A":   (item.get("product_a", {}).get("title", "") or "")[:45] + "…",
                "Product B":   (item.get("product_b", {}).get("title", "") or "")[:45] + "…",
                "Pre-rating":  labels.get(pre,  str(pre)  if pre  is not None else "β€”"),
                "Post-rating": labels.get(post, str(post) if post is not None else "β€”"),
                "Shift":       f"{arrow} {delta:+d}" if isinstance(delta, int) else "β€”",
            })
        else:
            rows.append({
                "#":           i + 1,
                "Category":    cat,
                "Product":     (item.get("product", {}).get("title", "") or "")[:65] + "…",
                "Pre-rating":  labels.get(pre,  str(pre)  if pre  is not None else "β€”"),
                "Post-rating": labels.get(post, str(post) if post is not None else "β€”"),
                "Shift":       f"{arrow} {delta:+d}" if isinstance(delta, int) else "β€”",
            })

    st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
    st.markdown("---")
    st.success(
        f"**Your Prolific completion code: `{code}`**\n\n"
        "Please copy this code and paste it into the Prolific website to complete your submission."
    )