Spaces:
Sleeping
Sleeping
File size: 5,206 Bytes
aacd162 dba1a8e aacd162 3fe4567 dd79664 aacd162 3fe4567 dba1a8e aacd162 dba1a8e aacd162 3fe4567 aacd162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any
from fastapi import Depends, HTTPException, Request, status
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from sqlalchemy.orm import Session
from starlette.middleware.sessions import SessionMiddleware
from data import crud
from data.db import get_db
AUTH_MODE_DEV = "dev"
AUTH_MODE_HF = "hf_oauth"
AUTH_BRIDGE_SALT = "streamlit-auth-bridge"
DEFAULT_DEV_SESSION_SECRET = "dev-only-session-secret-change-me"
@dataclass(frozen=True)
class CurrentUser:
id: int
email: str
display_name: str | None = None
avatar_url: str | None = None
class AuthBridgeTokenError(ValueError):
pass
def get_auth_mode() -> str:
configured = os.getenv("AUTH_MODE", AUTH_MODE_DEV).strip().lower()
return configured if configured in {AUTH_MODE_DEV, AUTH_MODE_HF} else AUTH_MODE_DEV
def configure_session_middleware(app) -> None:
"""Attach Starlette session middleware once during app setup."""
secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET).strip()
auth_mode = get_auth_mode()
if auth_mode == AUTH_MODE_HF and (not secret or secret == DEFAULT_DEV_SESSION_SECRET):
raise RuntimeError("APP_SESSION_SECRET must be set to a non-default value in hf_oauth mode.")
same_site = os.getenv("SESSION_COOKIE_SAMESITE", "lax").strip().lower()
if same_site not in {"lax", "strict", "none"}:
same_site = "lax"
secure_default = "1" if auth_mode == AUTH_MODE_HF else "0"
https_only = os.getenv("SESSION_COOKIE_SECURE", secure_default).strip().lower() in {
"1",
"true",
"yes",
"on",
}
app.add_middleware(
SessionMiddleware,
secret_key=secret,
same_site=same_site,
https_only=https_only,
max_age=60 * 60 * 24 * 7, # 7 days
)
def _bridge_serializer() -> URLSafeTimedSerializer:
secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET)
return URLSafeTimedSerializer(secret_key=secret, salt=AUTH_BRIDGE_SALT)
def _session_user_to_current_user(session_user: dict[str, Any]) -> CurrentUser | None:
try:
return CurrentUser(
id=int(session_user["id"]),
email=str(session_user["email"]),
display_name=(str(session_user["display_name"]) if session_user.get("display_name") else None),
avatar_url=(str(session_user["avatar_url"]) if session_user.get("avatar_url") else None),
)
except (KeyError, TypeError, ValueError):
return None
def get_session_user(request: Request) -> CurrentUser | None:
raw_user = request.session.get("user")
if not isinstance(raw_user, dict):
return None
return _session_user_to_current_user(raw_user)
def set_session_user(request: Request, user: CurrentUser) -> None:
request.session["user"] = {
"id": user.id,
"email": user.email,
"display_name": user.display_name,
"avatar_url": user.avatar_url,
}
def clear_session_user(request: Request) -> None:
request.session.pop("user", None)
def generate_auth_bridge_token(user: CurrentUser) -> str:
payload = {
"id": user.id,
"email": user.email,
"display_name": user.display_name,
"avatar_url": user.avatar_url,
}
return _bridge_serializer().dumps(payload)
def decode_auth_bridge_token(token: str) -> CurrentUser:
ttl_seconds = int(os.getenv("AUTH_BRIDGE_TOKEN_TTL_SECONDS", "300"))
try:
payload = _bridge_serializer().loads(token, max_age=ttl_seconds)
except SignatureExpired as exc:
raise AuthBridgeTokenError("Bridge token expired. Please sign in again.") from exc
except BadSignature as exc:
raise AuthBridgeTokenError("Invalid bridge token.") from exc
if not isinstance(payload, dict):
raise AuthBridgeTokenError("Invalid bridge token payload.")
user = _session_user_to_current_user(payload)
if user is None:
raise AuthBridgeTokenError("Invalid bridge token payload.")
return user
def _ensure_dev_user(request: Request, db: Session) -> CurrentUser:
dev_user_id = int(os.getenv("AUTH_DEV_USER_ID", "1"))
dev_email = os.getenv("AUTH_DEV_EMAIL", "dev@example.com")
dev_name = os.getenv("AUTH_DEV_DISPLAY_NAME", "Dev User")
user = crud.get_or_create_user(
db=db,
user_id=dev_user_id,
email=dev_email,
display_name=dev_name,
)
current = CurrentUser(
id=user.id,
email=user.email,
display_name=user.display_name,
avatar_url=user.avatar_url,
)
set_session_user(request, current)
return current
def require_current_user(request: Request, db: Session = Depends(get_db)) -> CurrentUser:
"""Resolve current user from session; auto-provision in dev mode."""
session_user = get_session_user(request)
if session_user:
return session_user
if get_auth_mode() == AUTH_MODE_DEV:
return _ensure_dev_user(request, db)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.",
)
|