Spaces:
Running
Running
| """ | |
| Entry point for the user study Streamlit app. | |
| Run from repo root: | |
| streamlit run src/app.py | |
| streamlit run src/app.py -- --debug (sets DEBUG_MODE via env) | |
| HuggingFace Space secrets required: | |
| HF_TOKEN β read/write HuggingFace token | |
| GH_TOKEN β GitHub PAT (ehejin account) for the private lsp submodule | |
| TINKER_API_KEY β Tinker inference API key | |
| DEBUG_MODE β "true" to skip all validation (optional) | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import subprocess | |
| from pathlib import Path | |
| # --------------------------------------------------------------------------- | |
| # 1. Initialise lsp git submodule before any lsp imports | |
| # On a cold HF Space start the submodule directory exists but is empty; | |
| # GH_TOKEN lets us authenticate against the private GitHub repo. | |
| # --------------------------------------------------------------------------- | |
| _BASE = Path(__file__).resolve().parent.parent | |
| _LSP_PATH = _BASE / "lsp" | |
| def _init_submodule() -> None: | |
| prompts_exist = (_LSP_PATH / "src" / "prompts").exists() | |
| if not prompts_exist: | |
| token = os.getenv("GH_TOKEN", "") | |
| if not token: | |
| raise RuntimeError("GH_TOKEN secret is not set.") | |
| import shutil | |
| import tarfile | |
| import urllib.request | |
| import time as _time | |
| # Clean any stale state | |
| if _LSP_PATH.exists(): | |
| shutil.rmtree(str(_LSP_PATH), ignore_errors=True) | |
| git_modules = _BASE / ".git" / "modules" / "lsp" | |
| if git_modules.exists(): | |
| shutil.rmtree(str(git_modules), ignore_errors=True) | |
| # GitHub serves a tarball of any branch/tag/SHA at this URL. | |
| # Pinned to a specific commit SHA so future lsp changes don't break us. | |
| branch = "74582acd911f81309ba8b22cef9286c2887dda18" | |
| tarball_url = f"https://api.github.com/repos/batu-el/lsp/tarball/{branch}" | |
| tmp_tar = Path("/tmp/lsp.tar.gz") | |
| tmp_extract = Path("/tmp/lsp_extract") | |
| for attempt in range(1, 4): | |
| print(f"[SUBMODULE] tarball download attempt {attempt}/3 ...") | |
| try: | |
| req = urllib.request.Request( | |
| tarball_url, | |
| headers={ | |
| "Authorization": f"Bearer {token}", | |
| "Accept": "application/vnd.github+json", | |
| "User-Agent": "prolific-preferences", | |
| }, | |
| ) | |
| with urllib.request.urlopen(req, timeout=60) as resp: | |
| tmp_tar.write_bytes(resp.read()) | |
| print(f"[SUBMODULE] downloaded {tmp_tar.stat().st_size} bytes") | |
| # Extract | |
| if tmp_extract.exists(): | |
| shutil.rmtree(str(tmp_extract), ignore_errors=True) | |
| tmp_extract.mkdir(parents=True) | |
| with tarfile.open(str(tmp_tar)) as tar: | |
| tar.extractall(str(tmp_extract)) | |
| # GitHub tarballs have a top-level dir like batu-el-lsp-abc123/ | |
| subdirs = [d for d in tmp_extract.iterdir() if d.is_dir()] | |
| if not subdirs: | |
| raise RuntimeError("tarball had no top-level directory") | |
| top = subdirs[0] | |
| # Verify the prompts dir is present | |
| if not (top / "src" / "prompts").exists(): | |
| raise RuntimeError(f"src/prompts not found in extracted tarball at {top}") | |
| # Move extracted contents to /app/lsp | |
| shutil.copytree(str(top), str(_LSP_PATH)) | |
| tmp_tar.unlink(missing_ok=True) | |
| shutil.rmtree(str(tmp_extract), ignore_errors=True) | |
| print("[SUBMODULE] ready.") | |
| break | |
| except Exception as e: | |
| msg = str(e).replace(token, "***") if token else str(e) | |
| print(f"[SUBMODULE] attempt {attempt} failed: {msg}") | |
| _time.sleep(3) | |
| else: | |
| raise RuntimeError(f"Failed to download lsp tarball after 3 attempts.") | |
| lsp_src = str(_LSP_PATH / "src") | |
| if lsp_src not in sys.path: | |
| sys.path.insert(0, lsp_src) | |
| if str(_BASE) not in sys.path: | |
| sys.path.insert(0, str(_BASE)) | |
| _init_submodule() | |
| # Wipe stale local state ONLY on the first container load (not on every Streamlit rerun). | |
| # We use a marker file β once created, subsequent imports skip the wipe. | |
| # Completions stay durable in HF; we re-scan HF fresh after wipe. | |
| _data_root = _BASE / "data" | |
| _data_root.mkdir(parents=True, exist_ok=True) | |
| _wipe_marker = _data_root / ".startup_wiped" | |
| if not _wipe_marker.exists(): | |
| for pattern in ("reservations.json", "local_completions_*.json", "completion_cache_*.json"): | |
| for f in _data_root.glob(pattern): | |
| try: | |
| f.unlink() | |
| print(f"[STARTUP] Wiped stale file: {f.name}") | |
| except Exception as e: | |
| print(f"[STARTUP] Could not wipe {f.name}: {e}") | |
| _wipe_marker.touch() | |
| print("[STARTUP] Marked container as wiped") | |
| # --------------------------------------------------------------------------- | |
| # 2. App imports (only after submodule is initialised) | |
| # --------------------------------------------------------------------------- | |
| import streamlit as st | |
| from src.config import load_config | |
| from src.data import ensure_datasets, init_state | |
| from src.ui.components import inject_css | |
| from src.ui.screens_shared import ( | |
| screen_background, | |
| screen_chat, | |
| screen_demographics, | |
| screen_done, | |
| screen_post_rating, | |
| screen_reflection, | |
| screen_welcome, | |
| ) | |
| from src.ui.screens_likelihood import screen_item_intro | |
| from src.ui.screens_preference import screen_pair_intro | |
| # --------------------------------------------------------------------------- | |
| # 3. Admin dashboard β visit ?admin=1 | |
| # --------------------------------------------------------------------------- | |
| def _screen_admin(cfg: dict) -> None: | |
| """Coverage dashboard β visit ?admin=1 to see this.""" | |
| from src.data import ( | |
| _get_accepted_counts, _load_pool, _pool_path, | |
| _load_reservations, _save_reservations, | |
| _expire_reservations, _release_returned_reservations, | |
| _reservation_lock_path, | |
| ) | |
| from filelock import FileLock | |
| st.markdown("## π Study Coverage Dashboard") | |
| st.caption( | |
| f"Study type: `{cfg['study_type']}` Β· " | |
| f"Seed: `{cfg['pair_selection_seed']}` Β· " | |
| f"Output repo: `{cfg['output_dataset_repo']}`" | |
| ) | |
| if st.button("π Refresh", type="primary"): | |
| # Invalidate caches so we re-scan HF and re-poll Prolific | |
| from src.data import _data_dir | |
| for f in _data_dir(cfg).glob("completion_cache*"): | |
| f.unlink() | |
| prolific_cache = _data_dir(cfg) / "prolific_returned_cache.json" | |
| if prolific_cache.exists(): | |
| prolific_cache.unlink() | |
| st.rerun() | |
| # Release expired + returned/timed-out reservations before displaying | |
| lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10) | |
| with lock: | |
| reservations = _load_reservations(cfg) | |
| _expire_reservations(reservations) | |
| _release_returned_reservations(reservations, cfg) | |
| _save_reservations(reservations, cfg) | |
| for cat_cfg in cfg["categories"]: | |
| cat = cat_cfg["name"] | |
| pool = _load_pool(str(_pool_path(cat, cfg))) | |
| total = len(pool) | |
| counts = _get_accepted_counts(cat, cfg) | |
| covered = sum(1 for v in counts.values() if v >= 1) | |
| reserved_uncovered = sum( | |
| 1 for k in reservations | |
| if counts.get(k, 0) == 0 | |
| ) | |
| truly_uncovered = total - covered - reserved_uncovered | |
| st.markdown(f"### {cat.capitalize()}") | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Total items", total) | |
| col2.metric("Covered β ", covered) | |
| col3.metric("In progress π", reserved_uncovered, | |
| help="Reserved by active Prolific participants") | |
| col4.metric("Still needed β οΈ", truly_uncovered, | |
| delta=f"-{truly_uncovered}" if truly_uncovered > 0 else None, | |
| delta_color="inverse") | |
| if truly_uncovered == 0 and reserved_uncovered == 0: | |
| st.success(f"β All {total} items covered!") | |
| elif truly_uncovered == 0: | |
| st.info(f"π {reserved_uncovered} item(s) in progress.") | |
| else: | |
| st.warning( | |
| f"β οΈ {truly_uncovered} item(s) still need a participant. " | |
| f"Send more Prolific slots." | |
| ) | |
| st.markdown("---") | |
| # --------------------------------------------------------------------------- | |
| # 4. Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| cfg = load_config() | |
| st.set_page_config( | |
| page_title="Product Study", | |
| page_icon="π", | |
| layout="centered", | |
| ) | |
| inject_css() | |
| # Admin dashboard β visit ?admin=1 | |
| try: | |
| params = st.query_params | |
| except Exception: | |
| params = {} | |
| if params.get("admin") == "1": | |
| ensure_datasets(cfg) | |
| _screen_admin(cfg) | |
| return | |
| if "study_state" not in st.session_state: | |
| ensure_datasets(cfg) | |
| st.session_state.study_state = init_state(cfg) | |
| s = st.session_state.study_state | |
| screen = s.get("screen", "welcome") | |
| dispatch = { | |
| "welcome": lambda: screen_welcome(s, cfg), | |
| "demographics": lambda: screen_demographics(s, cfg), | |
| "background": lambda: screen_background(s, cfg), | |
| "item_intro": lambda: ( | |
| screen_pair_intro(s, cfg) | |
| if cfg["study_type"] == "preference" | |
| else screen_item_intro(s, cfg) | |
| ), | |
| "chat": lambda: screen_chat(s, cfg), | |
| "post_rating": lambda: screen_post_rating(s, cfg), | |
| "reflection": lambda: screen_reflection(s, cfg), | |
| "done": lambda: screen_done(s, cfg), | |
| } | |
| handler = dispatch.get(screen) | |
| if handler: | |
| handler() | |
| else: | |
| st.error(f"Unknown screen: {screen!r}") | |
| if __name__ == "__main__": | |
| main() |