solution_challenge_backend / backend /refresh_token_store.py
github-actions
Deploy to Hugging Face
c794b6b
Raw
History Blame Contribute Delete
2.62 kB
"""
Refresh token persistence — file-backed by default; optional Redis for multi-instance Cloud Run.
Set REDIS_URL for shared session store across replicas.
"""
from __future__ import annotations
import json
import logging
import os
import threading
import time
from typing import Any, Callable, Optional
import persistence as persist
logger = logging.getLogger(__name__)
_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
_REFRESH_FILE = os.path.join(_DATA_DIR, "refresh_tokens.json")
_KEY_PREFIX = "cepheus:refresh:"
_file_rmw_lock = threading.Lock()
def _mutate_file_store(mutator: Callable[[dict[str, dict[str, Any]]], None]) -> None:
"""Atomic read-modify-write for the file-backed refresh token store."""
with _file_rmw_lock:
tokens = load_all()
mutator(tokens)
save_all(tokens)
def _redis():
url = os.getenv("REDIS_URL", "").strip()
if not url:
return None
try:
import redis # type: ignore
return redis.from_url(url, decode_responses=True)
except Exception as exc:
logger.warning("REDIS_URL set but redis unavailable: %s", exc)
return None
def load_all() -> dict[str, dict[str, Any]]:
r = _redis()
if r:
return {}
os.makedirs(_DATA_DIR, exist_ok=True)
if not os.path.exists(_REFRESH_FILE):
return {}
data = persist.load_json(_REFRESH_FILE, {})
if isinstance(data, dict):
now = int(time.time())
return {k: v for k, v in data.items() if v.get("exp", 0) > now}
return {}
def save_all(tokens: dict[str, dict[str, Any]]) -> None:
r = _redis()
if r:
return
persist.save_json(_REFRESH_FILE, tokens)
def set_entry(jti: str, entry: dict[str, Any], ttl_seconds: int) -> None:
r = _redis()
if r:
r.setex(f"{_KEY_PREFIX}{jti}", ttl_seconds, json.dumps(entry))
return
def _upsert(tokens: dict[str, dict[str, Any]]) -> None:
tokens[jti] = entry
_mutate_file_store(_upsert)
def get_entry(jti: str) -> Optional[dict[str, Any]]:
r = _redis()
if r:
raw = r.get(f"{_KEY_PREFIX}{jti}")
if not raw:
return None
try:
return json.loads(raw)
except json.JSONDecodeError:
return None
return load_all().get(jti)
def delete_entry(jti: str) -> None:
r = _redis()
if r:
r.delete(f"{_KEY_PREFIX}{jti}")
return
def _delete(tokens: dict[str, dict[str, Any]]) -> None:
tokens.pop(jti, None)
_mutate_file_store(_delete)
def using_redis() -> bool:
return _redis() is not None