Spaces:
Running
Running
| """ | |
| Dataset download, item-pool caching, completion-aware assignment, and session-state init. | |
| Assignment strategy | |
| ------------------- | |
| Items are assigned based on how many *accepted* completions they already have, | |
| ensuring the least-covered items are always prioritised. | |
| Each assigned item is stamped with _pool_index and _pool_category at assignment | |
| time so record_completion never needs to do a fuzzy pair_id match β it reads | |
| the index directly. | |
| Accepted completions = JSON files under json/ in the output repo. | |
| Rejected completions = JSON files moved to rejected/ by the admin. | |
| β moving a file to rejected/ automatically makes that item available again. | |
| Reservations | |
| ------------ | |
| When a user starts, their items are "reserved" in a local file for 80 min. | |
| Concurrent users each get a FileLock on the reservation file so they | |
| never receive the same items. Reservations expire automatically so abandoned | |
| sessions don't permanently block items. | |
| Each reservation stores the user's prolific_pid so we can release their items | |
| immediately when Prolific reports them as RETURNED or TIMED-OUT β no need to | |
| wait for the 80-min TTL. | |
| Dropout / rejection recovery | |
| ----------------------------- | |
| - Dropout (voluntary return): Prolific marks RETURNED, we query the API and | |
| release the reservation on the next assignment. | |
| - Dropout (silent): reservation expires after 80 min β item re-enters pool. | |
| - Rejection: admin moves json/{worker}/{id}.json β rejected/{worker}/{id}.json | |
| in the HF dataset repo. On next Space restart (or cache expiry) the item's | |
| accepted count drops to 0 and it gets re-assigned. | |
| """ | |
| import json | |
| import random | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| import streamlit as st | |
| from filelock import FileLock | |
| from src.config import CATEGORY_TO_REPO | |
| POOL_SIZE = 50 # items selected per (study_type, category) | |
| RESERVATION_TTL = 60 * 80 # 80 min: 30 min expected + ~2.5x buffer | |
| COMPLETION_CACHE_TTL = 300 # re-scan HF repo every 5 minutes | |
| PROLIFIC_POLL_CACHE_TTL = 120 # re-poll Prolific every 2 minutes | |
| # ββ Path helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _data_dir(cfg: dict) -> Path: | |
| p = Path(cfg["data_dir"]) | |
| p.mkdir(parents=True, exist_ok=True) | |
| return p | |
| def _pool_path(category: str, cfg: dict) -> Path: | |
| return _data_dir(cfg) / f"pool_{cfg['study_type']}_{category}.json" | |
| def _reservation_path(cfg: dict) -> Path: | |
| return _data_dir(cfg) / "reservations.json" | |
| def _reservation_lock_path(cfg: dict) -> Path: | |
| return _data_dir(cfg) / "reservations.lock" | |
| def _local_completions_path(category: str, cfg: dict) -> Path: | |
| """ | |
| Local file tracking completed item counts this container session. | |
| Updated immediately on each completion so subsequent assignments | |
| see accurate counts without waiting for an HF re-scan. | |
| Reset on container restart β HF is the durable source of truth. | |
| """ | |
| return _data_dir(cfg) / f"local_completions_{cfg['study_type']}_{category}.json" | |
| # ββ Dataset download + normalisation βββββββββββββββββββββββββββββββββββββββββ | |
| def _download_and_cache( | |
| study_type: str, | |
| category: str, | |
| seed: int, | |
| hf_token: str, | |
| data_dir: str, | |
| ) -> None: | |
| pool_path = Path(data_dir) / f"pool_{study_type}_{category}.json" | |
| if pool_path.exists(): | |
| print(f"[DATA] Pool already cached: {pool_path}") | |
| return | |
| from datasets import load_dataset | |
| repo_id = CATEGORY_TO_REPO[(study_type, category)] | |
| token_arg = hf_token or None | |
| print(f"[DATA] Downloading {repo_id} β¦") | |
| ds = load_dataset(repo_id, token=token_arg, trust_remote_code=True) | |
| if study_type == "preference": | |
| if "test" in ds: | |
| rows = [dict(r) for r in ds["test"]] | |
| else: | |
| rows = [dict(r) for r in ds["train"] if r.get("split") == "test"] | |
| else: | |
| split_key = "test" if "test" in ds else list(ds.keys())[0] | |
| rows = [dict(r) for r in ds[split_key]] | |
| rng = random.Random(seed) | |
| rng.shuffle(rows) | |
| selected = rows[:POOL_SIZE] | |
| if study_type == "likelihood": | |
| normalised = [] | |
| for i, row in enumerate(selected): | |
| meta = row["metadata"] | |
| if isinstance(meta, str): | |
| meta = json.loads(meta) | |
| else: | |
| meta = dict(meta) | |
| meta["item_id"] = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{repo_id}_{i}_{seed}")) | |
| meta["category"] = category | |
| normalised.append(meta) | |
| selected = normalised | |
| else: | |
| cleaned = [] | |
| for row in selected: | |
| r = dict(row) | |
| r["product_a"] = dict(r["product_a"]) | |
| r["product_b"] = dict(r["product_b"]) | |
| r["product_a"].setdefault("category", r.get("category", category)) | |
| r["product_b"].setdefault("category", r.get("category", category)) | |
| cleaned.append(r) | |
| selected = cleaned | |
| pool_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(pool_path, "w") as f: | |
| json.dump(selected, f, indent=2) | |
| print(f"[DATA] {study_type}/{category}: cached {len(selected)} items (seed={seed}).") | |
| def ensure_datasets(cfg: dict) -> None: | |
| for cat_cfg in cfg["categories"]: | |
| _download_and_cache( | |
| study_type=cfg["study_type"], | |
| category=cat_cfg["name"], | |
| seed=cfg["pair_selection_seed"], | |
| hf_token=cfg.get("hf_token", ""), | |
| data_dir=cfg["data_dir"], | |
| ) | |
| def _load_pool(pool_path_str: str) -> list: | |
| with open(pool_path_str) as f: | |
| return json.load(f) | |
| # ββ Accepted completion counts ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_accepted_counts(category: str, cfg: dict) -> dict: | |
| """ | |
| Return how many times each pool item has been accepted. | |
| Sources (merged, highest count wins): | |
| 1. Local completions file β written immediately on each completion this session. | |
| 2. HF output repo scan β authoritative after a container restart. | |
| Results cached for COMPLETION_CACHE_TTL seconds. | |
| Rejected submissions live under rejected/ and are NOT counted. | |
| """ | |
| pool = _load_pool(str(_pool_path(category, cfg))) | |
| counts = {str(i): 0 for i in range(len(pool))} | |
| # ββ Source 1: local completions (most up-to-date within this session) ββββ | |
| local_path = _local_completions_path(category, cfg) | |
| if local_path.exists(): | |
| try: | |
| with open(local_path) as f: | |
| local = json.load(f) | |
| for k, v in local.items(): | |
| counts[k] = max(counts.get(k, 0), v) | |
| print(f"[ASSIGN] Local completions for {category}: " | |
| f"{sum(1 for v in local.values() if v > 0)} items completed") | |
| except Exception as e: | |
| print(f"[ASSIGN] Could not read local completions: {e}") | |
| # ββ Source 2: HF scan (authoritative after restart, with 5-min cache) βββ | |
| cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{category}.json" | |
| now = time.time() | |
| hf_counts = None | |
| if cache_path.exists(): | |
| try: | |
| with open(cache_path) as f: | |
| cache = json.load(f) | |
| if now - cache.get("timestamp", 0) < COMPLETION_CACHE_TTL: | |
| hf_counts = cache["counts"] | |
| except Exception: | |
| pass | |
| if hf_counts is None: | |
| hf_counts = {str(i): 0 for i in range(len(pool))} | |
| hf_token = cfg.get("hf_token", "") | |
| output_repo = cfg.get("output_dataset_repo", "") | |
| if hf_token and output_repo: | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| files = list(api.list_repo_files(repo_id=output_repo, repo_type="dataset")) | |
| json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")] | |
| # Build pair_id β pool_index lookup for fallback matching | |
| id_to_index = {} | |
| for i, p in enumerate(pool): | |
| pid = p.get("pair_id") or p.get("item_id", "") | |
| if pid: | |
| id_to_index[pid] = i | |
| for filepath in json_files: | |
| try: | |
| content = api.hf_hub_download( | |
| repo_id=output_repo, | |
| filename=filepath, | |
| repo_type="dataset", | |
| token=hf_token, | |
| ) | |
| with open(content) as f: | |
| submission = json.load(f) | |
| for item in submission.get("items", []): | |
| if item.get("category") != category: | |
| continue | |
| idx = item.get("_pool_index") | |
| if idx is None: | |
| pid = item.get("pair_id") or item.get("item_id", "") | |
| idx = id_to_index.get(pid) | |
| if idx is not None: | |
| hf_counts[str(idx)] = hf_counts.get(str(idx), 0) + 1 | |
| except Exception as e: | |
| print(f"[ASSIGN] Could not parse {filepath}: {e}") | |
| except Exception as e: | |
| print(f"[ASSIGN] Could not scan HF repo: {e}") | |
| try: | |
| with open(cache_path, "w") as f: | |
| json.dump({"timestamp": now, "counts": hf_counts}, f) | |
| except Exception: | |
| pass | |
| for k, v in hf_counts.items(): | |
| counts[k] = max(counts.get(k, 0), v) | |
| return counts | |
| # ββ Reservation management ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_reservations(cfg: dict) -> dict: | |
| path = _reservation_path(cfg) | |
| if not path.exists(): | |
| return {} | |
| try: | |
| with open(path) as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _save_reservations(reservations: dict, cfg: dict) -> None: | |
| with open(_reservation_path(cfg), "w") as f: | |
| json.dump(reservations, f) | |
| def _expire_reservations(reservations: dict) -> dict: | |
| now = time.time() | |
| expired = [k for k, v in reservations.items() if v["expiry"] < now] | |
| for k in expired: | |
| print(f"[ASSIGN] Reservation expired for item index {k}") | |
| del reservations[k] | |
| return reservations | |
| def release_reservation(user_id: str, cfg: dict) -> None: | |
| """Release all reservations held by this user immediately after completion.""" | |
| lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10) | |
| with lock: | |
| reservations = _load_reservations(cfg) | |
| _expire_reservations(reservations) | |
| released = [k for k, v in reservations.items() if v["user_id"] == user_id] | |
| for k in released: | |
| del reservations[k] | |
| _save_reservations(reservations, cfg) | |
| print(f"[ASSIGN] Released {len(released)} reservations for user {user_id}") | |
| def record_completion(user_id: str, items: list, cfg: dict) -> None: | |
| """ | |
| Record completed item indices to the local completions file immediately. | |
| Uses _pool_index stamped on each item at assignment time β no fuzzy matching. | |
| Called after successful HF upload AND by the simulation script. | |
| """ | |
| by_category: dict = {} | |
| for item in items: | |
| cat = item.get("_pool_category") or item.get("category", "") | |
| idx = item.get("_pool_index") | |
| if idx is None: | |
| print(f"[ASSIGN] WARNING: item missing _pool_index, skipping: " | |
| f"{item.get('pair_id') or item.get('item_id', '?')}") | |
| continue | |
| by_category.setdefault(cat, []).append(idx) | |
| for cat, indices in by_category.items(): | |
| pool = _load_pool(str(_pool_path(cat, cfg))) | |
| completions_path = _local_completions_path(cat, cfg) | |
| if completions_path.exists(): | |
| try: | |
| with open(completions_path) as f: | |
| completions = json.load(f) | |
| except Exception: | |
| completions = {str(i): 0 for i in range(len(pool))} | |
| else: | |
| completions = {str(i): 0 for i in range(len(pool))} | |
| for idx in indices: | |
| completions[str(idx)] = completions.get(str(idx), 0) + 1 | |
| with open(completions_path, "w") as f: | |
| json.dump(completions, f) | |
| # Invalidate HF cache so next scan re-reads fresh | |
| cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{cat}.json" | |
| if cache_path.exists(): | |
| try: | |
| cache_path.unlink() | |
| except Exception: | |
| pass | |
| print(f"[ASSIGN] Recorded completions for {cat}: indices {indices} " | |
| f"(user {user_id[:8]})") | |
| # ββ Prolific status polling βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _prolific_returned_pids(cfg: dict) -> set: | |
| """ | |
| Query Prolific for participants who have RETURNED or TIMED-OUT from the | |
| active study. Returns a set of their PIDs. Cached for PROLIFIC_POLL_CACHE_TTL. | |
| """ | |
| token = cfg.get("prolific_api_token", "") | |
| study_id = cfg.get("prolific_study_id", "") | |
| if not token or not study_id: | |
| return set() | |
| cache_path = _data_dir(cfg) / "prolific_returned_cache.json" | |
| now = time.time() | |
| if cache_path.exists(): | |
| try: | |
| with open(cache_path) as f: | |
| c = json.load(f) | |
| if now - c.get("timestamp", 0) < PROLIFIC_POLL_CACHE_TTL: | |
| return set(c.get("returned_pids", [])) | |
| except Exception: | |
| pass | |
| returned = set() | |
| try: | |
| import requests | |
| url = f"https://api.prolific.com/api/v1/studies/{study_id}/submissions/" | |
| headers = {"Authorization": f"Token {token}"} | |
| resp = requests.get(url, headers=headers, timeout=10) | |
| resp.raise_for_status() | |
| for sub in resp.json().get("results", []): | |
| status = sub.get("status", "") | |
| if status in ("RETURNED", "TIMED-OUT", "TIMED_OUT"): | |
| pid = sub.get("participant_id") or sub.get("participant", "") | |
| if pid: | |
| returned.add(pid) | |
| print(f"[PROLIFIC] Found {len(returned)} returned/timed-out participants") | |
| except Exception as e: | |
| print(f"[PROLIFIC] Could not query API: {e}") | |
| try: | |
| with open(cache_path, "w") as f: | |
| json.dump({"timestamp": now, "returned_pids": list(returned)}, f) | |
| except Exception: | |
| pass | |
| return returned | |
| def _release_returned_reservations(reservations: dict, cfg: dict) -> None: | |
| """ | |
| Remove reservations held by Prolific participants who have RETURNED or | |
| TIMED-OUT. Mutates the reservations dict in place. | |
| """ | |
| returned_pids = _prolific_returned_pids(cfg) | |
| if not returned_pids: | |
| return | |
| released = [] | |
| for idx, r in list(reservations.items()): | |
| pid = r.get("prolific_pid", "") | |
| if pid and pid in returned_pids: | |
| released.append(idx) | |
| del reservations[idx] | |
| if released: | |
| print(f"[ASSIGN] Released {len(released)} reservations from returned/timed-out participants: {released}") | |
| def all_items_covered(cfg: dict) -> bool: | |
| """ | |
| Returns True if every item in every category has been accepted at least once. | |
| Used for auto-pausing the Prolific study. | |
| """ | |
| for cat_cfg in cfg["categories"]: | |
| cat = cat_cfg["name"] | |
| pool = _load_pool(str(_pool_path(cat, cfg))) | |
| counts = _get_accepted_counts(cat, cfg) | |
| for i in range(len(pool)): | |
| if counts.get(str(i), 0) < 1: | |
| return False | |
| return True | |
| def pause_prolific_study(cfg: dict) -> bool: | |
| """ | |
| Call Prolific's API to pause the study. Returns True on success. | |
| Requires prolific_api_token (env PROLIFIC_API_TOKEN) and prolific_study_id. | |
| Idempotent β safe to call multiple times (Prolific treats repeated pauses as no-ops). | |
| """ | |
| token = cfg.get("prolific_api_token", "") | |
| study_id = cfg.get("prolific_study_id", "") | |
| if not token or not study_id: | |
| print("[PROLIFIC] Cannot auto-pause: no API token or study_id configured") | |
| return False | |
| # Idempotency marker so we don't spam the API on every completion after | |
| # the first time all items are covered. | |
| paused_marker = _data_dir(cfg) / ".prolific_paused" | |
| if paused_marker.exists(): | |
| return True | |
| try: | |
| import requests | |
| url = f"https://api.prolific.com/api/v1/studies/{study_id}/transition/" | |
| headers = {"Authorization": f"Token {token}", "Content-Type": "application/json"} | |
| resp = requests.post(url, headers=headers, json={"action": "PAUSE"}, timeout=10) | |
| resp.raise_for_status() | |
| paused_marker.touch() | |
| print(f"[PROLIFIC] β Study {study_id} paused automatically β all items covered.") | |
| return True | |
| except Exception as e: | |
| print(f"[PROLIFIC] Could not auto-pause study: {e}") | |
| return False | |
| # ββ Core assignment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _assign_from_category(category: str, n: int, user_id: str, cfg: dict) -> list: | |
| """ | |
| Assign n items using least-coverage-first strategy. | |
| Priority order (via sort key): | |
| 1. Uncovered + unreserved (count=0, not reserved) | |
| 2. Uncovered + reserved by other (count=0, reserved) | |
| 3. Covered + unreserved (count>0, not reserved) | |
| 4. Covered + reserved by other (count>0, reserved) | |
| Reservations are ONLY created for participants who come via Prolific | |
| (i.e. have a non-empty prolific_pid in the URL). Non-Prolific visitors | |
| (testers, previewers, direct-URL visitors) still get items assigned so | |
| they can run through the study, but they don't hold reservations. | |
| Reservations from participants who have RETURNED/TIMED-OUT on Prolific | |
| are released BEFORE the sort, so their items are treated as unreserved. | |
| """ | |
| pool = _load_pool(str(_pool_path(category, cfg))) | |
| accepted_counts = _get_accepted_counts(category, cfg) | |
| lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10) | |
| # Capture prolific_pid early so we can decide whether to reserve. | |
| # Read from query_params directly β session_state.study_state doesn't | |
| # exist yet during init_state, which is what calls this function. | |
| prolific_pid = "" | |
| try: | |
| params = st.query_params | |
| prolific_pid = params.get("PROLIFIC_PID", "") or "" | |
| except Exception: | |
| pass | |
| is_prolific = bool(prolific_pid) | |
| with lock: | |
| reservations = _load_reservations(cfg) | |
| _expire_reservations(reservations) | |
| _release_returned_reservations(reservations, cfg) | |
| # If this Prolific PID already has reservations (e.g. they refreshed | |
| # the tab, got a new user_id, and came back), release the old ones | |
| # before creating new ones. Prevents the same participant from | |
| # accumulating multiple reservations. | |
| if is_prolific: | |
| stale = [ | |
| idx for idx, r in list(reservations.items()) | |
| if r.get("prolific_pid") == prolific_pid | |
| ] | |
| for idx in stale: | |
| del reservations[idx] | |
| if stale: | |
| print(f"[ASSIGN] Released {len(stale)} prior reservations " | |
| f"for returning PID {prolific_pid}") | |
| def is_reserved_by_other(i): | |
| r = reservations.get(str(i)) | |
| return r is not None and r["user_id"] != user_id | |
| def sort_key(i): | |
| count = accepted_counts.get(str(i), 0) | |
| reserved = int(is_reserved_by_other(i)) | |
| return (count, reserved) | |
| all_indices = sorted(range(len(pool)), key=sort_key) | |
| selected_indices = all_indices[:n] | |
| # Only reserve if this is a Prolific participant β keeps the | |
| # admin "in progress" count accurate and stops testers/bouncers | |
| # from blocking items for real users. | |
| if is_prolific: | |
| expiry = time.time() + RESERVATION_TTL | |
| for i in selected_indices: | |
| reservations[str(i)] = { | |
| "user_id": user_id, | |
| "prolific_pid": prolific_pid, | |
| "expiry": expiry, | |
| } | |
| _save_reservations(reservations, cfg) | |
| print(f"[ASSIGN] Reserved for Prolific PID {prolific_pid}") | |
| else: | |
| print(f"[ASSIGN] Non-Prolific visitor β no reservation created") | |
| selected = [] | |
| for i in selected_indices: | |
| item = dict(pool[i]) | |
| item["_pool_index"] = i | |
| item["_pool_category"] = category | |
| selected.append(item) | |
| print(f"[ASSIGN] {category}: assigned indices {selected_indices} " | |
| f"(counts: {[accepted_counts.get(str(i), 0) for i in selected_indices]})") | |
| return selected | |
| # ββ Variant assignment ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _assign_variants(cfg: dict, n: int) -> list: | |
| variants = cfg.get("model_variants") | |
| if not variants: | |
| return [{"name": "default", | |
| "model_name": cfg["model_name"], | |
| "prompt_variant": cfg["prompt_variant"]}] * n | |
| if len(variants) == 1: | |
| return [variants[0]] * n | |
| lock = FileLock(str(_data_dir(cfg) / "variant_counter.lock"), timeout=10) | |
| with lock: | |
| counter_path = _data_dir(cfg) / "variant_counter.txt" | |
| ctr = int(counter_path.read_text().strip()) if counter_path.exists() else 0 | |
| counter_path.write_text(str(ctr + 1)) | |
| v0, v1 = variants[0], variants[1] | |
| if ctr % 2 == 1: | |
| v0, v1 = v1, v0 | |
| from itertools import zip_longest | |
| interleaved = [] | |
| for a, b in zip_longest([v0] * v0["count"], [v1] * v1["count"]): | |
| if a: interleaved.append(a) | |
| if b: interleaved.append(b) | |
| print(f"[VARIANTS] user {ctr}: {[v['name'] for v in interleaved]}") | |
| return interleaved | |
| # ββ Category count computation ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_counts(cfg: dict) -> dict: | |
| cats = cfg["categories"] | |
| n = cfg["pairs_per_user"] | |
| if len(cats) == 1: | |
| return {cats[0]["name"]: n} | |
| lock = FileLock(str(_data_dir(cfg) / "alternation_counter.lock"), timeout=10) | |
| with lock: | |
| path = _data_dir(cfg) / "alternation_counter.txt" | |
| ctr = int(path.read_text().strip()) if path.exists() else 0 | |
| path.write_text(str(ctr + 1)) | |
| base = {c["name"]: c["count"] for c in cats} | |
| if sum(base.values()) != n: | |
| base = {} | |
| for i, c in enumerate(cats): | |
| base[c["name"]] = n // len(cats) + (1 if i < n % len(cats) else 0) | |
| return base | |
| if ctr % 2 == 1: | |
| names = [c["name"] for c in cats] | |
| base[names[0]], base[names[1]] = base[names[1]], base[names[0]] | |
| return base | |
| def assign_items(cfg: dict, user_id: str) -> list: | |
| counts = _compute_counts(cfg) | |
| items = [] | |
| for cat_name, n in counts.items(): | |
| items.extend(_assign_from_category(cat_name, n, user_id, cfg)) | |
| random.shuffle(items) | |
| return items | |
| # ββ Item slot construction ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_item_slot(item: dict, study_type: str) -> dict: | |
| base = { | |
| "_pool_index": item.get("_pool_index"), | |
| "_pool_category": item.get("_pool_category", item.get("category", "")), | |
| "conversation": { | |
| "system_prompt": "", | |
| "closing_message": "", | |
| "turns": [], | |
| "num_turns": 0, | |
| }, | |
| "reflection": {}, | |
| "pre_rating": None, | |
| "post_rating": None, | |
| "rating_delta": None, | |
| } | |
| if study_type == "preference": | |
| base.update({ | |
| "pair_id": item.get("pair_id", str(uuid.uuid4())), | |
| "category": item.get("category", ""), | |
| "product_a": item.get("product_a", {}), | |
| "product_b": item.get("product_b", {}), | |
| "familiarity_a": None, | |
| "familiarity_b": None, | |
| }) | |
| else: | |
| base.update({ | |
| "item_id": item.get("item_id", str(uuid.uuid4())), | |
| "category": item.get("category", ""), | |
| "product": item, | |
| "familiarity": None, | |
| }) | |
| return base | |
| # ββ Session-state construction ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def init_state(cfg: dict) -> dict: | |
| """Build the initial session-state dict for a new participant.""" | |
| n = cfg["pairs_per_user"] | |
| user_id = str(uuid.uuid4()) | |
| variants = _assign_variants(cfg, n) | |
| items = assign_items(cfg, user_id)[:n] | |
| slots = [_make_item_slot(it, cfg["study_type"]) for it in items] | |
| for slot, variant in zip(slots, variants): | |
| slot["model_name"] = variant["model_name"] | |
| slot["prompt_variant"] = variant["prompt_variant"] | |
| slot["sampler_path"] = variant.get("sampler_path", "") | |
| for i, slot in enumerate(slots): | |
| print(f"[ITEM {i}] category={slot.get('category')} " | |
| f"pool_index={slot.get('_pool_index')} " | |
| f"model={slot.get('model_name')} " | |
| f"personalization={slot.get('prompt_variant', {}).get('personalization')}") | |
| try: | |
| params = st.query_params | |
| except Exception: | |
| params = {} | |
| return { | |
| "submission_id": str(uuid.uuid4()), | |
| "user_id": user_id, | |
| "prolific_pid": params.get("PROLIFIC_PID", ""), | |
| "study_id": params.get("STUDY_ID", ""), | |
| "session_id": params.get("SESSION_ID", ""), | |
| "start_time": time.time(), | |
| "study_type": cfg["study_type"], | |
| "demographics": {}, | |
| "background": {}, | |
| "items": slots, | |
| "current_index": 0, | |
| "screen": "welcome", | |
| "meta": {}, | |
| } |