ehejin's picture
sync w/ detailed repo
0f4326e
"""
Dataset download, item-pool caching, completion-aware assignment, and session-state init.
Assignment strategy
-------------------
Items are assigned based on how many *accepted* completions they already have,
ensuring the least-covered items are always prioritised.
Each assigned item is stamped with _pool_index and _pool_category at assignment
time so record_completion never needs to do a fuzzy pair_id match β€” it reads
the index directly.
Accepted completions = JSON files under json/ in the output repo.
Rejected completions = JSON files moved to rejected/ by the admin.
β†’ moving a file to rejected/ automatically makes that item available again.
Reservations
------------
When a user starts, their items are "reserved" in a local file for 80 min.
Concurrent users each get a FileLock on the reservation file so they
never receive the same items. Reservations expire automatically so abandoned
sessions don't permanently block items.
Each reservation stores the user's prolific_pid so we can release their items
immediately when Prolific reports them as RETURNED or TIMED-OUT β€” no need to
wait for the 80-min TTL.
Dropout / rejection recovery
-----------------------------
- Dropout (voluntary return): Prolific marks RETURNED, we query the API and
release the reservation on the next assignment.
- Dropout (silent): reservation expires after 80 min β†’ item re-enters pool.
- Rejection: admin moves json/{worker}/{id}.json β†’ rejected/{worker}/{id}.json
in the HF dataset repo. On next Space restart (or cache expiry) the item's
accepted count drops to 0 and it gets re-assigned.
"""
import json
import random
import time
import uuid
from pathlib import Path
import streamlit as st
from filelock import FileLock
from src.config import CATEGORY_TO_REPO
POOL_SIZE = 50 # items selected per (study_type, category)
RESERVATION_TTL = 60 * 80 # 80 min: 30 min expected + ~2.5x buffer
COMPLETION_CACHE_TTL = 300 # re-scan HF repo every 5 minutes
PROLIFIC_POLL_CACHE_TTL = 120 # re-poll Prolific every 2 minutes
# ── Path helpers ──────────────────────────────────────────────────────────────
def _data_dir(cfg: dict) -> Path:
p = Path(cfg["data_dir"])
p.mkdir(parents=True, exist_ok=True)
return p
def _pool_path(category: str, cfg: dict) -> Path:
return _data_dir(cfg) / f"pool_{cfg['study_type']}_{category}.json"
def _reservation_path(cfg: dict) -> Path:
return _data_dir(cfg) / "reservations.json"
def _reservation_lock_path(cfg: dict) -> Path:
return _data_dir(cfg) / "reservations.lock"
def _local_completions_path(category: str, cfg: dict) -> Path:
"""
Local file tracking completed item counts this container session.
Updated immediately on each completion so subsequent assignments
see accurate counts without waiting for an HF re-scan.
Reset on container restart β€” HF is the durable source of truth.
"""
return _data_dir(cfg) / f"local_completions_{cfg['study_type']}_{category}.json"
# ── Dataset download + normalisation ─────────────────────────────────────────
@st.cache_resource
def _download_and_cache(
study_type: str,
category: str,
seed: int,
hf_token: str,
data_dir: str,
) -> None:
pool_path = Path(data_dir) / f"pool_{study_type}_{category}.json"
if pool_path.exists():
print(f"[DATA] Pool already cached: {pool_path}")
return
from datasets import load_dataset
repo_id = CATEGORY_TO_REPO[(study_type, category)]
token_arg = hf_token or None
print(f"[DATA] Downloading {repo_id} …")
ds = load_dataset(repo_id, token=token_arg, trust_remote_code=True)
if study_type == "preference":
if "test" in ds:
rows = [dict(r) for r in ds["test"]]
else:
rows = [dict(r) for r in ds["train"] if r.get("split") == "test"]
else:
split_key = "test" if "test" in ds else list(ds.keys())[0]
rows = [dict(r) for r in ds[split_key]]
rng = random.Random(seed)
rng.shuffle(rows)
selected = rows[:POOL_SIZE]
if study_type == "likelihood":
normalised = []
for i, row in enumerate(selected):
meta = row["metadata"]
if isinstance(meta, str):
meta = json.loads(meta)
else:
meta = dict(meta)
meta["item_id"] = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{repo_id}_{i}_{seed}"))
meta["category"] = category
normalised.append(meta)
selected = normalised
else:
cleaned = []
for row in selected:
r = dict(row)
r["product_a"] = dict(r["product_a"])
r["product_b"] = dict(r["product_b"])
r["product_a"].setdefault("category", r.get("category", category))
r["product_b"].setdefault("category", r.get("category", category))
cleaned.append(r)
selected = cleaned
pool_path.parent.mkdir(parents=True, exist_ok=True)
with open(pool_path, "w") as f:
json.dump(selected, f, indent=2)
print(f"[DATA] {study_type}/{category}: cached {len(selected)} items (seed={seed}).")
def ensure_datasets(cfg: dict) -> None:
for cat_cfg in cfg["categories"]:
_download_and_cache(
study_type=cfg["study_type"],
category=cat_cfg["name"],
seed=cfg["pair_selection_seed"],
hf_token=cfg.get("hf_token", ""),
data_dir=cfg["data_dir"],
)
@st.cache_data
def _load_pool(pool_path_str: str) -> list:
with open(pool_path_str) as f:
return json.load(f)
# ── Accepted completion counts ────────────────────────────────────────────────
def _get_accepted_counts(category: str, cfg: dict) -> dict:
"""
Return how many times each pool item has been accepted.
Sources (merged, highest count wins):
1. Local completions file β€” written immediately on each completion this session.
2. HF output repo scan β€” authoritative after a container restart.
Results cached for COMPLETION_CACHE_TTL seconds.
Rejected submissions live under rejected/ and are NOT counted.
"""
pool = _load_pool(str(_pool_path(category, cfg)))
counts = {str(i): 0 for i in range(len(pool))}
# ── Source 1: local completions (most up-to-date within this session) ────
local_path = _local_completions_path(category, cfg)
if local_path.exists():
try:
with open(local_path) as f:
local = json.load(f)
for k, v in local.items():
counts[k] = max(counts.get(k, 0), v)
print(f"[ASSIGN] Local completions for {category}: "
f"{sum(1 for v in local.values() if v > 0)} items completed")
except Exception as e:
print(f"[ASSIGN] Could not read local completions: {e}")
# ── Source 2: HF scan (authoritative after restart, with 5-min cache) ───
cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{category}.json"
now = time.time()
hf_counts = None
if cache_path.exists():
try:
with open(cache_path) as f:
cache = json.load(f)
if now - cache.get("timestamp", 0) < COMPLETION_CACHE_TTL:
hf_counts = cache["counts"]
except Exception:
pass
if hf_counts is None:
hf_counts = {str(i): 0 for i in range(len(pool))}
hf_token = cfg.get("hf_token", "")
output_repo = cfg.get("output_dataset_repo", "")
if hf_token and output_repo:
try:
from huggingface_hub import HfApi
api = HfApi(token=hf_token)
files = list(api.list_repo_files(repo_id=output_repo, repo_type="dataset"))
json_files = [f for f in files if f.startswith("json/") and f.endswith(".json")]
# Build pair_id β†’ pool_index lookup for fallback matching
id_to_index = {}
for i, p in enumerate(pool):
pid = p.get("pair_id") or p.get("item_id", "")
if pid:
id_to_index[pid] = i
for filepath in json_files:
try:
content = api.hf_hub_download(
repo_id=output_repo,
filename=filepath,
repo_type="dataset",
token=hf_token,
)
with open(content) as f:
submission = json.load(f)
for item in submission.get("items", []):
if item.get("category") != category:
continue
idx = item.get("_pool_index")
if idx is None:
pid = item.get("pair_id") or item.get("item_id", "")
idx = id_to_index.get(pid)
if idx is not None:
hf_counts[str(idx)] = hf_counts.get(str(idx), 0) + 1
except Exception as e:
print(f"[ASSIGN] Could not parse {filepath}: {e}")
except Exception as e:
print(f"[ASSIGN] Could not scan HF repo: {e}")
try:
with open(cache_path, "w") as f:
json.dump({"timestamp": now, "counts": hf_counts}, f)
except Exception:
pass
for k, v in hf_counts.items():
counts[k] = max(counts.get(k, 0), v)
return counts
# ── Reservation management ────────────────────────────────────────────────────
def _load_reservations(cfg: dict) -> dict:
path = _reservation_path(cfg)
if not path.exists():
return {}
try:
with open(path) as f:
return json.load(f)
except Exception:
return {}
def _save_reservations(reservations: dict, cfg: dict) -> None:
with open(_reservation_path(cfg), "w") as f:
json.dump(reservations, f)
def _expire_reservations(reservations: dict) -> dict:
now = time.time()
expired = [k for k, v in reservations.items() if v["expiry"] < now]
for k in expired:
print(f"[ASSIGN] Reservation expired for item index {k}")
del reservations[k]
return reservations
def release_reservation(user_id: str, cfg: dict) -> None:
"""Release all reservations held by this user immediately after completion."""
lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10)
with lock:
reservations = _load_reservations(cfg)
_expire_reservations(reservations)
released = [k for k, v in reservations.items() if v["user_id"] == user_id]
for k in released:
del reservations[k]
_save_reservations(reservations, cfg)
print(f"[ASSIGN] Released {len(released)} reservations for user {user_id}")
def record_completion(user_id: str, items: list, cfg: dict) -> None:
"""
Record completed item indices to the local completions file immediately.
Uses _pool_index stamped on each item at assignment time β€” no fuzzy matching.
Called after successful HF upload AND by the simulation script.
"""
by_category: dict = {}
for item in items:
cat = item.get("_pool_category") or item.get("category", "")
idx = item.get("_pool_index")
if idx is None:
print(f"[ASSIGN] WARNING: item missing _pool_index, skipping: "
f"{item.get('pair_id') or item.get('item_id', '?')}")
continue
by_category.setdefault(cat, []).append(idx)
for cat, indices in by_category.items():
pool = _load_pool(str(_pool_path(cat, cfg)))
completions_path = _local_completions_path(cat, cfg)
if completions_path.exists():
try:
with open(completions_path) as f:
completions = json.load(f)
except Exception:
completions = {str(i): 0 for i in range(len(pool))}
else:
completions = {str(i): 0 for i in range(len(pool))}
for idx in indices:
completions[str(idx)] = completions.get(str(idx), 0) + 1
with open(completions_path, "w") as f:
json.dump(completions, f)
# Invalidate HF cache so next scan re-reads fresh
cache_path = _data_dir(cfg) / f"completion_cache_{cfg['study_type']}_{cat}.json"
if cache_path.exists():
try:
cache_path.unlink()
except Exception:
pass
print(f"[ASSIGN] Recorded completions for {cat}: indices {indices} "
f"(user {user_id[:8]})")
# ── Prolific status polling ───────────────────────────────────────────────────
def _prolific_returned_pids(cfg: dict) -> set:
"""
Query Prolific for participants who have RETURNED or TIMED-OUT from the
active study. Returns a set of their PIDs. Cached for PROLIFIC_POLL_CACHE_TTL.
"""
token = cfg.get("prolific_api_token", "")
study_id = cfg.get("prolific_study_id", "")
if not token or not study_id:
return set()
cache_path = _data_dir(cfg) / "prolific_returned_cache.json"
now = time.time()
if cache_path.exists():
try:
with open(cache_path) as f:
c = json.load(f)
if now - c.get("timestamp", 0) < PROLIFIC_POLL_CACHE_TTL:
return set(c.get("returned_pids", []))
except Exception:
pass
returned = set()
try:
import requests
url = f"https://api.prolific.com/api/v1/studies/{study_id}/submissions/"
headers = {"Authorization": f"Token {token}"}
resp = requests.get(url, headers=headers, timeout=10)
resp.raise_for_status()
for sub in resp.json().get("results", []):
status = sub.get("status", "")
if status in ("RETURNED", "TIMED-OUT", "TIMED_OUT"):
pid = sub.get("participant_id") or sub.get("participant", "")
if pid:
returned.add(pid)
print(f"[PROLIFIC] Found {len(returned)} returned/timed-out participants")
except Exception as e:
print(f"[PROLIFIC] Could not query API: {e}")
try:
with open(cache_path, "w") as f:
json.dump({"timestamp": now, "returned_pids": list(returned)}, f)
except Exception:
pass
return returned
def _release_returned_reservations(reservations: dict, cfg: dict) -> None:
"""
Remove reservations held by Prolific participants who have RETURNED or
TIMED-OUT. Mutates the reservations dict in place.
"""
returned_pids = _prolific_returned_pids(cfg)
if not returned_pids:
return
released = []
for idx, r in list(reservations.items()):
pid = r.get("prolific_pid", "")
if pid and pid in returned_pids:
released.append(idx)
del reservations[idx]
if released:
print(f"[ASSIGN] Released {len(released)} reservations from returned/timed-out participants: {released}")
def all_items_covered(cfg: dict) -> bool:
"""
Returns True if every item in every category has been accepted at least once.
Used for auto-pausing the Prolific study.
"""
for cat_cfg in cfg["categories"]:
cat = cat_cfg["name"]
pool = _load_pool(str(_pool_path(cat, cfg)))
counts = _get_accepted_counts(cat, cfg)
for i in range(len(pool)):
if counts.get(str(i), 0) < 1:
return False
return True
def pause_prolific_study(cfg: dict) -> bool:
"""
Call Prolific's API to pause the study. Returns True on success.
Requires prolific_api_token (env PROLIFIC_API_TOKEN) and prolific_study_id.
Idempotent β€” safe to call multiple times (Prolific treats repeated pauses as no-ops).
"""
token = cfg.get("prolific_api_token", "")
study_id = cfg.get("prolific_study_id", "")
if not token or not study_id:
print("[PROLIFIC] Cannot auto-pause: no API token or study_id configured")
return False
# Idempotency marker so we don't spam the API on every completion after
# the first time all items are covered.
paused_marker = _data_dir(cfg) / ".prolific_paused"
if paused_marker.exists():
return True
try:
import requests
url = f"https://api.prolific.com/api/v1/studies/{study_id}/transition/"
headers = {"Authorization": f"Token {token}", "Content-Type": "application/json"}
resp = requests.post(url, headers=headers, json={"action": "PAUSE"}, timeout=10)
resp.raise_for_status()
paused_marker.touch()
print(f"[PROLIFIC] βœ… Study {study_id} paused automatically β€” all items covered.")
return True
except Exception as e:
print(f"[PROLIFIC] Could not auto-pause study: {e}")
return False
# ── Core assignment ───────────────────────────────────────────────────────────
def _assign_from_category(category: str, n: int, user_id: str, cfg: dict) -> list:
"""
Assign n items using least-coverage-first strategy.
Priority order (via sort key):
1. Uncovered + unreserved (count=0, not reserved)
2. Uncovered + reserved by other (count=0, reserved)
3. Covered + unreserved (count>0, not reserved)
4. Covered + reserved by other (count>0, reserved)
Reservations are ONLY created for participants who come via Prolific
(i.e. have a non-empty prolific_pid in the URL). Non-Prolific visitors
(testers, previewers, direct-URL visitors) still get items assigned so
they can run through the study, but they don't hold reservations.
Reservations from participants who have RETURNED/TIMED-OUT on Prolific
are released BEFORE the sort, so their items are treated as unreserved.
"""
pool = _load_pool(str(_pool_path(category, cfg)))
accepted_counts = _get_accepted_counts(category, cfg)
lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10)
# Capture prolific_pid early so we can decide whether to reserve.
# Read from query_params directly β€” session_state.study_state doesn't
# exist yet during init_state, which is what calls this function.
prolific_pid = ""
try:
params = st.query_params
prolific_pid = params.get("PROLIFIC_PID", "") or ""
except Exception:
pass
is_prolific = bool(prolific_pid)
with lock:
reservations = _load_reservations(cfg)
_expire_reservations(reservations)
_release_returned_reservations(reservations, cfg)
# If this Prolific PID already has reservations (e.g. they refreshed
# the tab, got a new user_id, and came back), release the old ones
# before creating new ones. Prevents the same participant from
# accumulating multiple reservations.
if is_prolific:
stale = [
idx for idx, r in list(reservations.items())
if r.get("prolific_pid") == prolific_pid
]
for idx in stale:
del reservations[idx]
if stale:
print(f"[ASSIGN] Released {len(stale)} prior reservations "
f"for returning PID {prolific_pid}")
def is_reserved_by_other(i):
r = reservations.get(str(i))
return r is not None and r["user_id"] != user_id
def sort_key(i):
count = accepted_counts.get(str(i), 0)
reserved = int(is_reserved_by_other(i))
return (count, reserved)
all_indices = sorted(range(len(pool)), key=sort_key)
selected_indices = all_indices[:n]
# Only reserve if this is a Prolific participant β€” keeps the
# admin "in progress" count accurate and stops testers/bouncers
# from blocking items for real users.
if is_prolific:
expiry = time.time() + RESERVATION_TTL
for i in selected_indices:
reservations[str(i)] = {
"user_id": user_id,
"prolific_pid": prolific_pid,
"expiry": expiry,
}
_save_reservations(reservations, cfg)
print(f"[ASSIGN] Reserved for Prolific PID {prolific_pid}")
else:
print(f"[ASSIGN] Non-Prolific visitor β€” no reservation created")
selected = []
for i in selected_indices:
item = dict(pool[i])
item["_pool_index"] = i
item["_pool_category"] = category
selected.append(item)
print(f"[ASSIGN] {category}: assigned indices {selected_indices} "
f"(counts: {[accepted_counts.get(str(i), 0) for i in selected_indices]})")
return selected
# ── Variant assignment ────────────────────────────────────────────────────────
def _assign_variants(cfg: dict, n: int) -> list:
variants = cfg.get("model_variants")
if not variants:
return [{"name": "default",
"model_name": cfg["model_name"],
"prompt_variant": cfg["prompt_variant"]}] * n
if len(variants) == 1:
return [variants[0]] * n
lock = FileLock(str(_data_dir(cfg) / "variant_counter.lock"), timeout=10)
with lock:
counter_path = _data_dir(cfg) / "variant_counter.txt"
ctr = int(counter_path.read_text().strip()) if counter_path.exists() else 0
counter_path.write_text(str(ctr + 1))
v0, v1 = variants[0], variants[1]
if ctr % 2 == 1:
v0, v1 = v1, v0
from itertools import zip_longest
interleaved = []
for a, b in zip_longest([v0] * v0["count"], [v1] * v1["count"]):
if a: interleaved.append(a)
if b: interleaved.append(b)
print(f"[VARIANTS] user {ctr}: {[v['name'] for v in interleaved]}")
return interleaved
# ── Category count computation ────────────────────────────────────────────────
def _compute_counts(cfg: dict) -> dict:
cats = cfg["categories"]
n = cfg["pairs_per_user"]
if len(cats) == 1:
return {cats[0]["name"]: n}
lock = FileLock(str(_data_dir(cfg) / "alternation_counter.lock"), timeout=10)
with lock:
path = _data_dir(cfg) / "alternation_counter.txt"
ctr = int(path.read_text().strip()) if path.exists() else 0
path.write_text(str(ctr + 1))
base = {c["name"]: c["count"] for c in cats}
if sum(base.values()) != n:
base = {}
for i, c in enumerate(cats):
base[c["name"]] = n // len(cats) + (1 if i < n % len(cats) else 0)
return base
if ctr % 2 == 1:
names = [c["name"] for c in cats]
base[names[0]], base[names[1]] = base[names[1]], base[names[0]]
return base
def assign_items(cfg: dict, user_id: str) -> list:
counts = _compute_counts(cfg)
items = []
for cat_name, n in counts.items():
items.extend(_assign_from_category(cat_name, n, user_id, cfg))
random.shuffle(items)
return items
# ── Item slot construction ────────────────────────────────────────────────────
def _make_item_slot(item: dict, study_type: str) -> dict:
base = {
"_pool_index": item.get("_pool_index"),
"_pool_category": item.get("_pool_category", item.get("category", "")),
"conversation": {
"system_prompt": "",
"closing_message": "",
"turns": [],
"num_turns": 0,
},
"reflection": {},
"pre_rating": None,
"post_rating": None,
"rating_delta": None,
}
if study_type == "preference":
base.update({
"pair_id": item.get("pair_id", str(uuid.uuid4())),
"category": item.get("category", ""),
"product_a": item.get("product_a", {}),
"product_b": item.get("product_b", {}),
"familiarity_a": None,
"familiarity_b": None,
})
else:
base.update({
"item_id": item.get("item_id", str(uuid.uuid4())),
"category": item.get("category", ""),
"product": item,
"familiarity": None,
})
return base
# ── Session-state construction ────────────────────────────────────────────────
def init_state(cfg: dict) -> dict:
"""Build the initial session-state dict for a new participant."""
n = cfg["pairs_per_user"]
user_id = str(uuid.uuid4())
variants = _assign_variants(cfg, n)
items = assign_items(cfg, user_id)[:n]
slots = [_make_item_slot(it, cfg["study_type"]) for it in items]
for slot, variant in zip(slots, variants):
slot["model_name"] = variant["model_name"]
slot["prompt_variant"] = variant["prompt_variant"]
slot["sampler_path"] = variant.get("sampler_path", "")
for i, slot in enumerate(slots):
print(f"[ITEM {i}] category={slot.get('category')} "
f"pool_index={slot.get('_pool_index')} "
f"model={slot.get('model_name')} "
f"personalization={slot.get('prompt_variant', {}).get('personalization')}")
try:
params = st.query_params
except Exception:
params = {}
return {
"submission_id": str(uuid.uuid4()),
"user_id": user_id,
"prolific_pid": params.get("PROLIFIC_PID", ""),
"study_id": params.get("STUDY_ID", ""),
"session_id": params.get("SESSION_ID", ""),
"start_time": time.time(),
"study_type": cfg["study_type"],
"demographics": {},
"background": {},
"items": slots,
"current_index": 0,
"screen": "welcome",
"meta": {},
}