File size: 27,112 Bytes
6b23da9
46bfd91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f4326e
46bfd91
 
 
0f4326e
 
 
 
46bfd91
 
0f4326e
 
 
46bfd91
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
 
0f4326e
 
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
46bfd91
 
6b23da9
 
46bfd91
 
6b23da9
 
46bfd91
 
 
 
 
 
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bfd91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c66ca5e
 
 
 
 
 
 
 
46bfd91
 
 
 
 
 
 
 
 
 
 
 
 
 
c66ca5e
 
 
46bfd91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f4326e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bfd91
 
 
 
 
 
0f4326e
46bfd91
0f4326e
46bfd91
 
6b23da9
0f4326e
 
 
 
 
 
 
46bfd91
 
 
 
6b23da9
0f4326e
 
 
 
 
 
 
 
 
 
 
6b23da9
46bfd91
 
0f4326e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bfd91
 
 
 
 
 
 
 
 
 
0f4326e
46bfd91
 
0f4326e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bfd91
 
 
 
 
 
 
6b23da9
46bfd91
 
 
 
 
 
6b23da9
6b0bcdc
 
 
46bfd91
 
 
6b0bcdc
53fa83f
 
 
46bfd91
6b0bcdc
46bfd91
 
 
6b0bcdc
 
 
 
 
46bfd91
 
 
 
 
 
 
 
 
 
 
261fec3
6b23da9
 
 
 
 
 
 
46bfd91
6b23da9
46bfd91
 
 
6b23da9
 
 
 
 
 
 
 
46bfd91
6b23da9
 
 
 
 
 
46bfd91
6b23da9
 
 
46bfd91
 
6b23da9
 
 
46bfd91
6b23da9
 
 
46bfd91
 
6b23da9
 
46bfd91
6b23da9
 
 
 
46bfd91
 
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bfd91
 
6b23da9
 
e271058
46bfd91
6b0bcdc
46bfd91
 
e271058
 
 
 
0f4326e
46bfd91
e271058
 
46bfd91
e271058
46bfd91
 
6b23da9
 
 
 
46bfd91
6b23da9
 
46bfd91
6b23da9
 
 
 
 
 
 
e271058
6b23da9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
"""
Dataset download, item-pool caching, completion-aware assignment, and session-state init.

Assignment strategy
-------------------
Items are assigned based on how many *accepted* completions they already have,
ensuring the least-covered items are always prioritised.

Each assigned item is stamped with _pool_index and _pool_category at assignment
time so record_completion never needs to do a fuzzy pair_id match β€” it reads
the index directly.

Accepted completions = JSON files under json/ in the output repo.
Rejected completions = JSON files moved to rejected/ by the admin.
  β†’ moving a file to rejected/ automatically makes that item available again.

Reservations
------------
When a user starts, their items are "reserved" in a local file for 80 min.
Concurrent users each get a FileLock on the reservation file so they
never receive the same items. Reservations expire automatically so abandoned
sessions don't permanently block items.

Each reservation stores the user's prolific_pid so we can release their items
immediately when Prolific reports them as RETURNED or TIMED-OUT β€” no need to
wait for the 80-min TTL.

Dropout / rejection recovery
-----------------------------
- Dropout (voluntary return): Prolific marks RETURNED, we query the API and
  release the reservation on the next assignment.
- Dropout (silent): reservation expires after 80 min β†’ item re-enters pool.
- Rejection: admin moves json/{worker}/{id}.json β†’ rejected/{worker}/{id}.json
  in the HF dataset repo. On next Space restart (or cache expiry) the item's
  accepted count drops to 0 and it gets re-assigned.
"""
import json
import random
import time
import uuid
from pathlib import Path

import streamlit as st
from filelock import FileLock

from src.config import CATEGORY_TO_REPO

POOL_SIZE               = 50        # items selected per (study_type, category)
RESERVATION_TTL         = 60 * 80   # 80 min: 30 min expected + ~2.5x buffer
COMPLETION_CACHE_TTL    = 300       # re-scan HF repo every 5 minutes
PROLIFIC_POLL_CACHE_TTL = 120       # re-poll Prolific every 2 minutes


# ── Path helpers ──────────────────────────────────────────────────────────────

def _data_dir(cfg: dict) -> Path:
    p = Path(cfg["data_dir"])
    p.mkdir(parents=True, exist_ok=True)
    return p


def _pool_path(category: str, cfg: dict) -> Path:
    return _data_dir(cfg) / f"pool_{cfg['study_type']}_{category}.json"


def _reservation_path(cfg: dict) -> Path:
    return _data_dir(cfg) / "reservations.json"


def _reservation_lock_path(cfg: dict) -> Path:
    return _data_dir(cfg) / "reservations.lock"


def _local_completions_path(category: str, cfg: dict) -> Path:
    """
    Local file tracking completed item counts this container session.
    Updated immediately on each completion so subsequent assignments
    see accurate counts without waiting for an HF re-scan.
    Reset on container restart β€” HF is the durable source of truth.
    """
    return _data_dir(cfg) / f"local_completions_{cfg['study_type']}_{category}.json"


# ── Dataset download + normalisation ─────────────────────────────────────────

@st.cache_resource
def _download_and_cache(
    study_type: str,
    category: str,
    seed: int,
    hf_token: str,
    data_dir: str,
) -> None:
    pool_path = Path(data_dir) / f"pool_{study_type}_{category}.json"
    if pool_path.exists():
        print(f"[DATA] Pool already cached: {pool_path}")
        return

    from datasets import load_dataset

    repo_id   = CATEGORY_TO_REPO[(study_type, category)]
    token_arg = hf_token or None
    print(f"[DATA] Downloading {repo_id} …")

    ds = load_dataset(repo_id, token=token_arg, trust_remote_code=True)

    if study_type == "preference":
        if "test" in ds:
            rows = [dict(r) for r in ds["test"]]
        else:
            rows = [dict(r) for r in ds["train"] if r.get("split") == "test"]
    else:
        split_key = "test" if "test" in ds else list(ds.keys())[0]
        rows = [dict(r) for r in ds[split_key]]

    rng = random.Random(seed)
    rng.shuffle(rows)
    selected = rows[:POOL_SIZE]

    if study_type == "likelihood":
        normalised = []
        for i, row in enumerate(selected):
            meta = row["metadata"]
            if isinstance(meta, str):
                meta = json.loads(meta)
            else:
                meta = dict(meta)
            meta["item_id"]  = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{repo_id}_{i}_{seed}"))
            meta["category"] = category
            normalised.append(meta)
        selected = normalised
    else:
        cleaned = []
        for row in selected:
            r = dict(row)
            r["product_a"] = dict(r["product_a"])
            r["product_b"] = dict(r["product_b"])
            r["product_a"].setdefault("category", r.get("category", category))
            r["product_b"].setdefault("category", r.get("category", category))
            cleaned.append(r)
        selected = cleaned

    pool_path.parent.mkdir(parents=True, exist_ok=True)
    with open(pool_path, "w") as f:
        json.dump(selected, f, indent=2)

    print(f"[DATA] {study_type}/{category}: cached {len(selected)} items (seed={seed}).")


def ensure_datasets(cfg: dict) -> None:
    for cat_cfg in cfg["categories"]:
        _download_and_cache(
            study_type=cfg["study_type"],
            category=cat_cfg["name"],
            seed=cfg["pair_selection_seed"],
            hf_token=cfg.get("hf_token", ""),
            data_dir=cfg["data_dir"],
        )


@st.cache_data
def _load_pool(pool_path_str: str) -> list:
    with open(pool_path_str) as f:
        return json.load(f)


# ── Accepted completion counts ────────────────────────────────────────────────

def _get_accepted_counts(category: str, cfg: dict) -> dict:
    """
    Return how many times each pool item has been accepted.

    Sources (merged, highest count wins):
    1. Local completions file β€” written immediately on each completion this session.
    2. HF output repo scan β€” authoritative after a container restart.
       Results cached for COMPLETION_CACHE_TTL seconds.

    Rejected submissions live under rejected/ and are NOT counted.
    """
    pool   = _load_pool(str(_pool_path(category, cfg)))
    counts = {str(i): 0 for i in range(len(pool))}

    # ── Source 1: local completions (most up-to-date within this session) ────
    local_path = _local_completions_path(category, cfg)
    if local_path.exists():
        try:
            with open(local_path) as f:
                local = json.load(f)
            for k, v in local.items():
                counts[k] = max(counts.get(k, 0), v)
            print(f"[ASSIGN] Local completions for {category}: "
                  f"{sum(1 for v in local.values() if v > 0)} items completed")
        except Exception as e:
            print(f"[ASSIGN] Could not read local completions: {e}")

    # ── Source 2: HF scan (authoritative after restart, with 5-min cache) ───
    cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{category}.json"
    now        = time.time()
    hf_counts  = None

    if cache_path.exists():
        try:
            with open(cache_path) as f:
                cache = json.load(f)
            if now - cache.get("timestamp", 0) < COMPLETION_CACHE_TTL:
                hf_counts = cache["counts"]
        except Exception:
            pass

    if hf_counts is None:
        hf_counts   = {str(i): 0 for i in range(len(pool))}
        hf_token    = cfg.get("hf_token", "")
        output_repo = cfg.get("output_dataset_repo", "")
        if hf_token and output_repo:
            try:
                from huggingface_hub import HfApi
                api        = HfApi(token=hf_token)
                files      = list(api.list_repo_files(repo_id=output_repo, repo_type="dataset"))
                json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")]

                # Build pair_id β†’ pool_index lookup for fallback matching
                id_to_index = {}
                for i, p in enumerate(pool):
                    pid = p.get("pair_id") or p.get("item_id", "")
                    if pid:
                        id_to_index[pid] = i

                for filepath in json_files:
                    try:
                        content = api.hf_hub_download(
                            repo_id=output_repo,
                            filename=filepath,
                            repo_type="dataset",
                            token=hf_token,
                        )
                        with open(content) as f:
                            submission = json.load(f)
                        for item in submission.get("items", []):
                            if item.get("category") != category:
                                continue
                            idx = item.get("_pool_index")
                            if idx is None:
                                pid = item.get("pair_id") or item.get("item_id", "")
                                idx = id_to_index.get(pid)
                            if idx is not None:
                                hf_counts[str(idx)] = hf_counts.get(str(idx), 0) + 1
                    except Exception as e:
                        print(f"[ASSIGN] Could not parse {filepath}: {e}")
            except Exception as e:
                print(f"[ASSIGN] Could not scan HF repo: {e}")
        try:
            with open(cache_path, "w") as f:
                json.dump({"timestamp": now, "counts": hf_counts}, f)
        except Exception:
            pass

    for k, v in hf_counts.items():
        counts[k] = max(counts.get(k, 0), v)

    return counts


# ── Reservation management ────────────────────────────────────────────────────

def _load_reservations(cfg: dict) -> dict:
    path = _reservation_path(cfg)
    if not path.exists():
        return {}
    try:
        with open(path) as f:
            return json.load(f)
    except Exception:
        return {}


def _save_reservations(reservations: dict, cfg: dict) -> None:
    with open(_reservation_path(cfg), "w") as f:
        json.dump(reservations, f)


def _expire_reservations(reservations: dict) -> dict:
    now     = time.time()
    expired = [k for k, v in reservations.items() if v["expiry"] < now]
    for k in expired:
        print(f"[ASSIGN] Reservation expired for item index {k}")
        del reservations[k]
    return reservations


def release_reservation(user_id: str, cfg: dict) -> None:
    """Release all reservations held by this user immediately after completion."""
    lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10)
    with lock:
        reservations = _load_reservations(cfg)
        _expire_reservations(reservations)
        released = [k for k, v in reservations.items() if v["user_id"] == user_id]
        for k in released:
            del reservations[k]
        _save_reservations(reservations, cfg)
        print(f"[ASSIGN] Released {len(released)} reservations for user {user_id}")


def record_completion(user_id: str, items: list, cfg: dict) -> None:
    """
    Record completed item indices to the local completions file immediately.
    Uses _pool_index stamped on each item at assignment time β€” no fuzzy matching.
    Called after successful HF upload AND by the simulation script.
    """
    by_category: dict = {}
    for item in items:
        cat = item.get("_pool_category") or item.get("category", "")
        idx = item.get("_pool_index")
        if idx is None:
            print(f"[ASSIGN] WARNING: item missing _pool_index, skipping: "
                  f"{item.get('pair_id') or item.get('item_id', '?')}")
            continue
        by_category.setdefault(cat, []).append(idx)

    for cat, indices in by_category.items():
        pool             = _load_pool(str(_pool_path(cat, cfg)))
        completions_path = _local_completions_path(cat, cfg)

        if completions_path.exists():
            try:
                with open(completions_path) as f:
                    completions = json.load(f)
            except Exception:
                completions = {str(i): 0 for i in range(len(pool))}
        else:
            completions = {str(i): 0 for i in range(len(pool))}

        for idx in indices:
            completions[str(idx)] = completions.get(str(idx), 0) + 1

        with open(completions_path, "w") as f:
            json.dump(completions, f)

        # Invalidate HF cache so next scan re-reads fresh
        cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{cat}.json"
        if cache_path.exists():
            try:
                cache_path.unlink()
            except Exception:
                pass

        print(f"[ASSIGN] Recorded completions for {cat}: indices {indices} "
              f"(user {user_id[:8]})")


# ── Prolific status polling ───────────────────────────────────────────────────

def _prolific_returned_pids(cfg: dict) -> set:
    """
    Query Prolific for participants who have RETURNED or TIMED-OUT from the
    active study. Returns a set of their PIDs. Cached for PROLIFIC_POLL_CACHE_TTL.
    """
    token    = cfg.get("prolific_api_token", "")
    study_id = cfg.get("prolific_study_id", "")
    if not token or not study_id:
        return set()

    cache_path = _data_dir(cfg) / "prolific_returned_cache.json"
    now        = time.time()

    if cache_path.exists():
        try:
            with open(cache_path) as f:
                c = json.load(f)
            if now - c.get("timestamp", 0) < PROLIFIC_POLL_CACHE_TTL:
                return set(c.get("returned_pids", []))
        except Exception:
            pass

    returned = set()
    try:
        import requests
        url     = f"https://api.prolific.com/api/v1/studies/{study_id}/submissions/"
        headers = {"Authorization": f"Token {token}"}
        resp    = requests.get(url, headers=headers, timeout=10)
        resp.raise_for_status()
        for sub in resp.json().get("results", []):
            status = sub.get("status", "")
            if status in ("RETURNED", "TIMED-OUT", "TIMED_OUT"):
                pid = sub.get("participant_id") or sub.get("participant", "")
                if pid:
                    returned.add(pid)
        print(f"[PROLIFIC] Found {len(returned)} returned/timed-out participants")
    except Exception as e:
        print(f"[PROLIFIC] Could not query API: {e}")

    try:
        with open(cache_path, "w") as f:
            json.dump({"timestamp": now, "returned_pids": list(returned)}, f)
    except Exception:
        pass

    return returned


def _release_returned_reservations(reservations: dict, cfg: dict) -> None:
    """
    Remove reservations held by Prolific participants who have RETURNED or
    TIMED-OUT. Mutates the reservations dict in place.
    """
    returned_pids = _prolific_returned_pids(cfg)
    if not returned_pids:
        return

    released = []
    for idx, r in list(reservations.items()):
        pid = r.get("prolific_pid", "")
        if pid and pid in returned_pids:
            released.append(idx)
            del reservations[idx]
    if released:
        print(f"[ASSIGN] Released {len(released)} reservations from returned/timed-out participants: {released}")


def all_items_covered(cfg: dict) -> bool:
    """
    Returns True if every item in every category has been accepted at least once.
    Used for auto-pausing the Prolific study.
    """
    for cat_cfg in cfg["categories"]:
        cat   = cat_cfg["name"]
        pool  = _load_pool(str(_pool_path(cat, cfg)))
        counts = _get_accepted_counts(cat, cfg)
        for i in range(len(pool)):
            if counts.get(str(i), 0) < 1:
                return False
    return True


def pause_prolific_study(cfg: dict) -> bool:
    """
    Call Prolific's API to pause the study. Returns True on success.
    Requires prolific_api_token (env PROLIFIC_API_TOKEN) and prolific_study_id.
    Idempotent β€” safe to call multiple times (Prolific treats repeated pauses as no-ops).
    """
    token    = cfg.get("prolific_api_token", "")
    study_id = cfg.get("prolific_study_id", "")
    if not token or not study_id:
        print("[PROLIFIC] Cannot auto-pause: no API token or study_id configured")
        return False

    # Idempotency marker so we don't spam the API on every completion after
    # the first time all items are covered.
    paused_marker = _data_dir(cfg) / ".prolific_paused"
    if paused_marker.exists():
        return True

    try:
        import requests
        url     = f"https://api.prolific.com/api/v1/studies/{study_id}/transition/"
        headers = {"Authorization": f"Token {token}", "Content-Type": "application/json"}
        resp    = requests.post(url, headers=headers, json={"action": "PAUSE"}, timeout=10)
        resp.raise_for_status()
        paused_marker.touch()
        print(f"[PROLIFIC] βœ… Study {study_id} paused automatically β€” all items covered.")
        return True
    except Exception as e:
        print(f"[PROLIFIC] Could not auto-pause study: {e}")
        return False


# ── Core assignment ───────────────────────────────────────────────────────────

def _assign_from_category(category: str, n: int, user_id: str, cfg: dict) -> list:
    """
    Assign n items using least-coverage-first strategy.

    Priority order (via sort key):
      1. Uncovered + unreserved         (count=0, not reserved)
      2. Uncovered + reserved by other  (count=0, reserved)
      3. Covered   + unreserved         (count>0, not reserved)
      4. Covered   + reserved by other  (count>0, reserved)

    Reservations are ONLY created for participants who come via Prolific
    (i.e. have a non-empty prolific_pid in the URL). Non-Prolific visitors
    (testers, previewers, direct-URL visitors) still get items assigned so
    they can run through the study, but they don't hold reservations.

    Reservations from participants who have RETURNED/TIMED-OUT on Prolific
    are released BEFORE the sort, so their items are treated as unreserved.
    """
    pool            = _load_pool(str(_pool_path(category, cfg)))
    accepted_counts = _get_accepted_counts(category, cfg)
    lock            = FileLock(str(_reservation_lock_path(cfg)), timeout=10)

    # Capture prolific_pid early so we can decide whether to reserve.
    # Read from query_params directly β€” session_state.study_state doesn't
    # exist yet during init_state, which is what calls this function.
    prolific_pid = ""
    try:
        params = st.query_params
        prolific_pid = params.get("PROLIFIC_PID", "") or ""
    except Exception:
        pass
    is_prolific = bool(prolific_pid)

    with lock:
        reservations = _load_reservations(cfg)
        _expire_reservations(reservations)
        _release_returned_reservations(reservations, cfg)

        # If this Prolific PID already has reservations (e.g. they refreshed
        # the tab, got a new user_id, and came back), release the old ones
        # before creating new ones. Prevents the same participant from
        # accumulating multiple reservations.
        if is_prolific:
            stale = [
                idx for idx, r in list(reservations.items())
                if r.get("prolific_pid") == prolific_pid
            ]
            for idx in stale:
                del reservations[idx]
            if stale:
                print(f"[ASSIGN] Released {len(stale)} prior reservations "
                      f"for returning PID {prolific_pid}")

        def is_reserved_by_other(i):
            r = reservations.get(str(i))
            return r is not None and r["user_id"] != user_id

        def sort_key(i):
            count    = accepted_counts.get(str(i), 0)
            reserved = int(is_reserved_by_other(i))
            return (count, reserved)

        all_indices      = sorted(range(len(pool)), key=sort_key)
        selected_indices = all_indices[:n]

        # Only reserve if this is a Prolific participant β€” keeps the
        # admin "in progress" count accurate and stops testers/bouncers
        # from blocking items for real users.
        if is_prolific:
            expiry = time.time() + RESERVATION_TTL
            for i in selected_indices:
                reservations[str(i)] = {
                    "user_id":      user_id,
                    "prolific_pid": prolific_pid,
                    "expiry":       expiry,
                }
            _save_reservations(reservations, cfg)
            print(f"[ASSIGN] Reserved for Prolific PID {prolific_pid}")
        else:
            print(f"[ASSIGN] Non-Prolific visitor β€” no reservation created")

    selected = []
    for i in selected_indices:
        item = dict(pool[i])
        item["_pool_index"]    = i
        item["_pool_category"] = category
        selected.append(item)

    print(f"[ASSIGN] {category}: assigned indices {selected_indices} "
          f"(counts: {[accepted_counts.get(str(i), 0) for i in selected_indices]})")
    return selected


# ── Variant assignment ────────────────────────────────────────────────────────

def _assign_variants(cfg: dict, n: int) -> list:
    variants = cfg.get("model_variants")
    if not variants:
        return [{"name": "default",
                 "model_name":     cfg["model_name"],
                 "prompt_variant": cfg["prompt_variant"]}] * n

    if len(variants) == 1:
        return [variants[0]] * n

    lock = FileLock(str(_data_dir(cfg) / "variant_counter.lock"), timeout=10)
    with lock:
        counter_path = _data_dir(cfg) / "variant_counter.txt"
        ctr = int(counter_path.read_text().strip()) if counter_path.exists() else 0
        counter_path.write_text(str(ctr + 1))

    v0, v1 = variants[0], variants[1]
    if ctr % 2 == 1:
        v0, v1 = v1, v0

    from itertools import zip_longest
    interleaved = []
    for a, b in zip_longest([v0] * v0["count"], [v1] * v1["count"]):
        if a: interleaved.append(a)
        if b: interleaved.append(b)

    print(f"[VARIANTS] user {ctr}: {[v['name'] for v in interleaved]}")
    return interleaved


# ── Category count computation ────────────────────────────────────────────────

def _compute_counts(cfg: dict) -> dict:
    cats = cfg["categories"]
    n    = cfg["pairs_per_user"]

    if len(cats) == 1:
        return {cats[0]["name"]: n}

    lock = FileLock(str(_data_dir(cfg) / "alternation_counter.lock"), timeout=10)
    with lock:
        path = _data_dir(cfg) / "alternation_counter.txt"
        ctr  = int(path.read_text().strip()) if path.exists() else 0
        path.write_text(str(ctr + 1))

    base = {c["name"]: c["count"] for c in cats}
    if sum(base.values()) != n:
        base = {}
        for i, c in enumerate(cats):
            base[c["name"]] = n // len(cats) + (1 if i < n % len(cats) else 0)
        return base

    if ctr % 2 == 1:
        names = [c["name"] for c in cats]
        base[names[0]], base[names[1]] = base[names[1]], base[names[0]]

    return base


def assign_items(cfg: dict, user_id: str) -> list:
    counts = _compute_counts(cfg)
    items  = []
    for cat_name, n in counts.items():
        items.extend(_assign_from_category(cat_name, n, user_id, cfg))
    random.shuffle(items)
    return items


# ── Item slot construction ────────────────────────────────────────────────────

def _make_item_slot(item: dict, study_type: str) -> dict:
    base = {
        "_pool_index":    item.get("_pool_index"),
        "_pool_category": item.get("_pool_category", item.get("category", "")),
        "conversation": {
            "system_prompt":   "",
            "closing_message": "",
            "turns":           [],
            "num_turns":       0,
        },
        "reflection":   {},
        "pre_rating":   None,
        "post_rating":  None,
        "rating_delta": None,
    }
    if study_type == "preference":
        base.update({
            "pair_id":       item.get("pair_id",  str(uuid.uuid4())),
            "category":      item.get("category", ""),
            "product_a":     item.get("product_a", {}),
            "product_b":     item.get("product_b", {}),
            "familiarity_a": None,
            "familiarity_b": None,
        })
    else:
        base.update({
            "item_id":    item.get("item_id",  str(uuid.uuid4())),
            "category":   item.get("category", ""),
            "product":    item,
            "familiarity": None,
        })
    return base


# ── Session-state construction ────────────────────────────────────────────────

def init_state(cfg: dict) -> dict:
    """Build the initial session-state dict for a new participant."""
    n        = cfg["pairs_per_user"]
    user_id  = str(uuid.uuid4())
    variants = _assign_variants(cfg, n)
    items    = assign_items(cfg, user_id)[:n]

    slots = [_make_item_slot(it, cfg["study_type"]) for it in items]
    for slot, variant in zip(slots, variants):
        slot["model_name"]     = variant["model_name"]
        slot["prompt_variant"] = variant["prompt_variant"]
        slot["sampler_path"]   = variant.get("sampler_path", "")

    for i, slot in enumerate(slots):
        print(f"[ITEM {i}] category={slot.get('category')} "
              f"pool_index={slot.get('_pool_index')} "
              f"model={slot.get('model_name')} "
              f"personalization={slot.get('prompt_variant', {}).get('personalization')}")

    try:
        params = st.query_params
    except Exception:
        params = {}

    return {
        "submission_id": str(uuid.uuid4()),
        "user_id":       user_id,
        "prolific_pid":  params.get("PROLIFIC_PID", ""),
        "study_id":      params.get("STUDY_ID",     ""),
        "session_id":    params.get("SESSION_ID",   ""),
        "start_time":    time.time(),
        "study_type":    cfg["study_type"],
        "demographics":  {},
        "background":    {},
        "items":         slots,
        "current_index": 0,
        "screen":        "welcome",
        "meta":          {},
    }