ehejin's picture
0505 np3 prolfiic
d1c998c
"""
Entry point for the user study Streamlit app.
Run from repo root:
streamlit run src/app.py
streamlit run src/app.py -- --debug (sets DEBUG_MODE via env)
HuggingFace Space secrets required:
HF_TOKEN β€” read/write HuggingFace token
GH_TOKEN β€” GitHub PAT (ehejin account) for the private lsp submodule
TINKER_API_KEY β€” Tinker inference API key
DEBUG_MODE β€” "true" to skip all validation (optional)
"""
import os
import sys
import json
import subprocess
from pathlib import Path
# ---------------------------------------------------------------------------
# 1. Initialise lsp git submodule before any lsp imports
# On a cold HF Space start the submodule directory exists but is empty;
# GH_TOKEN lets us authenticate against the private GitHub repo.
# ---------------------------------------------------------------------------
_BASE = Path(__file__).resolve().parent.parent
_LSP_PATH = _BASE / "lsp"
def _init_submodule() -> None:
prompts_exist = (_LSP_PATH / "src" / "prompts").exists()
if not prompts_exist:
token = os.getenv("GH_TOKEN", "")
if not token:
raise RuntimeError("GH_TOKEN secret is not set.")
import shutil
import tarfile
import urllib.request
import time as _time
# Clean any stale state
if _LSP_PATH.exists():
shutil.rmtree(str(_LSP_PATH), ignore_errors=True)
git_modules = _BASE / ".git" / "modules" / "lsp"
if git_modules.exists():
shutil.rmtree(str(git_modules), ignore_errors=True)
# GitHub serves a tarball of any branch/tag/SHA at this URL.
# Pinned to a specific commit SHA so future lsp changes don't break us.
branch = "74582acd911f81309ba8b22cef9286c2887dda18"
tarball_url = f"https://api.github.com/repos/batu-el/lsp/tarball/{branch}"
tmp_tar = Path("/tmp/lsp.tar.gz")
tmp_extract = Path("/tmp/lsp_extract")
for attempt in range(1, 4):
print(f"[SUBMODULE] tarball download attempt {attempt}/3 ...")
try:
req = urllib.request.Request(
tarball_url,
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
"User-Agent": "prolific-preferences",
},
)
with urllib.request.urlopen(req, timeout=60) as resp:
tmp_tar.write_bytes(resp.read())
print(f"[SUBMODULE] downloaded {tmp_tar.stat().st_size} bytes")
# Extract
if tmp_extract.exists():
shutil.rmtree(str(tmp_extract), ignore_errors=True)
tmp_extract.mkdir(parents=True)
with tarfile.open(str(tmp_tar)) as tar:
tar.extractall(str(tmp_extract))
# GitHub tarballs have a top-level dir like batu-el-lsp-abc123/
subdirs = [d for d in tmp_extract.iterdir() if d.is_dir()]
if not subdirs:
raise RuntimeError("tarball had no top-level directory")
top = subdirs[0]
# Verify the prompts dir is present
if not (top / "src" / "prompts").exists():
raise RuntimeError(f"src/prompts not found in extracted tarball at {top}")
# Move extracted contents to /app/lsp
shutil.copytree(str(top), str(_LSP_PATH))
tmp_tar.unlink(missing_ok=True)
shutil.rmtree(str(tmp_extract), ignore_errors=True)
print("[SUBMODULE] ready.")
break
except Exception as e:
msg = str(e).replace(token, "***") if token else str(e)
print(f"[SUBMODULE] attempt {attempt} failed: {msg}")
_time.sleep(3)
else:
raise RuntimeError(f"Failed to download lsp tarball after 3 attempts.")
lsp_src = str(_LSP_PATH / "src")
if lsp_src not in sys.path:
sys.path.insert(0, lsp_src)
if str(_BASE) not in sys.path:
sys.path.insert(0, str(_BASE))
_init_submodule()
# Wipe stale local state ONLY on the first container load (not on every Streamlit rerun).
# We use a marker file β€” once created, subsequent imports skip the wipe.
# Completions stay durable in HF; we re-scan HF fresh after wipe.
_data_root = _BASE / "data"
_data_root.mkdir(parents=True, exist_ok=True)
_wipe_marker = _data_root / ".startup_wiped"
if not _wipe_marker.exists():
for pattern in ("reservations.json", "local_completions_*.json", "completion_cache_*.json"):
for f in _data_root.glob(pattern):
try:
f.unlink()
print(f"[STARTUP] Wiped stale file: {f.name}")
except Exception as e:
print(f"[STARTUP] Could not wipe {f.name}: {e}")
_wipe_marker.touch()
print("[STARTUP] Marked container as wiped")
# ---------------------------------------------------------------------------
# 2. App imports (only after submodule is initialised)
# ---------------------------------------------------------------------------
import streamlit as st
from src.config import load_config
from src.data import ensure_datasets, init_state
from src.ui.components import inject_css
from src.ui.screens_shared import (
screen_background,
screen_chat,
screen_demographics,
screen_done,
screen_post_rating,
screen_reflection,
screen_welcome,
)
from src.ui.screens_likelihood import screen_item_intro
from src.ui.screens_preference import screen_pair_intro
# ---------------------------------------------------------------------------
# 3. Admin dashboard β€” visit ?admin=1
# ---------------------------------------------------------------------------
def _screen_admin(cfg: dict) -> None:
"""Coverage dashboard β€” visit ?admin=1 to see this."""
from src.data import (
_get_accepted_counts, _load_pool, _pool_path,
_load_reservations, _save_reservations,
_expire_reservations, _release_returned_reservations,
_reservation_lock_path,
)
from filelock import FileLock
st.markdown("## πŸ“Š Study Coverage Dashboard")
st.caption(
f"Study type: `{cfg['study_type']}` Β· "
f"Seed: `{cfg['pair_selection_seed']}` Β· "
f"Output repo: `{cfg['output_dataset_repo']}`"
)
if st.button("πŸ”„ Refresh", type="primary"):
# Invalidate caches so we re-scan HF and re-poll Prolific
from src.data import _data_dir
for f in _data_dir(cfg).glob("completion_cache*"):
f.unlink()
prolific_cache = _data_dir(cfg) / "prolific_returned_cache.json"
if prolific_cache.exists():
prolific_cache.unlink()
st.rerun()
# Release expired + returned/timed-out reservations before displaying
lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10)
with lock:
reservations = _load_reservations(cfg)
_expire_reservations(reservations)
_release_returned_reservations(reservations, cfg)
_save_reservations(reservations, cfg)
for cat_cfg in cfg["categories"]:
cat = cat_cfg["name"]
pool = _load_pool(str(_pool_path(cat, cfg)))
total = len(pool)
counts = _get_accepted_counts(cat, cfg)
covered = sum(1 for v in counts.values() if v >= 1)
reserved_uncovered = sum(
1 for k in reservations
if counts.get(k, 0) == 0
)
truly_uncovered = total - covered - reserved_uncovered
st.markdown(f"### {cat.capitalize()}")
col1, col2, col3, col4 = st.columns(4)
col1.metric("Total items", total)
col2.metric("Covered βœ…", covered)
col3.metric("In progress πŸ”„", reserved_uncovered,
help="Reserved by active Prolific participants")
col4.metric("Still needed ⚠️", truly_uncovered,
delta=f"-{truly_uncovered}" if truly_uncovered > 0 else None,
delta_color="inverse")
if truly_uncovered == 0 and reserved_uncovered == 0:
st.success(f"βœ… All {total} items covered!")
elif truly_uncovered == 0:
st.info(f"πŸ”„ {reserved_uncovered} item(s) in progress.")
else:
st.warning(
f"⚠️ {truly_uncovered} item(s) still need a participant. "
f"Send more Prolific slots."
)
st.markdown("---")
# ---------------------------------------------------------------------------
# 4. Main
# ---------------------------------------------------------------------------
def main() -> None:
cfg = load_config()
st.set_page_config(
page_title="Product Study",
page_icon="πŸ›’",
layout="centered",
)
inject_css()
# Admin dashboard β€” visit ?admin=1
try:
params = st.query_params
except Exception:
params = {}
if params.get("admin") == "1":
ensure_datasets(cfg)
_screen_admin(cfg)
return
if "study_state" not in st.session_state:
ensure_datasets(cfg)
st.session_state.study_state = init_state(cfg)
s = st.session_state.study_state
screen = s.get("screen", "welcome")
dispatch = {
"welcome": lambda: screen_welcome(s, cfg),
"demographics": lambda: screen_demographics(s, cfg),
"background": lambda: screen_background(s, cfg),
"item_intro": lambda: (
screen_pair_intro(s, cfg)
if cfg["study_type"] == "preference"
else screen_item_intro(s, cfg)
),
"chat": lambda: screen_chat(s, cfg),
"post_rating": lambda: screen_post_rating(s, cfg),
"reflection": lambda: screen_reflection(s, cfg),
"done": lambda: screen_done(s, cfg),
}
handler = dispatch.get(screen)
if handler:
handler()
else:
st.error(f"Unknown screen: {screen!r}")
if __name__ == "__main__":
main()