File size: 5,206 Bytes
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba1a8e
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fe4567
 
dd79664
 
aacd162
 
3fe4567
dba1a8e
 
 
 
 
 
 
 
 
 
 
 
 
aacd162
 
 
dba1a8e
 
aacd162
 
 
 
 
3fe4567
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any

from fastapi import Depends, HTTPException, Request, status
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from sqlalchemy.orm import Session
from starlette.middleware.sessions import SessionMiddleware

from data import crud
from data.db import get_db

AUTH_MODE_DEV = "dev"
AUTH_MODE_HF = "hf_oauth"
AUTH_BRIDGE_SALT = "streamlit-auth-bridge"
DEFAULT_DEV_SESSION_SECRET = "dev-only-session-secret-change-me"


@dataclass(frozen=True)
class CurrentUser:
    id: int
    email: str
    display_name: str | None = None
    avatar_url: str | None = None


class AuthBridgeTokenError(ValueError):
    pass


def get_auth_mode() -> str:
    configured = os.getenv("AUTH_MODE", AUTH_MODE_DEV).strip().lower()
    return configured if configured in {AUTH_MODE_DEV, AUTH_MODE_HF} else AUTH_MODE_DEV


def configure_session_middleware(app) -> None:
    """Attach Starlette session middleware once during app setup."""
    secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET).strip()
    auth_mode = get_auth_mode()
    if auth_mode == AUTH_MODE_HF and (not secret or secret == DEFAULT_DEV_SESSION_SECRET):
        raise RuntimeError("APP_SESSION_SECRET must be set to a non-default value in hf_oauth mode.")
    same_site = os.getenv("SESSION_COOKIE_SAMESITE", "lax").strip().lower()
    if same_site not in {"lax", "strict", "none"}:
        same_site = "lax"
    secure_default = "1" if auth_mode == AUTH_MODE_HF else "0"
    https_only = os.getenv("SESSION_COOKIE_SECURE", secure_default).strip().lower() in {
        "1",
        "true",
        "yes",
        "on",
    }
    app.add_middleware(
        SessionMiddleware,
        secret_key=secret,
        same_site=same_site,
        https_only=https_only,
        max_age=60 * 60 * 24 * 7,  # 7 days
    )


def _bridge_serializer() -> URLSafeTimedSerializer:
    secret = os.getenv("APP_SESSION_SECRET", DEFAULT_DEV_SESSION_SECRET)
    return URLSafeTimedSerializer(secret_key=secret, salt=AUTH_BRIDGE_SALT)


def _session_user_to_current_user(session_user: dict[str, Any]) -> CurrentUser | None:
    try:
        return CurrentUser(
            id=int(session_user["id"]),
            email=str(session_user["email"]),
            display_name=(str(session_user["display_name"]) if session_user.get("display_name") else None),
            avatar_url=(str(session_user["avatar_url"]) if session_user.get("avatar_url") else None),
        )
    except (KeyError, TypeError, ValueError):
        return None


def get_session_user(request: Request) -> CurrentUser | None:
    raw_user = request.session.get("user")
    if not isinstance(raw_user, dict):
        return None
    return _session_user_to_current_user(raw_user)


def set_session_user(request: Request, user: CurrentUser) -> None:
    request.session["user"] = {
        "id": user.id,
        "email": user.email,
        "display_name": user.display_name,
        "avatar_url": user.avatar_url,
    }


def clear_session_user(request: Request) -> None:
    request.session.pop("user", None)


def generate_auth_bridge_token(user: CurrentUser) -> str:
    payload = {
        "id": user.id,
        "email": user.email,
        "display_name": user.display_name,
        "avatar_url": user.avatar_url,
    }
    return _bridge_serializer().dumps(payload)


def decode_auth_bridge_token(token: str) -> CurrentUser:
    ttl_seconds = int(os.getenv("AUTH_BRIDGE_TOKEN_TTL_SECONDS", "300"))
    try:
        payload = _bridge_serializer().loads(token, max_age=ttl_seconds)
    except SignatureExpired as exc:
        raise AuthBridgeTokenError("Bridge token expired. Please sign in again.") from exc
    except BadSignature as exc:
        raise AuthBridgeTokenError("Invalid bridge token.") from exc

    if not isinstance(payload, dict):
        raise AuthBridgeTokenError("Invalid bridge token payload.")

    user = _session_user_to_current_user(payload)
    if user is None:
        raise AuthBridgeTokenError("Invalid bridge token payload.")
    return user


def _ensure_dev_user(request: Request, db: Session) -> CurrentUser:
    dev_user_id = int(os.getenv("AUTH_DEV_USER_ID", "1"))
    dev_email = os.getenv("AUTH_DEV_EMAIL", "dev@example.com")
    dev_name = os.getenv("AUTH_DEV_DISPLAY_NAME", "Dev User")

    user = crud.get_or_create_user(
        db=db,
        user_id=dev_user_id,
        email=dev_email,
        display_name=dev_name,
    )
    current = CurrentUser(
        id=user.id,
        email=user.email,
        display_name=user.display_name,
        avatar_url=user.avatar_url,
    )
    set_session_user(request, current)
    return current


def require_current_user(request: Request, db: Session = Depends(get_db)) -> CurrentUser:
    """Resolve current user from session; auto-provision in dev mode."""
    session_user = get_session_user(request)
    if session_user:
        return session_user

    if get_auth_mode() == AUTH_MODE_DEV:
        return _ensure_dev_user(request, db)

    raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Authentication required.",
    )