lewtun HF Staff OpenAI Codex commited on
Commit
e90638c
·
unverified ·
1 Parent(s): 82bad21

Force OAuth refresh after scope changes (#227)

Browse files

* Force OAuth refresh after scope changes

Co-authored-by: OpenAI Codex <codex@openai.com>

* Address OAuth scope review comments

Co-authored-by: OpenAI Codex <codex@openai.com>

---------

Co-authored-by: OpenAI Codex <codex@openai.com>

backend/dependencies.py CHANGED
@@ -7,6 +7,8 @@
7
  import logging
8
  import os
9
  import time
 
 
10
  from typing import Any
11
 
12
  import httpx
@@ -37,12 +39,60 @@ DEV_USER: dict[str, Any] = {
37
  }
38
 
39
  INTERNAL_HF_TOKEN_KEY = "_hf_token"
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Plan field discovery — log the whoami-v2 shape once at DEBUG so we can
42
  # confirm the actual key in production without hammering the HF API.
43
  _WHOAMI_SHAPE_LOGGED = False
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  async def _validate_token(token: str) -> dict[str, Any] | None:
47
  """Validate a token against HF OAuth userinfo endpoint.
48
 
@@ -213,7 +263,8 @@ async def get_current_user(request: Request) -> dict[str, Any]:
213
  if not AUTH_ENABLED:
214
  return await _dev_user_from_env()
215
 
216
- # Try Authorization header
 
217
  token = bearer_token_from_header(request.headers.get("Authorization", ""))
218
  if token:
219
  user = await _extract_user_from_token(token)
@@ -223,6 +274,15 @@ async def get_current_user(request: Request) -> dict[str, Any]:
223
  # Try cookie
224
  token = request.cookies.get("hf_access_token")
225
  if token:
 
 
 
 
 
 
 
 
 
226
  user = await _extract_user_from_token(token)
227
  if user:
228
  return user
 
7
  import logging
8
  import os
9
  import time
10
+ from collections.abc import Iterable
11
+ from hashlib import sha256
12
  from typing import Any
13
 
14
  import httpx
 
39
  }
40
 
41
  INTERNAL_HF_TOKEN_KEY = "_hf_token"
42
+ OAUTH_SCOPE_COOKIE = "hf_oauth_scope_hash"
43
+ REQUIRED_OAUTH_SCOPES: tuple[str, ...] = (
44
+ "openid",
45
+ "profile",
46
+ "read-repos",
47
+ "write-repos",
48
+ "contribute-repos",
49
+ "manage-repos",
50
+ "write-collections",
51
+ "inference-api",
52
+ "jobs",
53
+ "write-discussions",
54
+ )
55
 
56
  # Plan field discovery — log the whoami-v2 shape once at DEBUG so we can
57
  # confirm the actual key in production without hammering the HF API.
58
  _WHOAMI_SHAPE_LOGGED = False
59
 
60
 
61
+ def normalize_oauth_scopes(scopes: Iterable[str]) -> tuple[str, ...]:
62
+ """Return stable, de-duplicated OAuth scopes preserving declaration order."""
63
+ seen: set[str] = set()
64
+ normalized: list[str] = []
65
+ for scope in scopes:
66
+ value = str(scope).strip()
67
+ if not value or value in seen:
68
+ continue
69
+ seen.add(value)
70
+ normalized.append(value)
71
+ return tuple(normalized)
72
+
73
+
74
+ def configured_oauth_scopes() -> tuple[str, ...]:
75
+ """Return the scopes this backend should request from HF OAuth.
76
+
77
+ Spaces expose README ``hf_oauth_scopes`` through ``OAUTH_SCOPES``. Unioning
78
+ that value with the app-required scopes keeps the local request and Space
79
+ metadata in sync while ensuring new required scopes are never omitted.
80
+ """
81
+ env_scopes = os.environ.get("OAUTH_SCOPES", "").split()
82
+ return normalize_oauth_scopes((*env_scopes, *REQUIRED_OAUTH_SCOPES))
83
+
84
+
85
+ def oauth_scope_fingerprint(scopes: Iterable[str] | None = None) -> str:
86
+ """Return a non-secret fingerprint for the current OAuth scope contract."""
87
+ scope_list = configured_oauth_scopes() if scopes is None else scopes
88
+ payload = " ".join(sorted(normalize_oauth_scopes(scope_list)))
89
+ return sha256(payload.encode("utf-8")).hexdigest()[:16]
90
+
91
+
92
+ def _cookie_has_current_oauth_scope_marker(request: Request) -> bool:
93
+ return request.cookies.get(OAUTH_SCOPE_COOKIE) == oauth_scope_fingerprint()
94
+
95
+
96
  async def _validate_token(token: str) -> dict[str, Any] | None:
97
  """Validate a token against HF OAuth userinfo endpoint.
98
 
 
263
  if not AUTH_ENABLED:
264
  return await _dev_user_from_env()
265
 
266
+ # Bearer callers manage token lifecycle themselves; only browser cookie
267
+ # auth is forced through the scope-freshness marker below.
268
  token = bearer_token_from_header(request.headers.get("Authorization", ""))
269
  if token:
270
  user = await _extract_user_from_token(token)
 
274
  # Try cookie
275
  token = request.cookies.get("hf_access_token")
276
  if token:
277
+ if not _cookie_has_current_oauth_scope_marker(request):
278
+ logger.info(
279
+ "Rejecting stale HF OAuth cookie; current scopes require refresh."
280
+ )
281
+ raise HTTPException(
282
+ status_code=status.HTTP_401_UNAUTHORIZED,
283
+ detail="Authentication scopes changed. Please log in again.",
284
+ headers={"WWW-Authenticate": "Bearer"},
285
+ )
286
  user = await _extract_user_from_token(token)
287
  if user:
288
  return user
backend/routes/auth.py CHANGED
@@ -4,40 +4,47 @@ Handles the OAuth 2.0 authorization code flow with HF as provider.
4
  After successful auth, sets an HttpOnly cookie with the access token.
5
  """
6
 
 
7
  import os
8
  import secrets
9
  import time
10
  from urllib.parse import urlencode
11
 
12
  import httpx
13
- from dependencies import AUTH_ENABLED, get_current_user
 
 
 
 
 
 
 
14
  from fastapi import APIRouter, Depends, HTTPException, Request
15
  from fastapi.responses import RedirectResponse
16
 
17
  router = APIRouter(prefix="/auth", tags=["auth"])
 
18
 
19
  # OAuth configuration from environment
20
  OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
21
  OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
22
  OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
23
- OAUTH_SCOPES = (
24
- "openid",
25
- "profile",
26
- "read-repos",
27
- "write-repos",
28
- "contribute-repos",
29
- "manage-repos",
30
- "write-collections",
31
- "inference-api",
32
- "jobs",
33
- "write-discussions",
34
- )
35
 
36
  # In-memory OAuth state store with expiry (5 min TTL)
37
  _OAUTH_STATE_TTL = 300
38
  oauth_states: dict[str, dict] = {}
39
 
40
 
 
 
 
 
 
 
 
 
 
41
  def _cleanup_expired_states() -> None:
42
  """Remove expired OAuth states to prevent memory growth."""
43
  now = time.time()
@@ -131,6 +138,15 @@ async def oauth_callback(
131
  status_code=500,
132
  detail="Token exchange succeeded but no access_token was returned.",
133
  )
 
 
 
 
 
 
 
 
 
134
 
135
  # Fetch user info (optional — failure is not fatal)
136
  async with httpx.AsyncClient() as client:
@@ -156,6 +172,15 @@ async def oauth_callback(
156
  max_age=3600 * 24 * 7, # 7 days
157
  path="/",
158
  )
 
 
 
 
 
 
 
 
 
159
  return response
160
 
161
 
@@ -164,6 +189,7 @@ async def logout() -> RedirectResponse:
164
  """Log out the user by clearing the auth cookie."""
165
  response = RedirectResponse(url="/")
166
  response.delete_cookie(key="hf_access_token", path="/")
 
167
  return response
168
 
169
 
 
4
  After successful auth, sets an HttpOnly cookie with the access token.
5
  """
6
 
7
+ import logging
8
  import os
9
  import secrets
10
  import time
11
  from urllib.parse import urlencode
12
 
13
  import httpx
14
+ from dependencies import (
15
+ AUTH_ENABLED,
16
+ OAUTH_SCOPE_COOKIE,
17
+ REQUIRED_OAUTH_SCOPES,
18
+ configured_oauth_scopes,
19
+ get_current_user,
20
+ oauth_scope_fingerprint,
21
+ )
22
  from fastapi import APIRouter, Depends, HTTPException, Request
23
  from fastapi.responses import RedirectResponse
24
 
25
  router = APIRouter(prefix="/auth", tags=["auth"])
26
+ logger = logging.getLogger(__name__)
27
 
28
  # OAuth configuration from environment
29
  OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
30
  OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
31
  OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
32
+ OAUTH_SCOPES = configured_oauth_scopes()
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # In-memory OAuth state store with expiry (5 min TTL)
35
  _OAUTH_STATE_TTL = 300
36
  oauth_states: dict[str, dict] = {}
37
 
38
 
39
+ def _missing_required_scopes(token_data: dict) -> set[str]:
40
+ raw_scopes = token_data.get("scope")
41
+ if not isinstance(raw_scopes, str) or not raw_scopes.strip():
42
+ logger.debug("OAuth token response omitted a usable scope field")
43
+ return set()
44
+ granted = set(raw_scopes.split())
45
+ return set(REQUIRED_OAUTH_SCOPES) - granted
46
+
47
+
48
  def _cleanup_expired_states() -> None:
49
  """Remove expired OAuth states to prevent memory growth."""
50
  now = time.time()
 
138
  status_code=500,
139
  detail="Token exchange succeeded but no access_token was returned.",
140
  )
141
+ missing_scopes = _missing_required_scopes(token_data)
142
+ if missing_scopes:
143
+ raise HTTPException(
144
+ status_code=403,
145
+ detail=(
146
+ "OAuth token is missing required scopes: "
147
+ + ", ".join(sorted(missing_scopes))
148
+ ),
149
+ )
150
 
151
  # Fetch user info (optional — failure is not fatal)
152
  async with httpx.AsyncClient() as client:
 
172
  max_age=3600 * 24 * 7, # 7 days
173
  path="/",
174
  )
175
+ response.set_cookie(
176
+ key=OAUTH_SCOPE_COOKIE,
177
+ value=oauth_scope_fingerprint(OAUTH_SCOPES),
178
+ httponly=True,
179
+ secure=is_production,
180
+ samesite="lax",
181
+ max_age=3600 * 24 * 7,
182
+ path="/",
183
+ )
184
  return response
185
 
186
 
 
189
  """Log out the user by clearing the auth cookie."""
190
  response = RedirectResponse(url="/")
191
  response.delete_cookie(key="hf_access_token", path="/")
192
+ response.delete_cookie(key=OAUTH_SCOPE_COOKIE, path="/")
193
  return response
194
 
195
 
tests/unit/test_auth_token_propagation.py CHANGED
@@ -6,6 +6,7 @@ from types import SimpleNamespace
6
  from urllib.parse import parse_qs, urlparse
7
 
8
  import pytest
 
9
 
10
  _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
11
  if str(_BACKEND_DIR) not in sys.path:
@@ -44,6 +45,52 @@ async def test_current_user_carries_internal_hf_token(monkeypatch):
44
  assert user[dependencies.INTERNAL_HF_TOKEN_KEY] == "hf-user-token"
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @pytest.mark.asyncio
48
  async def test_auth_me_does_not_expose_internal_hf_token():
49
  user = {
@@ -73,3 +120,71 @@ async def test_oauth_login_requests_collection_write_scope(monkeypatch):
73
  scopes = set(params["scope"][0].split())
74
 
75
  assert "write-collections" in scopes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from urllib.parse import parse_qs, urlparse
7
 
8
  import pytest
9
+ from fastapi import HTTPException
10
 
11
  _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
12
  if str(_BACKEND_DIR) not in sys.path:
 
45
  assert user[dependencies.INTERNAL_HF_TOKEN_KEY] == "hf-user-token"
46
 
47
 
48
+ @pytest.mark.asyncio
49
+ async def test_cookie_auth_requires_current_oauth_scope_marker(monkeypatch):
50
+ monkeypatch.setattr(dependencies, "AUTH_ENABLED", True)
51
+
52
+ request = SimpleNamespace(
53
+ headers={},
54
+ cookies={"hf_access_token": "hf-user-token"},
55
+ )
56
+
57
+ with pytest.raises(HTTPException) as exc_info:
58
+ await dependencies.get_current_user(request)
59
+
60
+ assert exc_info.value.status_code == 401
61
+ assert "scopes changed" in exc_info.value.detail
62
+
63
+
64
+ @pytest.mark.asyncio
65
+ async def test_cookie_auth_accepts_current_oauth_scope_marker(monkeypatch):
66
+ monkeypatch.setattr(dependencies, "AUTH_ENABLED", True)
67
+ dependencies._token_cache.clear()
68
+
69
+ async def fake_validate_token(token):
70
+ assert token == "hf-user-token"
71
+ return {"sub": "user-id", "preferred_username": "alice"}
72
+
73
+ async def fake_fetch_user_plan(token):
74
+ assert token == "hf-user-token"
75
+ return "pro"
76
+
77
+ monkeypatch.setattr(dependencies, "_validate_token", fake_validate_token)
78
+ monkeypatch.setattr(dependencies, "_fetch_user_plan", fake_fetch_user_plan)
79
+
80
+ request = SimpleNamespace(
81
+ headers={},
82
+ cookies={
83
+ "hf_access_token": "hf-user-token",
84
+ dependencies.OAUTH_SCOPE_COOKIE: dependencies.oauth_scope_fingerprint(),
85
+ },
86
+ )
87
+
88
+ user = await dependencies.get_current_user(request)
89
+
90
+ assert user["user_id"] == "user-id"
91
+ assert user[dependencies.INTERNAL_HF_TOKEN_KEY] == "hf-user-token"
92
+
93
+
94
  @pytest.mark.asyncio
95
  async def test_auth_me_does_not_expose_internal_hf_token():
96
  user = {
 
120
  scopes = set(params["scope"][0].split())
121
 
122
  assert "write-collections" in scopes
123
+
124
+
125
+ def test_oauth_callback_detects_missing_required_collection_scope():
126
+ granted = [scope for scope in auth.OAUTH_SCOPES if scope != "write-collections"]
127
+
128
+ assert auth._missing_required_scopes({"scope": " ".join(granted)}) == {
129
+ "write-collections"
130
+ }
131
+
132
+
133
+ def test_oauth_callback_treats_absent_scope_as_full_grant():
134
+ assert auth._missing_required_scopes({}) == set()
135
+
136
+
137
+ @pytest.mark.asyncio
138
+ async def test_oauth_callback_sets_scope_marker_cookie(monkeypatch):
139
+ monkeypatch.setenv("SPACE_HOST", "example.hf.space")
140
+ auth.oauth_states.clear()
141
+ auth.oauth_states["state"] = {
142
+ "redirect_uri": "https://example.hf.space/auth/callback",
143
+ "expires_at": 9999999999,
144
+ }
145
+
146
+ class FakeResponse:
147
+ def __init__(self, payload):
148
+ self._payload = payload
149
+
150
+ def raise_for_status(self):
151
+ return None
152
+
153
+ def json(self):
154
+ return self._payload
155
+
156
+ class FakeAsyncClient:
157
+ def __init__(self, *args, **kwargs):
158
+ pass
159
+
160
+ async def __aenter__(self):
161
+ return self
162
+
163
+ async def __aexit__(self, *args):
164
+ return None
165
+
166
+ async def post(self, *args, **kwargs):
167
+ return FakeResponse(
168
+ {
169
+ "access_token": "hf-user-token",
170
+ "scope": " ".join(auth.OAUTH_SCOPES),
171
+ }
172
+ )
173
+
174
+ async def get(self, *args, **kwargs):
175
+ return FakeResponse({})
176
+
177
+ monkeypatch.setattr(auth.httpx, "AsyncClient", FakeAsyncClient)
178
+
179
+ response = await auth.oauth_callback(SimpleNamespace(), code="code", state="state")
180
+ set_cookies = [
181
+ value.decode("latin-1")
182
+ for key, value in response.raw_headers
183
+ if key == b"set-cookie"
184
+ ]
185
+
186
+ expected = (
187
+ f"{dependencies.OAUTH_SCOPE_COOKIE}="
188
+ f"{dependencies.oauth_scope_fingerprint(auth.OAUTH_SCOPES)}"
189
+ )
190
+ assert any(cookie.startswith(expected) for cookie in set_cookies)