File size: 10,216 Bytes
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e3420
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47a4db5
350ec72
47a4db5
89f6cd6
 
 
350ec72
89f6cd6
47a4db5
350ec72
 
 
 
 
0f4326e
 
d1c998c
89f6cd6
 
 
350ec72
 
89f6cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350ec72
89f6cd6
 
 
 
350ec72
89f6cd6
d35bda7
6b23da9
 
 
d35bda7
 
2984dc5
6b23da9
 
 
89f6cd6
 
0300fef
 
89f6cd6
 
 
 
 
 
 
 
 
 
 
 
b610725
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2984dc5
 
 
b610725
ce33f0d
 
0f4326e
 
 
ce33f0d
0f4326e
2984dc5
 
 
 
 
 
 
 
 
0f4326e
ce33f0d
 
 
0f4326e
 
 
2984dc5
 
0f4326e
 
 
 
 
 
 
 
2984dc5
ce33f0d
 
06e3420
2984dc5
0f4326e
b610725
ce33f0d
2984dc5
ce33f0d
 
2984dc5
b610725
2984dc5
 
 
 
 
 
0f4326e
2984dc5
 
 
 
 
 
 
b610725
2984dc5
 
 
b610725
2984dc5
 
 
 
 
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
2984dc5
 
 
 
 
 
 
 
 
 
6b23da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
Entry point for the user study Streamlit app.

Run from repo root:
    streamlit run src/app.py
    streamlit run src/app.py -- --debug   (sets DEBUG_MODE via env)

HuggingFace Space secrets required:
    HF_TOKEN        β€” read/write HuggingFace token
    GH_TOKEN        β€” GitHub PAT (ehejin account) for the private lsp submodule
    TINKER_API_KEY  β€” Tinker inference API key
    DEBUG_MODE      β€” "true" to skip all validation (optional)
"""
import os
import sys
import json
import subprocess
from pathlib import Path

# ---------------------------------------------------------------------------
# 1. Initialise lsp git submodule before any lsp imports
#    On a cold HF Space start the submodule directory exists but is empty;
#    GH_TOKEN lets us authenticate against the private GitHub repo.
# ---------------------------------------------------------------------------
_BASE     = Path(__file__).resolve().parent.parent
_LSP_PATH = _BASE / "lsp"


def _init_submodule() -> None:
    prompts_exist = (_LSP_PATH / "src" / "prompts").exists()
    if not prompts_exist:
        token = os.getenv("GH_TOKEN", "")
        if not token:
            raise RuntimeError("GH_TOKEN secret is not set.")

        import shutil
        import tarfile
        import urllib.request
        import time as _time

        # Clean any stale state
        if _LSP_PATH.exists():
            shutil.rmtree(str(_LSP_PATH), ignore_errors=True)
        git_modules = _BASE / ".git" / "modules" / "lsp"
        if git_modules.exists():
            shutil.rmtree(str(git_modules), ignore_errors=True)

        # GitHub serves a tarball of any branch/tag/SHA at this URL.
        # Pinned to a specific commit SHA so future lsp changes don't break us.
        branch      = "74582acd911f81309ba8b22cef9286c2887dda18"
        tarball_url = f"https://api.github.com/repos/batu-el/lsp/tarball/{branch}"
        tmp_tar     = Path("/tmp/lsp.tar.gz")
        tmp_extract = Path("/tmp/lsp_extract")

        for attempt in range(1, 4):
            print(f"[SUBMODULE] tarball download attempt {attempt}/3 ...")
            try:
                req = urllib.request.Request(
                    tarball_url,
                    headers={
                        "Authorization": f"Bearer {token}",
                        "Accept":        "application/vnd.github+json",
                        "User-Agent":    "prolific-preferences",
                    },
                )
                with urllib.request.urlopen(req, timeout=60) as resp:
                    tmp_tar.write_bytes(resp.read())
                print(f"[SUBMODULE] downloaded {tmp_tar.stat().st_size} bytes")

                # Extract
                if tmp_extract.exists():
                    shutil.rmtree(str(tmp_extract), ignore_errors=True)
                tmp_extract.mkdir(parents=True)
                with tarfile.open(str(tmp_tar)) as tar:
                    tar.extractall(str(tmp_extract))

                # GitHub tarballs have a top-level dir like batu-el-lsp-abc123/
                subdirs = [d for d in tmp_extract.iterdir() if d.is_dir()]
                if not subdirs:
                    raise RuntimeError("tarball had no top-level directory")
                top = subdirs[0]

                # Verify the prompts dir is present
                if not (top / "src" / "prompts").exists():
                    raise RuntimeError(f"src/prompts not found in extracted tarball at {top}")

                # Move extracted contents to /app/lsp
                shutil.copytree(str(top), str(_LSP_PATH))
                tmp_tar.unlink(missing_ok=True)
                shutil.rmtree(str(tmp_extract), ignore_errors=True)

                print("[SUBMODULE] ready.")
                break
            except Exception as e:
                msg = str(e).replace(token, "***") if token else str(e)
                print(f"[SUBMODULE] attempt {attempt} failed: {msg}")
                _time.sleep(3)
        else:
            raise RuntimeError(f"Failed to download lsp tarball after 3 attempts.")

    lsp_src = str(_LSP_PATH / "src")
    if lsp_src not in sys.path:
        sys.path.insert(0, lsp_src)
    if str(_BASE) not in sys.path:
        sys.path.insert(0, str(_BASE))


_init_submodule()

# Wipe stale local state ONLY on the first container load (not on every Streamlit rerun).
# We use a marker file β€” once created, subsequent imports skip the wipe.
# Completions stay durable in HF; we re-scan HF fresh after wipe.
_data_root = _BASE / "data"
_data_root.mkdir(parents=True, exist_ok=True)
_wipe_marker = _data_root / ".startup_wiped"
if not _wipe_marker.exists():
    for pattern in ("reservations.json", "local_completions_*.json", "completion_cache_*.json"):
        for f in _data_root.glob(pattern):
            try:
                f.unlink()
                print(f"[STARTUP] Wiped stale file: {f.name}")
            except Exception as e:
                print(f"[STARTUP] Could not wipe {f.name}: {e}")
    _wipe_marker.touch()
    print("[STARTUP] Marked container as wiped")

# ---------------------------------------------------------------------------
# 2. App imports (only after submodule is initialised)
# ---------------------------------------------------------------------------
import streamlit as st

from src.config import load_config
from src.data import ensure_datasets, init_state
from src.ui.components import inject_css
from src.ui.screens_shared import (
    screen_background,
    screen_chat,
    screen_demographics,
    screen_done,
    screen_post_rating,
    screen_reflection,
    screen_welcome,
)
from src.ui.screens_likelihood import screen_item_intro
from src.ui.screens_preference import screen_pair_intro


# ---------------------------------------------------------------------------
# 3. Admin dashboard β€” visit ?admin=1
# ---------------------------------------------------------------------------
def _screen_admin(cfg: dict) -> None:
    """Coverage dashboard β€” visit ?admin=1 to see this."""
    from src.data import (
        _get_accepted_counts, _load_pool, _pool_path,
        _load_reservations, _save_reservations,
        _expire_reservations, _release_returned_reservations,
        _reservation_lock_path,
    )
    from filelock import FileLock

    st.markdown("## πŸ“Š Study Coverage Dashboard")
    st.caption(
        f"Study type: `{cfg['study_type']}` Β· "
        f"Seed: `{cfg['pair_selection_seed']}` Β· "
        f"Output repo: `{cfg['output_dataset_repo']}`"
    )

    if st.button("πŸ”„ Refresh", type="primary"):
        # Invalidate caches so we re-scan HF and re-poll Prolific
        from src.data import _data_dir
        for f in _data_dir(cfg).glob("completion_cache*"):
            f.unlink()
        prolific_cache = _data_dir(cfg) / "prolific_returned_cache.json"
        if prolific_cache.exists():
            prolific_cache.unlink()
        st.rerun()

    # Release expired + returned/timed-out reservations before displaying
    lock = FileLock(str(_reservation_lock_path(cfg)), timeout=10)
    with lock:
        reservations = _load_reservations(cfg)
        _expire_reservations(reservations)
        _release_returned_reservations(reservations, cfg)
        _save_reservations(reservations, cfg)

    for cat_cfg in cfg["categories"]:
        cat   = cat_cfg["name"]
        pool  = _load_pool(str(_pool_path(cat, cfg)))
        total = len(pool)

        counts = _get_accepted_counts(cat, cfg)

        covered = sum(1 for v in counts.values() if v >= 1)
        reserved_uncovered = sum(
            1 for k in reservations
            if counts.get(k, 0) == 0
        )
        truly_uncovered = total - covered - reserved_uncovered

        st.markdown(f"### {cat.capitalize()}")
        col1, col2, col3, col4 = st.columns(4)
        col1.metric("Total items",     total)
        col2.metric("Covered βœ…",      covered)
        col3.metric("In progress πŸ”„",  reserved_uncovered,
                    help="Reserved by active Prolific participants")
        col4.metric("Still needed ⚠️", truly_uncovered,
                    delta=f"-{truly_uncovered}" if truly_uncovered > 0 else None,
                    delta_color="inverse")

        if truly_uncovered == 0 and reserved_uncovered == 0:
            st.success(f"βœ… All {total} items covered!")
        elif truly_uncovered == 0:
            st.info(f"πŸ”„ {reserved_uncovered} item(s) in progress.")
        else:
            st.warning(
                f"⚠️ {truly_uncovered} item(s) still need a participant. "
                f"Send more Prolific slots."
            )

        st.markdown("---")


# ---------------------------------------------------------------------------
# 4. Main
# ---------------------------------------------------------------------------
def main() -> None:
    cfg = load_config()

    st.set_page_config(
        page_title="Product Study",
        page_icon="πŸ›’",
        layout="centered",
    )
    inject_css()

    # Admin dashboard β€” visit ?admin=1
    try:
        params = st.query_params
    except Exception:
        params = {}
    if params.get("admin") == "1":
        ensure_datasets(cfg)
        _screen_admin(cfg)
        return

    if "study_state" not in st.session_state:
        ensure_datasets(cfg)
        st.session_state.study_state = init_state(cfg)

    s      = st.session_state.study_state
    screen = s.get("screen", "welcome")

    dispatch = {
        "welcome":      lambda: screen_welcome(s, cfg),
        "demographics": lambda: screen_demographics(s, cfg),
        "background":   lambda: screen_background(s, cfg),
        "item_intro":   lambda: (
            screen_pair_intro(s, cfg)
            if cfg["study_type"] == "preference"
            else screen_item_intro(s, cfg)
        ),
        "chat":         lambda: screen_chat(s, cfg),
        "post_rating":  lambda: screen_post_rating(s, cfg),
        "reflection":   lambda: screen_reflection(s, cfg),
        "done":         lambda: screen_done(s, cfg),
    }

    handler = dispatch.get(screen)
    if handler:
        handler()
    else:
        st.error(f"Unknown screen: {screen!r}")


if __name__ == "__main__":
    main()