File size: 5,136 Bytes
d3a26e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a5685
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a26e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a5685
 
d3a26e1
 
 
 
 
 
 
 
13a5685
 
d3a26e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a5685
 
 
d3a26e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Authentication helpers for HF OAuth-backed requests.

Spec references:
- `specs/04_interfaces.md`: implements `get_current_user()`.
- `specs/07_security.md`: authentication is required and user identity scopes storage access.
- `specs/10_test_plan.md`: behavior is explicit and unit-testable.
"""

from __future__ import annotations

from typing import Any


class AuthError(Exception):
    """Base exception for authentication failures."""


class NotAuthenticatedError(AuthError):
    """Raised when the current request does not include an authenticated user."""


_MISSING = object()


def _safe_getattr(container: object, attribute_name: str) -> Any:
    """Read an attribute without propagating framework-specific accessor errors."""

    try:
        return object.__getattribute__(container, attribute_name)
    except AttributeError:
        pass
    except Exception:
        return _MISSING

    try:
        return getattr(container, attribute_name)
    except Exception:
        return _MISSING


def _extract_mapping_value(container: dict[str, Any]) -> str | None:
    """Extract a username from common mapping-based request contexts."""

    direct_keys: tuple[str, ...] = ("username", "user", "hf_user", "current_user")
    for key in direct_keys:
        value: Any = container.get(key)
        if isinstance(value, str) and value.strip():
            return value.strip()
        if isinstance(value, dict):
            nested_username: str | None = _extract_user_from_candidate(value)
            if nested_username is not None:
                return nested_username

    request: Any = container.get("request")
    if isinstance(request, dict):
        nested_username = _extract_mapping_value(request)
        if nested_username is not None:
            return nested_username

    state: Any = container.get("state")
    if isinstance(state, dict):
        nested_username = _extract_mapping_value(state)
        if nested_username is not None:
            return nested_username

    session: Any = container.get("session")
    if isinstance(session, dict):
        nested_username = _extract_mapping_value(session)
        if nested_username is not None:
            return nested_username

    return None


def _extract_object_value(container: object) -> str | None:
    """Extract a username from object-based request contexts."""

    attribute_names: tuple[str, ...] = ("username", "user", "hf_user", "current_user")
    for attribute_name in attribute_names:
        value: Any = _safe_getattr(container, attribute_name)
        if value is _MISSING:
            continue
        if isinstance(value, str) and value.strip():
            return value.strip()
        nested_username: str | None = _extract_user_from_candidate(value)
        if nested_username is not None:
            return nested_username

    for attribute_name in ("request", "state", "session"):
        nested_container: Any = _safe_getattr(container, attribute_name)
        if nested_container is _MISSING:
            continue
        nested_username = _extract_user_from_candidate(nested_container)
        if nested_username is not None:
            return nested_username

    return None


def _extract_user_from_candidate(candidate: Any) -> str | None:
    """Extract an authenticated username from one candidate context value."""

    if isinstance(candidate, str):
        normalized: str = candidate.strip()
        return normalized or None

    if isinstance(candidate, dict):
        username_from_mapping: str | None = _extract_mapping_value(candidate)
        if username_from_mapping is not None:
            return username_from_mapping

        preferred_keys: tuple[str, ...] = ("preferred_username", "name", "login", "sub")
        for key in preferred_keys:
            value: Any = candidate.get(key)
            if isinstance(value, str) and value.strip():
                return value.strip()
        return None

    if candidate is None:
        return None

    username_from_object: str | None = _extract_object_value(candidate)
    if username_from_object is not None:
        return username_from_object

    for attribute_name in ("preferred_username", "name", "login", "sub"):
        value: Any = _safe_getattr(candidate, attribute_name)
        if isinstance(value, str) and value.strip():
            return value.strip()

    return None


def get_current_user(request_ctx: Any) -> str:
    """Return the authenticated HF OAuth username from the current request context.

    Spec references:
    - `specs/04_interfaces.md`: implements `get_current_user()`.
    - `specs/07_security.md`: rejects unauthenticated access.

    Args:
        request_ctx: Framework-specific request or auth context object.

    Returns:
        The authenticated username string used for per-user storage isolation.

    Raises:
        NotAuthenticatedError: If no authenticated user can be extracted.
    """

    username: str | None = _extract_user_from_candidate(request_ctx)
    if username is None:
        raise NotAuthenticatedError("Authenticated user not found in request context.")
    return username