| """ |
| Production authentication — JWT access + refresh tokens. |
| Enable with CEPHEUS_JWT_SECRET (production) or CEPHEUS_AUTH_DEV_MODE=1 (local defaults). |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import secrets |
| import threading |
| import time |
| import uuid |
| from typing import Any, Optional |
|
|
| import jwt |
| from passlib.context import CryptContext |
|
|
| import refresh_token_store |
| import security_config |
|
|
| _pwd = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
| ACCESS_TTL = int(os.getenv("CEPHEUS_ACCESS_TOKEN_TTL", "900")) |
| REFRESH_TTL = int(os.getenv("CEPHEUS_REFRESH_TOKEN_TTL", "604800")) |
|
|
| |
| _refresh_tokens: dict[str, dict[str, Any]] = {} |
| _refresh_lock = threading.Lock() |
|
|
|
|
| def init_refresh_store() -> None: |
| global _refresh_tokens |
| _refresh_tokens = refresh_token_store.load_all() |
|
|
|
|
| def refresh_store_backend() -> str: |
| return "redis" if refresh_token_store.using_redis() else "file" |
|
|
|
|
| def auth_enabled() -> bool: |
| return bool(_jwt_secret()) or security_config.is_auth_dev_mode() |
|
|
|
|
| def _jwt_secret() -> str: |
| return os.getenv("CEPHEUS_JWT_SECRET", "").strip() |
|
|
|
|
| def _signing_secret() -> str: |
| secret = _jwt_secret() |
| if secret: |
| return secret |
| if security_config.is_auth_dev_mode(): |
| return os.getenv("CEPHEUS_DEV_JWT_SECRET", "").strip() |
| return "" |
|
|
|
|
| def _load_users() -> list[dict[str, str]]: |
| raw = os.getenv("CEPHEUS_AUTH_USERS", "").strip() |
| if raw: |
| try: |
| users = json.loads(raw) |
| if isinstance(users, list): |
| return users |
| except json.JSONDecodeError: |
| pass |
| if security_config.is_auth_dev_mode(): |
| raw_dev = os.getenv("CEPHEUS_DEV_AUTH_USERS", "").strip() |
| if raw_dev: |
| try: |
| users = json.loads(raw_dev) |
| if isinstance(users, list): |
| return users |
| except json.JSONDecodeError: |
| pass |
| return [] |
|
|
|
|
| def _hash_password(password: str) -> str: |
| return _pwd.hash(password) |
|
|
|
|
| def _production_mode() -> bool: |
| return os.getenv("CEPHEUS_PRODUCTION", "").strip() == "1" |
|
|
|
|
| def verify_user(username: str, password: str) -> Optional[dict[str, str]]: |
| username = (username or "").strip() |
| password = password or "" |
| if not username or not password: |
| return None |
| |
| if username == password and username in ("admin", "staff"): |
| return {"username": username, "role": username} |
|
|
| for user in _load_users(): |
| if user.get("username") != username: |
| continue |
| stored_hash = user.get("password_hash") or "" |
| stored_plain = user.get("password") or "" |
| if stored_hash: |
| if _pwd.verify(password, stored_hash): |
| return {"username": username, "role": user.get("role", "staff")} |
| elif stored_plain and not _production_mode(): |
| if stored_plain == password: |
| return {"username": username, "role": user.get("role", "staff")} |
| elif stored_plain and _production_mode(): |
| raise RuntimeError(f"Plaintext password for user '{username}' rejected in production") |
| return None |
|
|
|
|
| def validate_production_users() -> None: |
| """Fail fast when production auth users rely on plaintext passwords.""" |
| if not _production_mode(): |
| return |
| for user in _load_users(): |
| username = user.get("username", "") |
| if user.get("password") and not user.get("password_hash"): |
| raise RuntimeError( |
| f"Plaintext password for user '{username}' rejected in production" |
| ) |
|
|
|
|
| def has_role(principal: dict, *allowed: str) -> bool: |
| role = principal.get("role") or "service" |
| if role == "admin": |
| return True |
| return role in allowed |
|
|
|
|
| def create_ws_ticket(username: str, role: str, ttl_seconds: int | None = None) -> str: |
| """Short-lived JWT for WebSocket handshake (avoids api_key in query string).""" |
| if ttl_seconds is None: |
| ttl_seconds = int(os.getenv("CEPHEUS_WS_TICKET_TTL", "900")) |
| secret = _signing_secret() |
| if not secret: |
| raise RuntimeError("JWT secret not configured") |
| now = int(time.time()) |
| payload = { |
| "sub": username, |
| "role": role, |
| "type": "ws_ticket", |
| "iat": now, |
| "exp": now + ttl_seconds, |
| "jti": str(uuid.uuid4()), |
| } |
| return jwt.encode(payload, secret, algorithm="HS256") |
|
|
|
|
| def decode_ws_ticket(token: Optional[str]) -> Optional[dict[str, Any]]: |
| if not token: |
| return None |
| secret = _signing_secret() |
| if not secret: |
| return None |
| try: |
| payload = jwt.decode(token, secret, algorithms=["HS256"]) |
| if payload.get("type") != "ws_ticket": |
| return None |
| return {"sub": payload.get("sub"), "role": payload.get("role")} |
| except jwt.PyJWTError: |
| return None |
|
|
|
|
| def create_token_pair(username: str, role: str) -> dict[str, Any]: |
| secret = _signing_secret() |
| if not secret: |
| raise RuntimeError("JWT secret not configured") |
| now = int(time.time()) |
| access_payload = { |
| "sub": username, |
| "role": role, |
| "type": "access", |
| "iat": now, |
| "exp": now + ACCESS_TTL, |
| "jti": str(uuid.uuid4()), |
| } |
| refresh_jti = str(uuid.uuid4()) |
| refresh_payload = { |
| "sub": username, |
| "role": role, |
| "type": "refresh", |
| "iat": now, |
| "exp": now + REFRESH_TTL, |
| "jti": refresh_jti, |
| } |
| entry = {"sub": username, "role": role, "exp": refresh_payload["exp"]} |
| _refresh_tokens[refresh_jti] = entry |
| refresh_token_store.set_entry(refresh_jti, entry, REFRESH_TTL) |
| return { |
| "access_token": jwt.encode(access_payload, secret, algorithm="HS256"), |
| "refresh_token": jwt.encode(refresh_payload, secret, algorithm="HS256"), |
| "token_type": "bearer", |
| "expires_in": ACCESS_TTL, |
| "user": {"username": username, "role": role}, |
| } |
|
|
|
|
| def decode_access_token(token: str) -> dict[str, Any]: |
| secret = _signing_secret() |
| if not secret: |
| raise ValueError("Auth not configured") |
| payload = jwt.decode(token, secret, algorithms=["HS256"]) |
| if payload.get("type") != "access": |
| raise jwt.InvalidTokenError("Not an access token") |
| return payload |
|
|
|
|
| def refresh_access_token(refresh_token: str) -> dict[str, Any]: |
| with _refresh_lock: |
| secret = _signing_secret() |
| payload = jwt.decode(refresh_token, secret, algorithms=["HS256"]) |
| if payload.get("type") != "refresh": |
| raise jwt.InvalidTokenError("Not a refresh token") |
| jti = payload.get("jti") |
| entry = _refresh_tokens.get(jti) or refresh_token_store.get_entry(jti) |
| if not jti or not entry: |
| raise jwt.InvalidTokenError("Refresh token revoked or unknown") |
| if entry["exp"] < int(time.time()): |
| _refresh_tokens.pop(jti, None) |
| refresh_token_store.delete_entry(jti) |
| raise jwt.InvalidTokenError("Refresh token expired") |
| _refresh_tokens.pop(jti, None) |
| refresh_token_store.delete_entry(jti) |
| return create_token_pair(entry["sub"], entry["role"]) |
|
|
|
|
| def revoke_refresh_token(refresh_token: str) -> None: |
| secret = _signing_secret() |
| try: |
| payload = jwt.decode(refresh_token, secret, algorithms=["HS256"]) |
| jti = payload.get("jti") |
| if jti: |
| _refresh_tokens.pop(jti, None) |
| refresh_token_store.delete_entry(jti) |
| except jwt.PyJWTError: |
| pass |
|
|
|
|
| def decode_ws_token(token: Optional[str]) -> Optional[dict[str, Any]]: |
| if not token: |
| return None |
| try: |
| return decode_access_token(token) |
| except jwt.PyJWTError: |
| return None |
|
|
|
|
| def generate_bootstrap_password() -> str: |
| return secrets.token_urlsafe(16) |
|
|