File size: 10,768 Bytes
bb8ee77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""
Per-token usage tracking for assignment tokens.

Tracks how many runs each assignment token has used and enforces
a maximum. A "run" = one user-triggered generation job (cognitive
interview, expert review, QAS, or silicon sampling).

Storage: JSON file, either local or backed by a HuggingFace dataset
repo for persistence across HF Space restarts.

Usage in dashboard.py β€” see integration instructions at bottom of file.
"""

import csv
import hashlib
import json
import logging
import os
import threading
from pathlib import Path

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

MAX_RUNS_PER_TOKEN = 8

# Per-token overrides (token -> max runs)
_TOKEN_OVERRIDES = {
    "DE9BBFB5": 15,  # Haonan Sun
}


def get_max_runs(token: str) -> int:
    """Return the max runs for a token, checking overrides first."""
    return _TOKEN_OVERRIDES.get(token.strip(), MAX_RUNS_PER_TOKEN)

# Path to the CSV that lists valid tokens (column: "token")
_TOKEN_CSV_PATH = Path(__file__).parent / "student_tokens.csv"

# Local JSON file for usage data
_LOCAL_USAGE_PATH = Path(__file__).parent / "token_usage.json"

# HuggingFace dataset repo for persistent storage (optional).
# Set these environment variables on the HF Space to enable:
#   HF_USAGE_REPO  = "Patricksturg/cogbot-token-usage"  (a private dataset repo)
#   HF_USAGE_TOKEN = a write-access HF token
_HF_USAGE_REPO = os.environ.get("HF_USAGE_REPO", "")
_HF_USAGE_TOKEN = os.environ.get("HF_USAGE_TOKEN", "")
_HF_USAGE_FILENAME = "token_usage.json"

# Thread lock for safe concurrent access within one process
_lock = threading.Lock()


# ---------------------------------------------------------------------------
# Token masking (for safe logging)
# ---------------------------------------------------------------------------

def mask_token(token: str) -> str:
    """Return a short hash for logging. Never log raw tokens."""
    return hashlib.sha256(token.encode()).hexdigest()[:8]


# ---------------------------------------------------------------------------
# Valid token loading
# ---------------------------------------------------------------------------

_valid_tokens: set[str] | None = None


def _load_valid_tokens() -> set[str]:
    """Load the set of valid assignment tokens from the CSV."""
    global _valid_tokens
    if _valid_tokens is not None:
        return _valid_tokens

    tokens = set()
    path = _TOKEN_CSV_PATH
    if path.exists():
        with open(path, newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                t = row.get("token", "").strip()
                if t:
                    tokens.add(t)
        logger.info(f"Loaded {len(tokens)} valid assignment tokens")
    else:
        logger.warning(f"Token CSV not found: {path}")

    _valid_tokens = tokens
    return _valid_tokens


def is_valid_token(token: str) -> bool:
    """Check whether a token is in the valid set."""
    return token.strip() in _load_valid_tokens()


# ---------------------------------------------------------------------------
# Usage persistence β€” local JSON
# ---------------------------------------------------------------------------

def _load_local() -> dict:
    if _LOCAL_USAGE_PATH.exists():
        try:
            with open(_LOCAL_USAGE_PATH) as f:
                return json.load(f)
        except (json.JSONDecodeError, OSError):
            logger.warning("Corrupt local usage file β€” starting fresh")
    return {}


def _save_local(data: dict):
    with open(_LOCAL_USAGE_PATH, "w") as f:
        json.dump(data, f, indent=2)


# ---------------------------------------------------------------------------
# Usage persistence β€” HuggingFace dataset repo
# ---------------------------------------------------------------------------

def _hf_enabled() -> bool:
    return bool(_HF_USAGE_REPO and _HF_USAGE_TOKEN)


def _load_hf() -> dict:
    """Download token_usage.json from the HF dataset repo."""
    try:
        from huggingface_hub import hf_hub_download
        path = hf_hub_download(
            repo_id=_HF_USAGE_REPO,
            filename=_HF_USAGE_FILENAME,
            repo_type="dataset",
            token=_HF_USAGE_TOKEN,
        )
        with open(path) as f:
            return json.load(f)
    except Exception as e:
        logger.warning(f"Could not load usage from HF repo: {e}")
        return {}


def _save_hf(data: dict):
    """Upload token_usage.json to the HF dataset repo."""
    try:
        from huggingface_hub import HfApi
        import tempfile

        api = HfApi(token=_HF_USAGE_TOKEN)
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".json", delete=False
        ) as tmp:
            json.dump(data, tmp, indent=2)
            tmp_path = tmp.name

        api.upload_file(
            path_or_fileobj=tmp_path,
            path_in_repo=_HF_USAGE_FILENAME,
            repo_id=_HF_USAGE_REPO,
            repo_type="dataset",
            commit_message="Update token usage",
        )
        os.unlink(tmp_path)
    except Exception as e:
        logger.error(f"Could not save usage to HF repo: {e}")


# ---------------------------------------------------------------------------
# Unified load / save (prefers HF if configured, falls back to local)
# ---------------------------------------------------------------------------

def load_token_usage() -> dict:
    """Load usage data. Returns dict like {"TOKEN": {"runs_used": 3}, ...}."""
    if _hf_enabled():
        return _load_hf()
    return _load_local()


def save_token_usage(data: dict):
    """Persist usage data."""
    _save_local(data)          # always write local copy
    if _hf_enabled():
        _save_hf(data)         # also push to HF if configured


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def get_runs_used(token: str) -> int:
    """Return how many runs this token has consumed."""
    with _lock:
        data = load_token_usage()
        return data.get(token, {}).get("runs_used", 0)


def get_remaining_runs(token: str) -> int:
    """Return how many runs this token has left."""
    return get_max_runs(token) - get_runs_used(token)


def can_use_token(token: str) -> tuple[bool, str]:
    """
    Validate a token and check its quota.

    Returns (ok, message):
        (True, "N runs remaining")  β€” token is valid and has quota
        (False, reason)             β€” token is invalid or exhausted
    """
    token = token.strip()
    if not is_valid_token(token):
        return False, "Invalid assignment token."

    remaining = get_remaining_runs(token)
    if remaining <= 0:
        return False, (
            "You have reached the maximum number of runs allowed for this "
            "assignment token. If you encountered a technical problem, "
            "contact Patrick Sturgis."
        )
    return True, f"{remaining} runs remaining"


def increment_token_usage(token: str):
    """Increment the run count for a token. Persists immediately."""
    with _lock:
        data = load_token_usage()
        entry = data.setdefault(token, {"runs_used": 0})
        entry["runs_used"] += 1
        save_token_usage(data)
        logger.info(
            f"Token {mask_token(token)}: runs_used -> {entry['runs_used']}"
        )


def rollback_token_usage(token: str):
    """Decrement the run count (on early failure). Persists immediately."""
    with _lock:
        data = load_token_usage()
        entry = data.get(token)
        if entry and entry.get("runs_used", 0) > 0:
            entry["runs_used"] -= 1
            save_token_usage(data)
            logger.info(
                f"Token {mask_token(token)}: rollback -> {entry['runs_used']}"
            )


def reset_token(token: str):
    """Reset a token's usage to zero. For admin use."""
    with _lock:
        data = load_token_usage()
        data[token] = {"runs_used": 0}
        save_token_usage(data)


def reset_all_tokens():
    """Reset all token usage to zero. For admin use."""
    with _lock:
        save_token_usage({})


# ---------------------------------------------------------------------------
# Integration instructions for dashboard.py (HF version)
# ---------------------------------------------------------------------------
#
# Below is the code to integrate into the HF Space dashboard.py.
# The exact line numbers will differ β€” adapt to the live file.
#
# ── 1. IMPORT (top of file) ──────────────────────────────────────────────
#
#   from token_usage import (
#       can_use_token, increment_token_usage, rollback_token_usage,
#       get_remaining_runs, is_valid_token, MAX_RUNS_PER_TOKEN,
#   )
#
# ── 2. SESSION STATE (near other session_state inits) ────────────────────
#
#   if 'active_token' not in st.session_state:
#       st.session_state.active_token = None
#
# ── 3. UI: show remaining runs (after token input widget) ────────────────
#
#   # Assuming assignment_token is the text_input value:
#   if assignment_token:
#       ok, msg = can_use_token(assignment_token)
#       if ok:
#           st.session_state.active_token = assignment_token
#           st.sidebar.info(f"βœ… {msg}")
#       else:
#           st.session_state.active_token = None
#           st.sidebar.error(msg)
#
# ── 4. GATE: before the run starts (after button click, before sampler) ──
#
#   # Inside the `if st.button(...)` block, before any API calls:
#   if st.session_state.active_token:
#       ok, msg = can_use_token(st.session_state.active_token)
#       if not ok:
#           st.error(msg)
#           st.session_state.processing = False
#           st.stop()
#       increment_token_usage(st.session_state.active_token)
#
# ── 5. ROLLBACK: in the except block for system errors ───────────────────
#
#   except Exception as e:
#       if st.session_state.get('active_token'):
#           rollback_token_usage(st.session_state.active_token)
#       st.error(f"Error: {e}")
#       st.session_state.processing = False
#
# ── 6. DOUBLE-CLICK GUARD ────────────────────────────────────────────────
#
#   The existing `st.session_state.processing` flag already prevents this:
#   the button is disabled while processing=True. No extra work needed.
#