""" Entry point for the user study Streamlit app. Run from repo root: streamlit run src/app.py streamlit run src/app.py -- --debug (sets DEBUG_MODE via env) HuggingFace Space secrets required: HF_TOKEN — read/write HuggingFace token GH_TOKEN — GitHub PAT (ehejin account) for the private lsp submodule TINKER_API_KEY — Tinker inference API key DEBUG_MODE — "true" to skip all validation (optional) """ import os import sys import json import subprocess from pathlib import Path # --------------------------------------------------------------------------- # 1. Initialise lsp git submodule before any lsp imports # On a cold HF Space start the submodule directory exists but is empty; # GH_TOKEN lets us authenticate against the private GitHub repo. # --------------------------------------------------------------------------- _BASE = Path(__file__).resolve().parent.parent _LSP_PATH = _BASE / "lsp" def _init_submodule() -> None: prompts_exist = (_LSP_PATH / "src" / "prompts").exists() if not prompts_exist: token = os.getenv("GH_TOKEN", "") if not token: raise RuntimeError("GH_TOKEN secret is not set.") import shutil import tarfile import urllib.request import time as _time # Clean any stale state if _LSP_PATH.exists(): shutil.rmtree(str(_LSP_PATH), ignore_errors=True) git_modules = _BASE / ".git" / "modules" / "lsp" if git_modules.exists(): shutil.rmtree(str(git_modules), ignore_errors=True) # GitHub serves a tarball of any branch/tag/SHA at this URL. # Pinned to a specific commit SHA so future lsp changes don't break us. branch = "74582acd911f81309ba8b22cef9286c2887dda18" tarball_url = f"https://api.github.com/repos/batu-el/lsp/tarball/{branch}" tmp_tar = Path("/tmp/lsp.tar.gz") tmp_extract = Path("/tmp/lsp_extract") for attempt in range(1, 4): print(f"[SUBMODULE] tarball download attempt {attempt}/3 ...") try: req = urllib.request.Request( tarball_url, headers={ "Authorization": f"Bearer {token}", "Accept": "application/vnd.github+json", "User-Agent": "prolific-preferences", }, ) with urllib.request.urlopen(req, timeout=60) as resp: tmp_tar.write_bytes(resp.read()) print(f"[SUBMODULE] downloaded {tmp_tar.stat().st_size} bytes") # Extract if tmp_extract.exists(): shutil.rmtree(str(tmp_extract), ignore_errors=True) tmp_extract.mkdir(parents=True) with tarfile.open(str(tmp_tar)) as tar: tar.extractall(str(tmp_extract)) # GitHub tarballs have a top-level dir like batu-el-lsp-abc123/ subdirs = [d for d in tmp_extract.iterdir() if d.is_dir()] if not subdirs: raise RuntimeError("tarball had no top-level directory") top = subdirs[0] # Verify the prompts dir is present if not (top / "src" / "prompts").exists(): raise RuntimeError(f"src/prompts not found in extracted tarball at {top}") # Move extracted contents to /app/lsp shutil.copytree(str(top), str(_LSP_PATH)) tmp_tar.unlink(missing_ok=True) shutil.rmtree(str(tmp_extract), ignore_errors=True) print("[SUBMODULE] ready.") break except Exception as e: msg = str(e).replace(token, "***") if token else str(e) print(f"[SUBMODULE] attempt {attempt} failed: {msg}") _time.sleep(3) else: raise RuntimeError(f"Failed to download lsp tarball after 3 attempts.") lsp_src = str(_LSP_PATH / "src") if lsp_src not in sys.path: sys.path.insert(0, lsp_src) if str(_BASE) not in sys.path: sys.path.insert(0, str(_BASE)) _init_submodule() # Wipe stale local state ONLY on the first container load (not on every Streamlit rerun). # We use a marker file — once created, subsequent imports skip the wipe. # Completions stay durable in HF; we re-scan HF fresh after wipe. _data_root = _BASE / "data" _data_root.mkdir(parents=True, exist_ok=True) _wipe_marker = _data_root / ".startup_wiped" if not _wipe_marker.exists(): for pattern in ("reservations.json", "local_completions_*.json", "completion_cache_*.json"): for f in _data_root.glob(pattern): try: f.unlink() print(f"[STARTUP] Wiped stale file: {f.name}") except Exception as e: print(f"[STARTUP] Could not wipe {f.name}: {e}") _wipe_marker.touch() print("[STARTUP] Marked container as wiped") # --------------------------------------------------------------------------- # 2. App imports (only after submodule is initialised) # --------------------------------------------------------------------------- import streamlit as st from src.config import load_config from src.data import ensure_datasets, init_state from src.ui.components import inject_css from src.ui.screens_shared import ( screen_background, screen_chat, screen_demographics, screen_done, screen_post_rating, screen_reflection, screen_welcome, ) from src.ui.screens_likelihood import screen_item_intro from src.ui.screens_preference import screen_pair_intro # --------------------------------------------------------------------------- # 3. Admin dashboard — visit ?admin=1 # --------------------------------------------------------------------------- def _screen_admin(cfg: dict) -> None: """Coverage dashboard — visit ?admin=1 to see this.""" from src.data import ( _get_accepted_counts, _load_pool, _pool_path, _load_reservations, _save_reservations, _expire_reservations, _release_returned_reservations, _reservation_lock_path, ) from filelock import FileLock st.markdown("## 📊 Study Coverage Dashboard") st.caption( f"Study type: `{cfg['study_type']}` · " f"Seed: `{cfg['pair_selection_seed']}` · " f"Output repo: `{cfg['output_dataset_repo']}`" ) if st.button("🔄 Refresh", type="primary"): # Invalidate caches so we re-scan HF and re-poll Prolific from src.data import _data_dir for f in _data_dir(cfg).glob("completion_cache*"): f.unlink() prolific_cache = _data_dir(cfg) / "prolific_returned_cache.json" if prolific_cache.exists(): prolific_cache.unlink() st.rerun() # Release expired + returned/timed-out reservations before displaying lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10) with lock: reservations = _load_reservations(cfg) _expire_reservations(reservations) _release_returned_reservations(reservations, cfg) _save_reservations(reservations, cfg) for cat_cfg in cfg["categories"]: cat = cat_cfg["name"] pool = _load_pool(str(_pool_path(cat, cfg))) total = len(pool) counts = _get_accepted_counts(cat, cfg) covered = sum(1 for v in counts.values() if v >= 1) reserved_uncovered = sum( 1 for k in reservations if counts.get(k, 0) == 0 ) truly_uncovered = total - covered - reserved_uncovered st.markdown(f"### {cat.capitalize()}") col1, col2, col3, col4 = st.columns(4) col1.metric("Total items", total) col2.metric("Covered ✅", covered) col3.metric("In progress 🔄", reserved_uncovered, help="Reserved by active Prolific participants") col4.metric("Still needed ⚠️", truly_uncovered, delta=f"-{truly_uncovered}" if truly_uncovered > 0 else None, delta_color="inverse") if truly_uncovered == 0 and reserved_uncovered == 0: st.success(f"✅ All {total} items covered!") elif truly_uncovered == 0: st.info(f"🔄 {reserved_uncovered} item(s) in progress.") else: st.warning( f"⚠️ {truly_uncovered} item(s) still need a participant. " f"Send more Prolific slots." ) st.markdown("---") # --------------------------------------------------------------------------- # 4. Main # --------------------------------------------------------------------------- def main() -> None: cfg = load_config() st.set_page_config( page_title="Product Study", page_icon="🛒", layout="centered", ) inject_css() # Admin dashboard — visit ?admin=1 try: params = st.query_params except Exception: params = {} if params.get("admin") == "1": ensure_datasets(cfg) _screen_admin(cfg) return if "study_state" not in st.session_state: ensure_datasets(cfg) st.session_state.study_state = init_state(cfg) s = st.session_state.study_state screen = s.get("screen", "welcome") dispatch = { "welcome": lambda: screen_welcome(s, cfg), "demographics": lambda: screen_demographics(s, cfg), "background": lambda: screen_background(s, cfg), "item_intro": lambda: ( screen_pair_intro(s, cfg) if cfg["study_type"] == "preference" else screen_item_intro(s, cfg) ), "chat": lambda: screen_chat(s, cfg), "post_rating": lambda: screen_post_rating(s, cfg), "reflection": lambda: screen_reflection(s, cfg), "done": lambda: screen_done(s, cfg), } handler = dispatch.get(screen) if handler: handler() else: st.error(f"Unknown screen: {screen!r}") if __name__ == "__main__": main()