File size: 10,768 Bytes
bb8ee77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | """
Per-token usage tracking for assignment tokens.
Tracks how many runs each assignment token has used and enforces
a maximum. A "run" = one user-triggered generation job (cognitive
interview, expert review, QAS, or silicon sampling).
Storage: JSON file, either local or backed by a HuggingFace dataset
repo for persistence across HF Space restarts.
Usage in dashboard.py β see integration instructions at bottom of file.
"""
import csv
import hashlib
import json
import logging
import os
import threading
from pathlib import Path
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MAX_RUNS_PER_TOKEN = 8
# Per-token overrides (token -> max runs)
_TOKEN_OVERRIDES = {
"DE9BBFB5": 15, # Haonan Sun
}
def get_max_runs(token: str) -> int:
"""Return the max runs for a token, checking overrides first."""
return _TOKEN_OVERRIDES.get(token.strip(), MAX_RUNS_PER_TOKEN)
# Path to the CSV that lists valid tokens (column: "token")
_TOKEN_CSV_PATH = Path(__file__).parent / "student_tokens.csv"
# Local JSON file for usage data
_LOCAL_USAGE_PATH = Path(__file__).parent / "token_usage.json"
# HuggingFace dataset repo for persistent storage (optional).
# Set these environment variables on the HF Space to enable:
# HF_USAGE_REPO = "Patricksturg/cogbot-token-usage" (a private dataset repo)
# HF_USAGE_TOKEN = a write-access HF token
_HF_USAGE_REPO = os.environ.get("HF_USAGE_REPO", "")
_HF_USAGE_TOKEN = os.environ.get("HF_USAGE_TOKEN", "")
_HF_USAGE_FILENAME = "token_usage.json"
# Thread lock for safe concurrent access within one process
_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Token masking (for safe logging)
# ---------------------------------------------------------------------------
def mask_token(token: str) -> str:
"""Return a short hash for logging. Never log raw tokens."""
return hashlib.sha256(token.encode()).hexdigest()[:8]
# ---------------------------------------------------------------------------
# Valid token loading
# ---------------------------------------------------------------------------
_valid_tokens: set[str] | None = None
def _load_valid_tokens() -> set[str]:
"""Load the set of valid assignment tokens from the CSV."""
global _valid_tokens
if _valid_tokens is not None:
return _valid_tokens
tokens = set()
path = _TOKEN_CSV_PATH
if path.exists():
with open(path, newline="") as f:
reader = csv.DictReader(f)
for row in reader:
t = row.get("token", "").strip()
if t:
tokens.add(t)
logger.info(f"Loaded {len(tokens)} valid assignment tokens")
else:
logger.warning(f"Token CSV not found: {path}")
_valid_tokens = tokens
return _valid_tokens
def is_valid_token(token: str) -> bool:
"""Check whether a token is in the valid set."""
return token.strip() in _load_valid_tokens()
# ---------------------------------------------------------------------------
# Usage persistence β local JSON
# ---------------------------------------------------------------------------
def _load_local() -> dict:
if _LOCAL_USAGE_PATH.exists():
try:
with open(_LOCAL_USAGE_PATH) as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt local usage file β starting fresh")
return {}
def _save_local(data: dict):
with open(_LOCAL_USAGE_PATH, "w") as f:
json.dump(data, f, indent=2)
# ---------------------------------------------------------------------------
# Usage persistence β HuggingFace dataset repo
# ---------------------------------------------------------------------------
def _hf_enabled() -> bool:
return bool(_HF_USAGE_REPO and _HF_USAGE_TOKEN)
def _load_hf() -> dict:
"""Download token_usage.json from the HF dataset repo."""
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=_HF_USAGE_REPO,
filename=_HF_USAGE_FILENAME,
repo_type="dataset",
token=_HF_USAGE_TOKEN,
)
with open(path) as f:
return json.load(f)
except Exception as e:
logger.warning(f"Could not load usage from HF repo: {e}")
return {}
def _save_hf(data: dict):
"""Upload token_usage.json to the HF dataset repo."""
try:
from huggingface_hub import HfApi
import tempfile
api = HfApi(token=_HF_USAGE_TOKEN)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp:
json.dump(data, tmp, indent=2)
tmp_path = tmp.name
api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo=_HF_USAGE_FILENAME,
repo_id=_HF_USAGE_REPO,
repo_type="dataset",
commit_message="Update token usage",
)
os.unlink(tmp_path)
except Exception as e:
logger.error(f"Could not save usage to HF repo: {e}")
# ---------------------------------------------------------------------------
# Unified load / save (prefers HF if configured, falls back to local)
# ---------------------------------------------------------------------------
def load_token_usage() -> dict:
"""Load usage data. Returns dict like {"TOKEN": {"runs_used": 3}, ...}."""
if _hf_enabled():
return _load_hf()
return _load_local()
def save_token_usage(data: dict):
"""Persist usage data."""
_save_local(data) # always write local copy
if _hf_enabled():
_save_hf(data) # also push to HF if configured
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def get_runs_used(token: str) -> int:
"""Return how many runs this token has consumed."""
with _lock:
data = load_token_usage()
return data.get(token, {}).get("runs_used", 0)
def get_remaining_runs(token: str) -> int:
"""Return how many runs this token has left."""
return get_max_runs(token) - get_runs_used(token)
def can_use_token(token: str) -> tuple[bool, str]:
"""
Validate a token and check its quota.
Returns (ok, message):
(True, "N runs remaining") β token is valid and has quota
(False, reason) β token is invalid or exhausted
"""
token = token.strip()
if not is_valid_token(token):
return False, "Invalid assignment token."
remaining = get_remaining_runs(token)
if remaining <= 0:
return False, (
"You have reached the maximum number of runs allowed for this "
"assignment token. If you encountered a technical problem, "
"contact Patrick Sturgis."
)
return True, f"{remaining} runs remaining"
def increment_token_usage(token: str):
"""Increment the run count for a token. Persists immediately."""
with _lock:
data = load_token_usage()
entry = data.setdefault(token, {"runs_used": 0})
entry["runs_used"] += 1
save_token_usage(data)
logger.info(
f"Token {mask_token(token)}: runs_used -> {entry['runs_used']}"
)
def rollback_token_usage(token: str):
"""Decrement the run count (on early failure). Persists immediately."""
with _lock:
data = load_token_usage()
entry = data.get(token)
if entry and entry.get("runs_used", 0) > 0:
entry["runs_used"] -= 1
save_token_usage(data)
logger.info(
f"Token {mask_token(token)}: rollback -> {entry['runs_used']}"
)
def reset_token(token: str):
"""Reset a token's usage to zero. For admin use."""
with _lock:
data = load_token_usage()
data[token] = {"runs_used": 0}
save_token_usage(data)
def reset_all_tokens():
"""Reset all token usage to zero. For admin use."""
with _lock:
save_token_usage({})
# ---------------------------------------------------------------------------
# Integration instructions for dashboard.py (HF version)
# ---------------------------------------------------------------------------
#
# Below is the code to integrate into the HF Space dashboard.py.
# The exact line numbers will differ β adapt to the live file.
#
# ββ 1. IMPORT (top of file) ββββββββββββββββββββββββββββββββββββββββββββββ
#
# from token_usage import (
# can_use_token, increment_token_usage, rollback_token_usage,
# get_remaining_runs, is_valid_token, MAX_RUNS_PER_TOKEN,
# )
#
# ββ 2. SESSION STATE (near other session_state inits) ββββββββββββββββββββ
#
# if 'active_token' not in st.session_state:
# st.session_state.active_token = None
#
# ββ 3. UI: show remaining runs (after token input widget) ββββββββββββββββ
#
# # Assuming assignment_token is the text_input value:
# if assignment_token:
# ok, msg = can_use_token(assignment_token)
# if ok:
# st.session_state.active_token = assignment_token
# st.sidebar.info(f"β
{msg}")
# else:
# st.session_state.active_token = None
# st.sidebar.error(msg)
#
# ββ 4. GATE: before the run starts (after button click, before sampler) ββ
#
# # Inside the `if st.button(...)` block, before any API calls:
# if st.session_state.active_token:
# ok, msg = can_use_token(st.session_state.active_token)
# if not ok:
# st.error(msg)
# st.session_state.processing = False
# st.stop()
# increment_token_usage(st.session_state.active_token)
#
# ββ 5. ROLLBACK: in the except block for system errors βββββββββββββββββββ
#
# except Exception as e:
# if st.session_state.get('active_token'):
# rollback_token_usage(st.session_state.active_token)
# st.error(f"Error: {e}")
# st.session_state.processing = False
#
# ββ 6. DOUBLE-CLICK GUARD ββββββββββββββββββββββββββββββββββββββββββββββββ
#
# The existing `st.session_state.processing` flag already prevents this:
# the button is disabled while processing=True. No extra work needed.
#
|