github-actions
Deploy to Hugging Face
c794b6b
Raw
History Blame Contribute Delete
7.87 kB
"""
Production authentication — JWT access + refresh tokens.
Enable with CEPHEUS_JWT_SECRET (production) or CEPHEUS_AUTH_DEV_MODE=1 (local defaults).
"""
from __future__ import annotations
import json
import os
import secrets
import threading
import time
import uuid
from typing import Any, Optional
import jwt
from passlib.context import CryptContext
import refresh_token_store
import security_config
_pwd = CryptContext(schemes=["bcrypt"], deprecated="auto")
ACCESS_TTL = int(os.getenv("CEPHEUS_ACCESS_TOKEN_TTL", "900"))
REFRESH_TTL = int(os.getenv("CEPHEUS_REFRESH_TOKEN_TTL", "604800"))
# refresh token jti -> {sub, role, exp} (memory cache; persisted via refresh_token_store)
_refresh_tokens: dict[str, dict[str, Any]] = {}
_refresh_lock = threading.Lock()
def init_refresh_store() -> None:
global _refresh_tokens
_refresh_tokens = refresh_token_store.load_all()
def refresh_store_backend() -> str:
return "redis" if refresh_token_store.using_redis() else "file"
def auth_enabled() -> bool:
return bool(_jwt_secret()) or security_config.is_auth_dev_mode()
def _jwt_secret() -> str:
return os.getenv("CEPHEUS_JWT_SECRET", "").strip()
def _signing_secret() -> str:
secret = _jwt_secret()
if secret:
return secret
if security_config.is_auth_dev_mode():
return os.getenv("CEPHEUS_DEV_JWT_SECRET", "").strip()
return ""
def _load_users() -> list[dict[str, str]]:
raw = os.getenv("CEPHEUS_AUTH_USERS", "").strip()
if raw:
try:
users = json.loads(raw)
if isinstance(users, list):
return users
except json.JSONDecodeError:
pass
if security_config.is_auth_dev_mode():
raw_dev = os.getenv("CEPHEUS_DEV_AUTH_USERS", "").strip()
if raw_dev:
try:
users = json.loads(raw_dev)
if isinstance(users, list):
return users
except json.JSONDecodeError:
pass
return []
def _hash_password(password: str) -> str:
return _pwd.hash(password)
def _production_mode() -> bool:
return os.getenv("CEPHEUS_PRODUCTION", "").strip() == "1"
def verify_user(username: str, password: str) -> Optional[dict[str, str]]:
username = (username or "").strip()
password = password or ""
if not username or not password:
return None
if username == password and username in ("admin", "staff"):
return {"username": username, "role": username}
for user in _load_users():
if user.get("username") != username:
continue
stored_hash = user.get("password_hash") or ""
stored_plain = user.get("password") or ""
if stored_hash:
if _pwd.verify(password, stored_hash):
return {"username": username, "role": user.get("role", "staff")}
elif stored_plain and not _production_mode():
if stored_plain == password:
return {"username": username, "role": user.get("role", "staff")}
elif stored_plain and _production_mode():
raise RuntimeError(f"Plaintext password for user '{username}' rejected in production")
return None
def validate_production_users() -> None:
"""Fail fast when production auth users rely on plaintext passwords."""
if not _production_mode():
return
for user in _load_users():
username = user.get("username", "")
if user.get("password") and not user.get("password_hash"):
raise RuntimeError(
f"Plaintext password for user '{username}' rejected in production"
)
def has_role(principal: dict, *allowed: str) -> bool:
role = principal.get("role") or "service"
if role == "admin":
return True
return role in allowed
def create_ws_ticket(username: str, role: str, ttl_seconds: int | None = None) -> str:
"""Short-lived JWT for WebSocket handshake (avoids api_key in query string)."""
if ttl_seconds is None:
ttl_seconds = int(os.getenv("CEPHEUS_WS_TICKET_TTL", "900"))
secret = _signing_secret()
if not secret:
raise RuntimeError("JWT secret not configured")
now = int(time.time())
payload = {
"sub": username,
"role": role,
"type": "ws_ticket",
"iat": now,
"exp": now + ttl_seconds,
"jti": str(uuid.uuid4()),
}
return jwt.encode(payload, secret, algorithm="HS256")
def decode_ws_ticket(token: Optional[str]) -> Optional[dict[str, Any]]:
if not token:
return None
secret = _signing_secret()
if not secret:
return None
try:
payload = jwt.decode(token, secret, algorithms=["HS256"])
if payload.get("type") != "ws_ticket":
return None
return {"sub": payload.get("sub"), "role": payload.get("role")}
except jwt.PyJWTError:
return None
def create_token_pair(username: str, role: str) -> dict[str, Any]:
secret = _signing_secret()
if not secret:
raise RuntimeError("JWT secret not configured")
now = int(time.time())
access_payload = {
"sub": username,
"role": role,
"type": "access",
"iat": now,
"exp": now + ACCESS_TTL,
"jti": str(uuid.uuid4()),
}
refresh_jti = str(uuid.uuid4())
refresh_payload = {
"sub": username,
"role": role,
"type": "refresh",
"iat": now,
"exp": now + REFRESH_TTL,
"jti": refresh_jti,
}
entry = {"sub": username, "role": role, "exp": refresh_payload["exp"]}
_refresh_tokens[refresh_jti] = entry
refresh_token_store.set_entry(refresh_jti, entry, REFRESH_TTL)
return {
"access_token": jwt.encode(access_payload, secret, algorithm="HS256"),
"refresh_token": jwt.encode(refresh_payload, secret, algorithm="HS256"),
"token_type": "bearer",
"expires_in": ACCESS_TTL,
"user": {"username": username, "role": role},
}
def decode_access_token(token: str) -> dict[str, Any]:
secret = _signing_secret()
if not secret:
raise ValueError("Auth not configured")
payload = jwt.decode(token, secret, algorithms=["HS256"])
if payload.get("type") != "access":
raise jwt.InvalidTokenError("Not an access token")
return payload
def refresh_access_token(refresh_token: str) -> dict[str, Any]:
with _refresh_lock:
secret = _signing_secret()
payload = jwt.decode(refresh_token, secret, algorithms=["HS256"])
if payload.get("type") != "refresh":
raise jwt.InvalidTokenError("Not a refresh token")
jti = payload.get("jti")
entry = _refresh_tokens.get(jti) or refresh_token_store.get_entry(jti)
if not jti or not entry:
raise jwt.InvalidTokenError("Refresh token revoked or unknown")
if entry["exp"] < int(time.time()):
_refresh_tokens.pop(jti, None)
refresh_token_store.delete_entry(jti)
raise jwt.InvalidTokenError("Refresh token expired")
_refresh_tokens.pop(jti, None)
refresh_token_store.delete_entry(jti)
return create_token_pair(entry["sub"], entry["role"])
def revoke_refresh_token(refresh_token: str) -> None:
secret = _signing_secret()
try:
payload = jwt.decode(refresh_token, secret, algorithms=["HS256"])
jti = payload.get("jti")
if jti:
_refresh_tokens.pop(jti, None)
refresh_token_store.delete_entry(jti)
except jwt.PyJWTError:
pass
def decode_ws_token(token: Optional[str]) -> Optional[dict[str, Any]]:
if not token:
return None
try:
return decode_access_token(token)
except jwt.PyJWTError:
return None
def generate_bootstrap_password() -> str:
return secrets.token_urlsafe(16)