""" Dataset download, item-pool caching, completion-aware assignment, and session-state init. Assignment strategy ------------------- Items are assigned based on how many *accepted* completions they already have, ensuring the least-covered items are always prioritised. Each assigned item is stamped with _pool_index and _pool_category at assignment time so record_completion never needs to do a fuzzy pair_id match — it reads the index directly. Accepted completions = JSON files under json/ in the output repo. Rejected completions = JSON files moved to rejected/ by the admin. → moving a file to rejected/ automatically makes that item available again. Reservations ------------ When a user starts, their items are "reserved" in a local file for 80 min. Concurrent users each get a FileLock on the reservation file so they never receive the same items. Reservations expire automatically so abandoned sessions don't permanently block items. Each reservation stores the user's prolific_pid so we can release their items immediately when Prolific reports them as RETURNED or TIMED-OUT — no need to wait for the 80-min TTL. Dropout / rejection recovery ----------------------------- - Dropout (voluntary return): Prolific marks RETURNED, we query the API and release the reservation on the next assignment. - Dropout (silent): reservation expires after 80 min → item re-enters pool. - Rejection: admin moves json/{worker}/{id}.json → rejected/{worker}/{id}.json in the HF dataset repo. On next Space restart (or cache expiry) the item's accepted count drops to 0 and it gets re-assigned. """ import json import random import time import uuid from pathlib import Path import streamlit as st from filelock import FileLock from src.config import CATEGORY_TO_REPO POOL_SIZE = 50 # items selected per (study_type, category) RESERVATION_TTL = 60 * 80 # 80 min: 30 min expected + ~2.5x buffer COMPLETION_CACHE_TTL = 300 # re-scan HF repo every 5 minutes PROLIFIC_POLL_CACHE_TTL = 120 # re-poll Prolific every 2 minutes # ── Path helpers ────────────────────────────────────────────────────────────── def _data_dir(cfg: dict) -> Path: p = Path(cfg["data_dir"]) p.mkdir(parents=True, exist_ok=True) return p def _pool_path(category: str, cfg: dict) -> Path: return _data_dir(cfg) / f"pool_{cfg['study_type']}_{category}.json" def _reservation_path(cfg: dict) -> Path: return _data_dir(cfg) / "reservations.json" def _reservation_lock_path(cfg: dict) -> Path: return _data_dir(cfg) / "reservations.lock" def _local_completions_path(category: str, cfg: dict) -> Path: """ Local file tracking completed item counts this container session. Updated immediately on each completion so subsequent assignments see accurate counts without waiting for an HF re-scan. Reset on container restart — HF is the durable source of truth. """ return _data_dir(cfg) / f"local_completions_{cfg['study_type']}_{category}.json" # ── Dataset download + normalisation ───────────────────────────────────────── @st.cache_resource def _download_and_cache( study_type: str, category: str, seed: int, hf_token: str, data_dir: str, ) -> None: pool_path = Path(data_dir) / f"pool_{study_type}_{category}.json" if pool_path.exists(): print(f"[DATA] Pool already cached: {pool_path}") return from datasets import load_dataset repo_id = CATEGORY_TO_REPO[(study_type, category)] token_arg = hf_token or None print(f"[DATA] Downloading {repo_id} …") ds = load_dataset(repo_id, token=token_arg, trust_remote_code=True) if study_type == "preference": if "test" in ds: rows = [dict(r) for r in ds["test"]] else: rows = [dict(r) for r in ds["train"] if r.get("split") == "test"] else: split_key = "test" if "test" in ds else list(ds.keys())[0] rows = [dict(r) for r in ds[split_key]] rng = random.Random(seed) rng.shuffle(rows) selected = rows[:POOL_SIZE] if study_type == "likelihood": normalised = [] for i, row in enumerate(selected): meta = row["metadata"] if isinstance(meta, str): meta = json.loads(meta) else: meta = dict(meta) meta["item_id"] = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{repo_id}_{i}_{seed}")) meta["category"] = category normalised.append(meta) selected = normalised else: cleaned = [] for row in selected: r = dict(row) r["product_a"] = dict(r["product_a"]) r["product_b"] = dict(r["product_b"]) r["product_a"].setdefault("category", r.get("category", category)) r["product_b"].setdefault("category", r.get("category", category)) cleaned.append(r) selected = cleaned pool_path.parent.mkdir(parents=True, exist_ok=True) with open(pool_path, "w") as f: json.dump(selected, f, indent=2) print(f"[DATA] {study_type}/{category}: cached {len(selected)} items (seed={seed}).") def ensure_datasets(cfg: dict) -> None: for cat_cfg in cfg["categories"]: _download_and_cache( study_type=cfg["study_type"], category=cat_cfg["name"], seed=cfg["pair_selection_seed"], hf_token=cfg.get("hf_token", ""), data_dir=cfg["data_dir"], ) @st.cache_data def _load_pool(pool_path_str: str) -> list: with open(pool_path_str) as f: return json.load(f) # ── Accepted completion counts ──────────────────────────────────────────────── def _get_accepted_counts(category: str, cfg: dict) -> dict: """ Return how many times each pool item has been accepted. Sources (merged, highest count wins): 1. Local completions file — written immediately on each completion this session. 2. HF output repo scan — authoritative after a container restart. Results cached for COMPLETION_CACHE_TTL seconds. Rejected submissions live under rejected/ and are NOT counted. """ pool = _load_pool(str(_pool_path(category, cfg))) counts = {str(i): 0 for i in range(len(pool))} # ── Source 1: local completions (most up-to-date within this session) ──── local_path = _local_completions_path(category, cfg) if local_path.exists(): try: with open(local_path) as f: local = json.load(f) for k, v in local.items(): counts[k] = max(counts.get(k, 0), v) print(f"[ASSIGN] Local completions for {category}: " f"{sum(1 for v in local.values() if v > 0)} items completed") except Exception as e: print(f"[ASSIGN] Could not read local completions: {e}") # ── Source 2: HF scan (authoritative after restart, with 5-min cache) ─── cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{category}.json" now = time.time() hf_counts = None if cache_path.exists(): try: with open(cache_path) as f: cache = json.load(f) if now - cache.get("timestamp", 0) < COMPLETION_CACHE_TTL: hf_counts = cache["counts"] except Exception: pass if hf_counts is None: hf_counts = {str(i): 0 for i in range(len(pool))} hf_token = cfg.get("hf_token", "") output_repo = cfg.get("output_dataset_repo", "") if hf_token and output_repo: try: from huggingface_hub import HfApi api = HfApi(token=hf_token) files = list(api.list_repo_files(repo_id=output_repo, repo_type="dataset")) json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")] # Build pair_id → pool_index lookup for fallback matching id_to_index = {} for i, p in enumerate(pool): pid = p.get("pair_id") or p.get("item_id", "") if pid: id_to_index[pid] = i for filepath in json_files: try: content = api.hf_hub_download( repo_id=output_repo, filename=filepath, repo_type="dataset", token=hf_token, ) with open(content) as f: submission = json.load(f) for item in submission.get("items", []): if item.get("category") != category: continue idx = item.get("_pool_index") if idx is None: pid = item.get("pair_id") or item.get("item_id", "") idx = id_to_index.get(pid) if idx is not None: hf_counts[str(idx)] = hf_counts.get(str(idx), 0) + 1 except Exception as e: print(f"[ASSIGN] Could not parse {filepath}: {e}") except Exception as e: print(f"[ASSIGN] Could not scan HF repo: {e}") try: with open(cache_path, "w") as f: json.dump({"timestamp": now, "counts": hf_counts}, f) except Exception: pass for k, v in hf_counts.items(): counts[k] = max(counts.get(k, 0), v) return counts # ── Reservation management ──────────────────────────────────────────────────── def _load_reservations(cfg: dict) -> dict: path = _reservation_path(cfg) if not path.exists(): return {} try: with open(path) as f: return json.load(f) except Exception: return {} def _save_reservations(reservations: dict, cfg: dict) -> None: with open(_reservation_path(cfg), "w") as f: json.dump(reservations, f) def _expire_reservations(reservations: dict) -> dict: now = time.time() expired = [k for k, v in reservations.items() if v["expiry"] < now] for k in expired: print(f"[ASSIGN] Reservation expired for item index {k}") del reservations[k] return reservations def release_reservation(user_id: str, cfg: dict) -> None: """Release all reservations held by this user immediately after completion.""" lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10) with lock: reservations = _load_reservations(cfg) _expire_reservations(reservations) released = [k for k, v in reservations.items() if v["user_id"] == user_id] for k in released: del reservations[k] _save_reservations(reservations, cfg) print(f"[ASSIGN] Released {len(released)} reservations for user {user_id}") def record_completion(user_id: str, items: list, cfg: dict) -> None: """ Record completed item indices to the local completions file immediately. Uses _pool_index stamped on each item at assignment time — no fuzzy matching. Called after successful HF upload AND by the simulation script. """ by_category: dict = {} for item in items: cat = item.get("_pool_category") or item.get("category", "") idx = item.get("_pool_index") if idx is None: print(f"[ASSIGN] WARNING: item missing _pool_index, skipping: " f"{item.get('pair_id') or item.get('item_id', '?')}") continue by_category.setdefault(cat, []).append(idx) for cat, indices in by_category.items(): pool = _load_pool(str(_pool_path(cat, cfg))) completions_path = _local_completions_path(cat, cfg) if completions_path.exists(): try: with open(completions_path) as f: completions = json.load(f) except Exception: completions = {str(i): 0 for i in range(len(pool))} else: completions = {str(i): 0 for i in range(len(pool))} for idx in indices: completions[str(idx)] = completions.get(str(idx), 0) + 1 with open(completions_path, "w") as f: json.dump(completions, f) # Invalidate HF cache so next scan re-reads fresh cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{cat}.json" if cache_path.exists(): try: cache_path.unlink() except Exception: pass print(f"[ASSIGN] Recorded completions for {cat}: indices {indices} " f"(user {user_id[:8]})") # ── Prolific status polling ─────────────────────────────────────────────────── def _prolific_returned_pids(cfg: dict) -> set: """ Query Prolific for participants who have RETURNED or TIMED-OUT from the active study. Returns a set of their PIDs. Cached for PROLIFIC_POLL_CACHE_TTL. """ token = cfg.get("prolific_api_token", "") study_id = cfg.get("prolific_study_id", "") if not token or not study_id: return set() cache_path = _data_dir(cfg) / "prolific_returned_cache.json" now = time.time() if cache_path.exists(): try: with open(cache_path) as f: c = json.load(f) if now - c.get("timestamp", 0) < PROLIFIC_POLL_CACHE_TTL: return set(c.get("returned_pids", [])) except Exception: pass returned = set() try: import requests url = f"https://api.prolific.com/api/v1/studies/{study_id}/submissions/" headers = {"Authorization": f"Token {token}"} resp = requests.get(url, headers=headers, timeout=10) resp.raise_for_status() for sub in resp.json().get("results", []): status = sub.get("status", "") if status in ("RETURNED", "TIMED-OUT", "TIMED_OUT"): pid = sub.get("participant_id") or sub.get("participant", "") if pid: returned.add(pid) print(f"[PROLIFIC] Found {len(returned)} returned/timed-out participants") except Exception as e: print(f"[PROLIFIC] Could not query API: {e}") try: with open(cache_path, "w") as f: json.dump({"timestamp": now, "returned_pids": list(returned)}, f) except Exception: pass return returned def _release_returned_reservations(reservations: dict, cfg: dict) -> None: """ Remove reservations held by Prolific participants who have RETURNED or TIMED-OUT. Mutates the reservations dict in place. """ returned_pids = _prolific_returned_pids(cfg) if not returned_pids: return released = [] for idx, r in list(reservations.items()): pid = r.get("prolific_pid", "") if pid and pid in returned_pids: released.append(idx) del reservations[idx] if released: print(f"[ASSIGN] Released {len(released)} reservations from returned/timed-out participants: {released}") def all_items_covered(cfg: dict) -> bool: """ Returns True if every item in every category has been accepted at least once. Used for auto-pausing the Prolific study. """ for cat_cfg in cfg["categories"]: cat = cat_cfg["name"] pool = _load_pool(str(_pool_path(cat, cfg))) counts = _get_accepted_counts(cat, cfg) for i in range(len(pool)): if counts.get(str(i), 0) < 1: return False return True def pause_prolific_study(cfg: dict) -> bool: """ Call Prolific's API to pause the study. Returns True on success. Requires prolific_api_token (env PROLIFIC_API_TOKEN) and prolific_study_id. Idempotent — safe to call multiple times (Prolific treats repeated pauses as no-ops). """ token = cfg.get("prolific_api_token", "") study_id = cfg.get("prolific_study_id", "") if not token or not study_id: print("[PROLIFIC] Cannot auto-pause: no API token or study_id configured") return False # Idempotency marker so we don't spam the API on every completion after # the first time all items are covered. paused_marker = _data_dir(cfg) / ".prolific_paused" if paused_marker.exists(): return True try: import requests url = f"https://api.prolific.com/api/v1/studies/{study_id}/transition/" headers = {"Authorization": f"Token {token}", "Content-Type": "application/json"} resp = requests.post(url, headers=headers, json={"action": "PAUSE"}, timeout=10) resp.raise_for_status() paused_marker.touch() print(f"[PROLIFIC] ✅ Study {study_id} paused automatically — all items covered.") return True except Exception as e: print(f"[PROLIFIC] Could not auto-pause study: {e}") return False # ── Core assignment ─────────────────────────────────────────────────────────── def _assign_from_category(category: str, n: int, user_id: str, cfg: dict) -> list: """ Assign n items using least-coverage-first strategy. Priority order (via sort key): 1. Uncovered + unreserved (count=0, not reserved) 2. Uncovered + reserved by other (count=0, reserved) 3. Covered + unreserved (count>0, not reserved) 4. Covered + reserved by other (count>0, reserved) Reservations are ONLY created for participants who come via Prolific (i.e. have a non-empty prolific_pid in the URL). Non-Prolific visitors (testers, previewers, direct-URL visitors) still get items assigned so they can run through the study, but they don't hold reservations. Reservations from participants who have RETURNED/TIMED-OUT on Prolific are released BEFORE the sort, so their items are treated as unreserved. """ pool = _load_pool(str(_pool_path(category, cfg))) accepted_counts = _get_accepted_counts(category, cfg) lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10) # Capture prolific_pid early so we can decide whether to reserve. # Read from query_params directly — session_state.study_state doesn't # exist yet during init_state, which is what calls this function. prolific_pid = "" try: params = st.query_params prolific_pid = params.get("PROLIFIC_PID", "") or "" except Exception: pass is_prolific = bool(prolific_pid) with lock: reservations = _load_reservations(cfg) _expire_reservations(reservations) _release_returned_reservations(reservations, cfg) # If this Prolific PID already has reservations (e.g. they refreshed # the tab, got a new user_id, and came back), release the old ones # before creating new ones. Prevents the same participant from # accumulating multiple reservations. if is_prolific: stale = [ idx for idx, r in list(reservations.items()) if r.get("prolific_pid") == prolific_pid ] for idx in stale: del reservations[idx] if stale: print(f"[ASSIGN] Released {len(stale)} prior reservations " f"for returning PID {prolific_pid}") def is_reserved_by_other(i): r = reservations.get(str(i)) return r is not None and r["user_id"] != user_id def sort_key(i): count = accepted_counts.get(str(i), 0) reserved = int(is_reserved_by_other(i)) return (count, reserved) all_indices = sorted(range(len(pool)), key=sort_key) selected_indices = all_indices[:n] # Only reserve if this is a Prolific participant — keeps the # admin "in progress" count accurate and stops testers/bouncers # from blocking items for real users. if is_prolific: expiry = time.time() + RESERVATION_TTL for i in selected_indices: reservations[str(i)] = { "user_id": user_id, "prolific_pid": prolific_pid, "expiry": expiry, } _save_reservations(reservations, cfg) print(f"[ASSIGN] Reserved for Prolific PID {prolific_pid}") else: print(f"[ASSIGN] Non-Prolific visitor — no reservation created") selected = [] for i in selected_indices: item = dict(pool[i]) item["_pool_index"] = i item["_pool_category"] = category selected.append(item) print(f"[ASSIGN] {category}: assigned indices {selected_indices} " f"(counts: {[accepted_counts.get(str(i), 0) for i in selected_indices]})") return selected # ── Variant assignment ──────────────────────────────────────────────────────── def _assign_variants(cfg: dict, n: int) -> list: variants = cfg.get("model_variants") if not variants: return [{"name": "default", "model_name": cfg["model_name"], "prompt_variant": cfg["prompt_variant"]}] * n if len(variants) == 1: return [variants[0]] * n lock = FileLock(str(_data_dir(cfg) / "variant_counter.lock"), timeout=10) with lock: counter_path = _data_dir(cfg) / "variant_counter.txt" ctr = int(counter_path.read_text().strip()) if counter_path.exists() else 0 counter_path.write_text(str(ctr + 1)) v0, v1 = variants[0], variants[1] if ctr % 2 == 1: v0, v1 = v1, v0 from itertools import zip_longest interleaved = [] for a, b in zip_longest([v0] * v0["count"], [v1] * v1["count"]): if a: interleaved.append(a) if b: interleaved.append(b) print(f"[VARIANTS] user {ctr}: {[v['name'] for v in interleaved]}") return interleaved # ── Category count computation ──────────────────────────────────────────────── def _compute_counts(cfg: dict) -> dict: cats = cfg["categories"] n = cfg["pairs_per_user"] if len(cats) == 1: return {cats[0]["name"]: n} lock = FileLock(str(_data_dir(cfg) / "alternation_counter.lock"), timeout=10) with lock: path = _data_dir(cfg) / "alternation_counter.txt" ctr = int(path.read_text().strip()) if path.exists() else 0 path.write_text(str(ctr + 1)) base = {c["name"]: c["count"] for c in cats} if sum(base.values()) != n: base = {} for i, c in enumerate(cats): base[c["name"]] = n // len(cats) + (1 if i < n % len(cats) else 0) return base if ctr % 2 == 1: names = [c["name"] for c in cats] base[names[0]], base[names[1]] = base[names[1]], base[names[0]] return base def assign_items(cfg: dict, user_id: str) -> list: counts = _compute_counts(cfg) items = [] for cat_name, n in counts.items(): items.extend(_assign_from_category(cat_name, n, user_id, cfg)) random.shuffle(items) return items # ── Item slot construction ──────────────────────────────────────────────────── def _make_item_slot(item: dict, study_type: str) -> dict: base = { "_pool_index": item.get("_pool_index"), "_pool_category": item.get("_pool_category", item.get("category", "")), "conversation": { "system_prompt": "", "closing_message": "", "turns": [], "num_turns": 0, }, "reflection": {}, "pre_rating": None, "post_rating": None, "rating_delta": None, } if study_type == "preference": base.update({ "pair_id": item.get("pair_id", str(uuid.uuid4())), "category": item.get("category", ""), "product_a": item.get("product_a", {}), "product_b": item.get("product_b", {}), "familiarity_a": None, "familiarity_b": None, }) else: base.update({ "item_id": item.get("item_id", str(uuid.uuid4())), "category": item.get("category", ""), "product": item, "familiarity": None, }) return base # ── Session-state construction ──────────────────────────────────────────────── def init_state(cfg: dict) -> dict: """Build the initial session-state dict for a new participant.""" n = cfg["pairs_per_user"] user_id = str(uuid.uuid4()) variants = _assign_variants(cfg, n) items = assign_items(cfg, user_id)[:n] slots = [_make_item_slot(it, cfg["study_type"]) for it in items] for slot, variant in zip(slots, variants): slot["model_name"] = variant["model_name"] slot["prompt_variant"] = variant["prompt_variant"] slot["sampler_path"] = variant.get("sampler_path", "") for i, slot in enumerate(slots): print(f"[ITEM {i}] category={slot.get('category')} " f"pool_index={slot.get('_pool_index')} " f"model={slot.get('model_name')} " f"personalization={slot.get('prompt_variant', {}).get('personalization')}") try: params = st.query_params except Exception: params = {} return { "submission_id": str(uuid.uuid4()), "user_id": user_id, "prolific_pid": params.get("PROLIFIC_PID", ""), "study_id": params.get("STUDY_ID", ""), "session_id": params.get("SESSION_ID", ""), "start_time": time.time(), "study_type": cfg["study_type"], "demographics": {}, "background": {}, "items": slots, "current_index": 0, "screen": "welcome", "meta": {}, }