raylim Claude Opus 4.6 commited on
Commit
8492a0e
·
1 Parent(s): 98b35f0

feat: implement manual OAuth for HF Spaces Docker SDK

Browse files

On HF Spaces with sdk:docker, Gradio's built-in OAuth injects the Space
owner's identity into every session. This adds custom /api/auth/* routes
that perform the full Authorization Code flow directly with HF's OAuth
provider, writing the correct visitor identity into the Starlette session
so Gradio's OAuthProfile injection works transparently.

- Add src/mosaic/ui/oauth.py with login/callback/logout routes and
server-side session store (24h TTL, mosaic_auth cookie fallback)
- Mount custom OAuth routes on Gradio app when IS_HF_SPACES
- Update login/logout links to /api/auth/login and /api/auth/logout
- Add server-session fallback to extract_user_info() and _get_username()
- Stop hashing usernames in telemetry (use raw HF usernames)
- Add pyproject.toml uv index-strategy for macOS CPU torch compatibility
- Add tests/test_oauth.py with 18 tests covering the full OAuth flow

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

pyproject.toml CHANGED
@@ -42,6 +42,13 @@ disable = [
42
  "unspecified-encoding",
43
  ]
44
 
 
 
 
 
 
 
 
45
  [tool.uv.sources]
46
  # For local dev with SSH: uv pip install -e .
47
  # For Docker builds with token: GH_TOKEN=<token> uv pip install -e .
 
42
  "unspecified-encoding",
43
  ]
44
 
45
+ [tool.uv]
46
+ # Override PyTorch dependencies from mussel[torch-gpu] for macOS compatibility
47
+ override-dependencies = [
48
+ "torch>=2.0.0; sys_platform == 'darwin'",
49
+ "torchvision>=0.15.0; sys_platform == 'darwin'",
50
+ ]
51
+
52
  [tool.uv.sources]
53
  # For local dev with SSH: uv pip install -e .
54
  # For Docker builds with token: GH_TOKEN=<token> uv pip install -e .
src/mosaic/telemetry/__init__.py CHANGED
@@ -47,7 +47,6 @@ from mosaic.telemetry.utils import (
47
  StageTimer,
48
  sanitize_error_message,
49
  hash_session_id,
50
- hash_username,
51
  UserInfo,
52
  extract_user_info,
53
  )
@@ -68,7 +67,6 @@ __all__ = [
68
  "StageTimer",
69
  "sanitize_error_message",
70
  "hash_session_id",
71
- "hash_username",
72
  "UserInfo",
73
  "extract_user_info",
74
  ]
 
47
  StageTimer,
48
  sanitize_error_message,
49
  hash_session_id,
 
50
  UserInfo,
51
  extract_user_info,
52
  )
 
67
  "StageTimer",
68
  "sanitize_error_message",
69
  "hash_session_id",
 
70
  "UserInfo",
71
  "extract_user_info",
72
  ]
src/mosaic/telemetry/tracker.py CHANGED
@@ -22,7 +22,6 @@ from mosaic.telemetry.events import (
22
  from mosaic.telemetry.storage import TelemetryStorage
23
  from mosaic.telemetry.utils import (
24
  hash_session_id,
25
- hash_username,
26
  sanitize_error_message,
27
  )
28
 
@@ -256,7 +255,7 @@ class TelemetryTracker:
256
  success=success,
257
  cached_slide_count=cached_slide_count,
258
  is_logged_in=is_logged_in,
259
- hf_username=hash_username(hf_username),
260
  )
261
  self.storage.write_usage_event(event)
262
 
@@ -332,7 +331,7 @@ class TelemetryTracker:
332
  gpu_type=gpu_type,
333
  peak_gpu_memory_gb=peak_gpu_memory_gb,
334
  is_logged_in=is_logged_in,
335
- hf_username=hash_username(hf_username),
336
  )
337
  self.storage.write_resource_event(event)
338
 
@@ -377,7 +376,7 @@ class TelemetryTracker:
377
  slide_count=slide_count,
378
  gpu_type=gpu_type,
379
  is_logged_in=is_logged_in,
380
- hf_username=hash_username(hf_username),
381
  )
382
  self.storage.write_failure_event(event)
383
 
 
22
  from mosaic.telemetry.storage import TelemetryStorage
23
  from mosaic.telemetry.utils import (
24
  hash_session_id,
 
25
  sanitize_error_message,
26
  )
27
 
 
255
  success=success,
256
  cached_slide_count=cached_slide_count,
257
  is_logged_in=is_logged_in,
258
+ hf_username=hf_username,
259
  )
260
  self.storage.write_usage_event(event)
261
 
 
331
  gpu_type=gpu_type,
332
  peak_gpu_memory_gb=peak_gpu_memory_gb,
333
  is_logged_in=is_logged_in,
334
+ hf_username=hf_username,
335
  )
336
  self.storage.write_resource_event(event)
337
 
 
376
  slide_count=slide_count,
377
  gpu_type=gpu_type,
378
  is_logged_in=is_logged_in,
379
+ hf_username=hf_username,
380
  )
381
  self.storage.write_failure_event(event)
382
 
src/mosaic/telemetry/utils.py CHANGED
@@ -5,7 +5,7 @@ This module provides helper utilities:
5
  - sanitize_error_message: Remove sensitive data from error messages
6
  - hash_session_id: Hash session IDs for privacy
7
  - UserInfo: Dataclass for user information from HF Spaces
8
- - extract_user_info: Extract user info from Gradio request object
9
  """
10
 
11
  import hashlib
@@ -99,25 +99,6 @@ def hash_session_id(session_id: Optional[str]) -> Optional[str]:
99
  return hashlib.sha256(salted.encode()).hexdigest()[:16]
100
 
101
 
102
- def hash_username(username: Optional[str]) -> Optional[str]:
103
- """Hash a username for privacy in telemetry.
104
-
105
- Uses SHA-256 with a different salt than session IDs to create a one-way hash.
106
- This allows distinguishing users in telemetry without storing actual usernames.
107
-
108
- Args:
109
- username: HuggingFace username (can be None for anonymous users)
110
-
111
- Returns:
112
- Hashed username or None if input is None
113
- """
114
- if username is None:
115
- return None
116
-
117
- salted = f"mosaic_user:{username}"
118
- return hashlib.sha256(salted.encode()).hexdigest()[:16]
119
-
120
-
121
  @dataclass
122
  class UserInfo:
123
  """User information extracted from HF Spaces request.
@@ -169,5 +150,19 @@ def extract_user_info(request, is_hf_spaces: bool = False, profile=None) -> User
169
  except Exception as e:
170
  logger.debug(f"Could not extract username from OAuthProfile: {e}")
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  logger.debug("User not logged in: no OAuthProfile available")
173
  return UserInfo()
 
5
  - sanitize_error_message: Remove sensitive data from error messages
6
  - hash_session_id: Hash session IDs for privacy
7
  - UserInfo: Dataclass for user information from HF Spaces
8
+ - extract_user_info: Extract user info from Gradio request/OAuth profile
9
  """
10
 
11
  import hashlib
 
99
  return hashlib.sha256(salted.encode()).hexdigest()[:16]
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  @dataclass
103
  class UserInfo:
104
  """User information extracted from HF Spaces request.
 
150
  except Exception as e:
151
  logger.debug(f"Could not extract username from OAuthProfile: {e}")
152
 
153
+ # Fallback: check server-side session store (custom OAuth flow)
154
+ if request is not None:
155
+ try:
156
+ from mosaic.ui.oauth import get_user_from_server_session
157
+
158
+ userinfo = get_user_from_server_session(request)
159
+ if userinfo:
160
+ username = userinfo.get("preferred_username")
161
+ if username:
162
+ logger.info(f"Extracted user from server-side session: {username}")
163
+ return UserInfo(is_logged_in=True, username=username)
164
+ except Exception as e:
165
+ logger.debug(f"Server-side session lookup failed: {e}")
166
+
167
  logger.debug("User not logged in: no OAuthProfile available")
168
  return UserInfo()
src/mosaic/ui/app.py CHANGED
@@ -37,6 +37,7 @@ from mosaic.data_directory import get_tcga_cache_directory
37
  from mosaic.model_manager import load_all_models
38
  from mosaic.hardware import DEFAULT_CONCURRENCY_LIMIT, IS_HF_SPACES, IS_T4_GPU, GPU_TYPE
39
  from mosaic.telemetry import extract_user_info
 
40
  from mosaic.tcga import (
41
  compute_settings_hash,
42
  download_results_from_hf,
@@ -1933,7 +1934,7 @@ This tool is for research purposes only and not approved for clinical diagnosis.
1933
  )
1934
  return (
1935
  gr.update(
1936
- value=f"Signed in as **{username}** \u00b7 [Sign out](/logout)",
1937
  visible=True,
1938
  ), # login_status_md
1939
  gr.update(visible=True), # user_storage_tabs
@@ -1943,7 +1944,7 @@ This tool is for research purposes only and not approved for clinical diagnosis.
1943
  else:
1944
  return (
1945
  gr.update(
1946
- value="[Sign in with HuggingFace](/login/huggingface) to save slides and results",
1947
  visible=True,
1948
  ), # login_status_md
1949
  gr.update(visible=False), # user_storage_tabs
@@ -1967,6 +1968,10 @@ This tool is for research purposes only and not approved for clinical diagnosis.
1967
  # Higher-memory GPUs and ZeroGPU can handle multiple concurrent analyses
1968
  demo.queue(max_size=10, default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT)
1969
 
 
 
 
 
1970
  # Register cleanup handler for graceful shutdown
1971
  import atexit
1972
 
 
37
  from mosaic.model_manager import load_all_models
38
  from mosaic.hardware import DEFAULT_CONCURRENCY_LIMIT, IS_HF_SPACES, IS_T4_GPU, GPU_TYPE
39
  from mosaic.telemetry import extract_user_info
40
+ from mosaic.ui.oauth import mount_oauth_routes
41
  from mosaic.tcga import (
42
  compute_settings_hash,
43
  download_results_from_hf,
 
1934
  )
1935
  return (
1936
  gr.update(
1937
+ value=f"Signed in as **{username}** \u00b7 [Sign out](/api/auth/logout)",
1938
  visible=True,
1939
  ), # login_status_md
1940
  gr.update(visible=True), # user_storage_tabs
 
1944
  else:
1945
  return (
1946
  gr.update(
1947
+ value="[Sign in with HuggingFace](/api/auth/login) to save slides and results",
1948
  visible=True,
1949
  ), # login_status_md
1950
  gr.update(visible=False), # user_storage_tabs
 
1968
  # Higher-memory GPUs and ZeroGPU can handle multiple concurrent analyses
1969
  demo.queue(max_size=10, default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT)
1970
 
1971
+ # Mount custom OAuth routes for HF Spaces (sdk:docker)
1972
+ if IS_HF_SPACES:
1973
+ mount_oauth_routes(demo.app)
1974
+
1975
  # Register cleanup handler for graceful shutdown
1976
  import atexit
1977
 
src/mosaic/ui/oauth.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Manual OAuth flow for HF Spaces with Docker SDK.
2
+
3
+ On HF Spaces with sdk:docker, Gradio's built-in OAuth (gr.LoginButton /
4
+ gr.OAuthProfile) doesn't work because the HF reverse proxy injects the
5
+ Space owner's identity into every session. This module implements the
6
+ Authorization Code flow directly against HF's OAuth provider and writes
7
+ the visitor's real identity into the Starlette session so that Gradio's
8
+ existing OAuthProfile injection works transparently.
9
+
10
+ Environment variables (set automatically by HF Spaces):
11
+ OAUTH_CLIENT_ID: OAuth application client ID
12
+ OAUTH_CLIENT_SECRET: OAuth application client secret
13
+ OAUTH_SCOPES: Space-separated scopes (default: "openid profile")
14
+ SPACE_HOST: Public hostname of the Space (e.g. "user-space.hf.space")
15
+
16
+ Routes mounted on the Gradio ASGI app:
17
+ GET /api/auth/login -> Redirect to HF authorize endpoint
18
+ GET /api/auth/callback -> Exchange code for token, set session
19
+ GET /api/auth/logout -> Clear session, redirect to /
20
+ """
21
+
22
+ import json
23
+ import os
24
+ import secrets
25
+ import time
26
+ from typing import Optional
27
+ from urllib.parse import urlencode
28
+
29
+ from loguru import logger
30
+ from starlette.requests import Request
31
+ from starlette.responses import RedirectResponse, JSONResponse
32
+ from starlette.routing import Route
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Configuration
36
+ # ---------------------------------------------------------------------------
37
+
38
+ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
39
+ OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
40
+ OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES", "openid profile")
41
+ SPACE_HOST = os.environ.get("SPACE_HOST", "")
42
+
43
+ HF_AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
44
+ HF_TOKEN_URL = "https://huggingface.co/oauth/token"
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Server-side session store
48
+ # ---------------------------------------------------------------------------
49
+
50
+ # {cookie_value: {"userinfo": {...}, "created_at": float}}
51
+ _sessions: dict[str, dict] = {}
52
+ _SESSION_TTL_SEC = 24 * 60 * 60 # 24 hours
53
+ _SESSION_COOKIE = "mosaic_auth"
54
+
55
+ # In-memory CSRF state tokens: {state_value: created_at_float}
56
+ _pending_states: dict[str, float] = {}
57
+ _STATE_TTL_SEC = 10 * 60 # 10 minutes
58
+
59
+
60
+ def _prune_expired() -> None:
61
+ """Remove expired sessions and CSRF state tokens."""
62
+ now = time.time()
63
+ expired_sessions = [
64
+ k for k, v in _sessions.items() if now - v["created_at"] > _SESSION_TTL_SEC
65
+ ]
66
+ for k in expired_sessions:
67
+ del _sessions[k]
68
+ expired_states = [k for k, v in _pending_states.items() if now - v > _STATE_TTL_SEC]
69
+ for k in expired_states:
70
+ del _pending_states[k]
71
+
72
+
73
+ def get_user_from_server_session(request) -> Optional[dict]:
74
+ """Look up user info from the server-side session store.
75
+
76
+ Checks the ``mosaic_auth`` cookie in the request and returns the
77
+ stored userinfo dict, or None if the cookie is missing/expired.
78
+
79
+ Args:
80
+ request: Starlette/Gradio request object (needs .cookies or .headers)
81
+
82
+ Returns:
83
+ userinfo dict with at least ``preferred_username`` key, or None
84
+ """
85
+ _prune_expired()
86
+
87
+ cookie_val = None
88
+ # Gradio's gr.Request wraps a Starlette Request; try .cookies first
89
+ if hasattr(request, "cookies"):
90
+ cookies = request.cookies
91
+ if isinstance(cookies, dict):
92
+ cookie_val = cookies.get(_SESSION_COOKIE)
93
+ # Fallback: parse Cookie header
94
+ if cookie_val is None and hasattr(request, "headers"):
95
+ headers = request.headers
96
+ cookie_header = None
97
+ if isinstance(headers, dict):
98
+ cookie_header = headers.get("cookie", "")
99
+ elif hasattr(headers, "get"):
100
+ cookie_header = headers.get("cookie", "")
101
+ if cookie_header:
102
+ for part in cookie_header.split(";"):
103
+ part = part.strip()
104
+ if part.startswith(f"{_SESSION_COOKIE}="):
105
+ cookie_val = part[len(f"{_SESSION_COOKIE}=") :]
106
+ break
107
+
108
+ if not cookie_val:
109
+ return None
110
+
111
+ entry = _sessions.get(cookie_val)
112
+ if entry is None:
113
+ return None
114
+
115
+ if time.time() - entry["created_at"] > _SESSION_TTL_SEC:
116
+ del _sessions[cookie_val]
117
+ return None
118
+
119
+ return entry.get("userinfo")
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Route handlers
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ async def _login(request: Request):
128
+ """Redirect to HF OAuth authorize endpoint."""
129
+ if not OAUTH_CLIENT_ID or not SPACE_HOST:
130
+ return JSONResponse(
131
+ {"error": "OAuth not configured (missing OAUTH_CLIENT_ID or SPACE_HOST)"},
132
+ status_code=500,
133
+ )
134
+
135
+ _prune_expired()
136
+ state = secrets.token_urlsafe(32)
137
+ _pending_states[state] = time.time()
138
+
139
+ # Build redirect URI pointing back to our callback
140
+ redirect_uri = f"https://{SPACE_HOST}/api/auth/callback"
141
+
142
+ params = {
143
+ "client_id": OAUTH_CLIENT_ID,
144
+ "redirect_uri": redirect_uri,
145
+ "response_type": "code",
146
+ "scope": OAUTH_SCOPES,
147
+ "state": state,
148
+ }
149
+ authorize_url = f"{HF_AUTHORIZE_URL}?{urlencode(params)}"
150
+ logger.info(f"OAuth login: redirecting to HF authorize endpoint")
151
+ return RedirectResponse(authorize_url, status_code=302)
152
+
153
+
154
+ async def _callback(request: Request):
155
+ """Handle OAuth callback: exchange code for token, set session."""
156
+ import httpx
157
+
158
+ code = request.query_params.get("code")
159
+ state = request.query_params.get("state")
160
+
161
+ if not code or not state:
162
+ return JSONResponse(
163
+ {"error": "Missing code or state parameter"}, status_code=400
164
+ )
165
+
166
+ # Validate CSRF state
167
+ _prune_expired()
168
+ created_at = _pending_states.pop(state, None)
169
+ if created_at is None:
170
+ return JSONResponse(
171
+ {"error": "Invalid or expired state parameter"}, status_code=400
172
+ )
173
+
174
+ # Exchange authorization code for access token
175
+ redirect_uri = f"https://{SPACE_HOST}/api/auth/callback"
176
+ token_data = {
177
+ "grant_type": "authorization_code",
178
+ "code": code,
179
+ "redirect_uri": redirect_uri,
180
+ "client_id": OAUTH_CLIENT_ID,
181
+ "client_secret": OAUTH_CLIENT_SECRET,
182
+ }
183
+
184
+ try:
185
+ async with httpx.AsyncClient() as client:
186
+ resp = await client.post(
187
+ HF_TOKEN_URL,
188
+ data=token_data,
189
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
190
+ )
191
+ resp.raise_for_status()
192
+ token_response = resp.json()
193
+ except httpx.HTTPStatusError as e:
194
+ logger.error(
195
+ f"OAuth token exchange failed: {e.response.status_code} {e.response.text}"
196
+ )
197
+ return JSONResponse({"error": "Token exchange failed"}, status_code=502)
198
+ except Exception as e:
199
+ logger.error(f"OAuth token exchange error: {e}")
200
+ return JSONResponse({"error": "Token exchange failed"}, status_code=502)
201
+
202
+ # Extract userinfo from the id_token (JWT) or use the userinfo endpoint
203
+ access_token = token_response.get("access_token")
204
+ userinfo = None
205
+
206
+ # Try to decode the id_token (JWT) for userinfo
207
+ id_token = token_response.get("id_token")
208
+ if id_token:
209
+ try:
210
+ # JWT is base64url-encoded: header.payload.signature
211
+ # We only need the payload (claims) — no signature verification
212
+ # since we just received this directly from HF's token endpoint
213
+ import base64
214
+
215
+ payload_b64 = id_token.split(".")[1]
216
+ # Add padding if needed
217
+ padding = 4 - len(payload_b64) % 4
218
+ if padding != 4:
219
+ payload_b64 += "=" * padding
220
+ payload_bytes = base64.urlsafe_b64decode(payload_b64)
221
+ userinfo = json.loads(payload_bytes)
222
+ except Exception as e:
223
+ logger.warning(f"Failed to decode id_token: {e}")
224
+
225
+ # Fallback: call userinfo endpoint
226
+ if userinfo is None and access_token:
227
+ try:
228
+ async with httpx.AsyncClient() as client:
229
+ resp = await client.get(
230
+ "https://huggingface.co/oauth/userinfo",
231
+ headers={"Authorization": f"Bearer {access_token}"},
232
+ )
233
+ resp.raise_for_status()
234
+ userinfo = resp.json()
235
+ except Exception as e:
236
+ logger.error(f"Failed to fetch userinfo: {e}")
237
+ return JSONResponse({"error": "Failed to get user info"}, status_code=502)
238
+
239
+ if not userinfo:
240
+ return JSONResponse({"error": "No user info received"}, status_code=502)
241
+
242
+ # Extract username
243
+ username = userinfo.get("preferred_username") or userinfo.get("sub", "unknown")
244
+ logger.info(f"OAuth callback: authenticated user '{username}'")
245
+
246
+ # Write into Gradio's Starlette session so OAuthProfile picks it up
247
+ # Gradio expects session["oauth_info"]["userinfo"] with at least
248
+ # "preferred_username" and optionally "name", "picture", etc.
249
+ oauth_info = {
250
+ "userinfo": userinfo,
251
+ "access_token": access_token,
252
+ }
253
+ request.session["oauth_info"] = oauth_info
254
+
255
+ # Also store in our server-side session as fallback
256
+ session_id = secrets.token_urlsafe(32)
257
+ _sessions[session_id] = {
258
+ "userinfo": userinfo,
259
+ "created_at": time.time(),
260
+ }
261
+
262
+ # Build redirect response with cookie
263
+ response = RedirectResponse("/", status_code=302)
264
+ response.set_cookie(
265
+ _SESSION_COOKIE,
266
+ session_id,
267
+ httponly=True,
268
+ secure=True,
269
+ samesite="lax",
270
+ max_age=_SESSION_TTL_SEC,
271
+ )
272
+ return response
273
+
274
+
275
+ async def _logout(request: Request):
276
+ """Clear session and redirect to /."""
277
+ # Clear Gradio's session
278
+ if "oauth_info" in request.session:
279
+ del request.session["oauth_info"]
280
+
281
+ # Clear server-side session
282
+ cookie_val = request.cookies.get(_SESSION_COOKIE)
283
+ if cookie_val and cookie_val in _sessions:
284
+ del _sessions[cookie_val]
285
+
286
+ response = RedirectResponse("/", status_code=302)
287
+ response.delete_cookie(_SESSION_COOKIE)
288
+ return response
289
+
290
+
291
+ # ---------------------------------------------------------------------------
292
+ # Mount helper
293
+ # ---------------------------------------------------------------------------
294
+
295
+ _oauth_routes = [
296
+ Route("/api/auth/login", _login, methods=["GET"]),
297
+ Route("/api/auth/callback", _callback, methods=["GET"]),
298
+ Route("/api/auth/logout", _logout, methods=["GET"]),
299
+ ]
300
+
301
+
302
+ def mount_oauth_routes(app) -> None:
303
+ """Mount custom OAuth routes on the Gradio ASGI app.
304
+
305
+ Should be called after ``demo.queue()`` and before ``demo.launch()``.
306
+
307
+ Args:
308
+ app: The Starlette/FastAPI app (``demo.app``)
309
+ """
310
+ if not OAUTH_CLIENT_ID:
311
+ logger.warning(
312
+ "OAuth routes not mounted: OAUTH_CLIENT_ID not set. "
313
+ "Custom login will not work."
314
+ )
315
+ return
316
+
317
+ # Insert our routes at the beginning so they take priority
318
+ app.routes[:0] = _oauth_routes
319
+ logger.info(
320
+ f"Mounted custom OAuth routes: /api/auth/login, /api/auth/callback, /api/auth/logout"
321
+ )
src/mosaic/ui/user_tabs.py CHANGED
@@ -54,6 +54,19 @@ def _get_username(request: gr.Request = None, profile=None) -> tuple[str, bool]:
54
  except Exception:
55
  pass
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return (None, False)
58
  else:
59
  # Local mode - use debug username
 
54
  except Exception:
55
  pass
56
 
57
+ # Fallback: check server-side session store (custom OAuth flow)
58
+ if request is not None:
59
+ try:
60
+ from mosaic.ui.oauth import get_user_from_server_session
61
+
62
+ userinfo = get_user_from_server_session(request)
63
+ if userinfo:
64
+ username = userinfo.get("preferred_username")
65
+ if username:
66
+ return (username, False)
67
+ except Exception:
68
+ pass
69
+
70
  return (None, False)
71
  else:
72
  # Local mode - use debug username
tests/telemetry/test_tracker.py CHANGED
@@ -8,7 +8,6 @@ from pathlib import Path
8
  import pytest
9
 
10
  from mosaic.telemetry import TelemetryTracker, TelemetryConfig
11
- from mosaic.telemetry.utils import hash_username
12
 
13
 
14
  @pytest.fixture
@@ -153,7 +152,7 @@ class TestUsageEvents:
153
  event = json.loads(f.read().strip())
154
 
155
  assert event["is_logged_in"] is True
156
- assert event["hf_username"] == hash_username("testuser")
157
 
158
  def test_log_analysis_complete(self, tracker, temp_dir):
159
  """Test logging analysis complete event."""
@@ -287,7 +286,7 @@ class TestResourceEvents:
287
  event = json.loads(f.read().strip())
288
 
289
  assert event["is_logged_in"] is True
290
- assert event["hf_username"] == hash_username("testuser")
291
 
292
 
293
  class TestFailureEvents:
 
8
  import pytest
9
 
10
  from mosaic.telemetry import TelemetryTracker, TelemetryConfig
 
11
 
12
 
13
  @pytest.fixture
 
152
  event = json.loads(f.read().strip())
153
 
154
  assert event["is_logged_in"] is True
155
+ assert event["hf_username"] == "testuser"
156
 
157
  def test_log_analysis_complete(self, tracker, temp_dir):
158
  """Test logging analysis complete event."""
 
286
  event = json.loads(f.read().strip())
287
 
288
  assert event["is_logged_in"] is True
289
+ assert event["hf_username"] == "testuser"
290
 
291
 
292
  class TestFailureEvents:
tests/telemetry/test_utils.py CHANGED
@@ -8,7 +8,6 @@ from mosaic.telemetry.utils import (
8
  StageTimer,
9
  sanitize_error_message,
10
  hash_session_id,
11
- hash_username,
12
  UserInfo,
13
  extract_user_info,
14
  )
@@ -151,40 +150,6 @@ class TestHashSessionId:
151
  assert len(set(hashes)) == 1 # All hashes should be identical
152
 
153
 
154
- class TestHashUsername:
155
- """Tests for username hashing."""
156
-
157
- def test_hash_username(self):
158
- """Test basic username hashing."""
159
- hashed = hash_username("testuser")
160
- assert hashed is not None
161
- assert hashed != "testuser"
162
- assert len(hashed) == 16
163
-
164
- def test_hash_none_returns_none(self):
165
- """Test that None input returns None."""
166
- assert hash_username(None) is None
167
-
168
- def test_hash_is_deterministic(self):
169
- """Test that same input produces same hash."""
170
- hash1 = hash_username("alice")
171
- hash2 = hash_username("alice")
172
- assert hash1 == hash2
173
-
174
- def test_different_inputs_different_hashes(self):
175
- """Test that different usernames produce different hashes."""
176
- hash1 = hash_username("alice")
177
- hash2 = hash_username("bob")
178
- assert hash1 != hash2
179
-
180
- def test_different_salt_from_session_id(self):
181
- """Test that username hash uses different salt than session hash."""
182
- value = "same_value"
183
- username_hash = hash_username(value)
184
- session_hash = hash_session_id(value)
185
- assert username_hash != session_hash
186
-
187
-
188
  class TestUserInfo:
189
  """Tests for UserInfo dataclass."""
190
 
 
8
  StageTimer,
9
  sanitize_error_message,
10
  hash_session_id,
 
11
  UserInfo,
12
  extract_user_info,
13
  )
 
150
  assert len(set(hashes)) == 1 # All hashes should be identical
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  class TestUserInfo:
154
  """Tests for UserInfo dataclass."""
155
 
tests/test_oauth.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the custom OAuth module (mosaic.ui.oauth)."""
2
+
3
+ import time
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ import pytest
7
+
8
+ from mosaic.ui.oauth import (
9
+ _prune_expired,
10
+ _sessions,
11
+ _pending_states,
12
+ _SESSION_COOKIE,
13
+ _SESSION_TTL_SEC,
14
+ _STATE_TTL_SEC,
15
+ get_user_from_server_session,
16
+ mount_oauth_routes,
17
+ )
18
+
19
+
20
+ @pytest.fixture(autouse=True)
21
+ def clean_sessions():
22
+ """Clear session and state stores before each test."""
23
+ _sessions.clear()
24
+ _pending_states.clear()
25
+ yield
26
+ _sessions.clear()
27
+ _pending_states.clear()
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Session store tests
32
+ # ---------------------------------------------------------------------------
33
+
34
+
35
+ class TestSessionStore:
36
+ """Tests for the server-side session store."""
37
+
38
+ def test_store_and_retrieve_session(self):
39
+ """Test basic session CRUD."""
40
+ _sessions["abc123"] = {
41
+ "userinfo": {"preferred_username": "alice"},
42
+ "created_at": time.time(),
43
+ }
44
+
45
+ request = MagicMock()
46
+ request.cookies = {_SESSION_COOKIE: "abc123"}
47
+ request.headers = {}
48
+
49
+ userinfo = get_user_from_server_session(request)
50
+ assert userinfo is not None
51
+ assert userinfo["preferred_username"] == "alice"
52
+
53
+ def test_missing_cookie_returns_none(self):
54
+ """Test that missing cookie returns None."""
55
+ request = MagicMock()
56
+ request.cookies = {}
57
+ request.headers = {}
58
+
59
+ assert get_user_from_server_session(request) is None
60
+
61
+ def test_unknown_session_id_returns_none(self):
62
+ """Test that unknown session ID returns None."""
63
+ request = MagicMock()
64
+ request.cookies = {_SESSION_COOKIE: "unknown"}
65
+ request.headers = {}
66
+
67
+ assert get_user_from_server_session(request) is None
68
+
69
+ def test_expired_session_returns_none(self):
70
+ """Test that expired sessions are pruned."""
71
+ _sessions["expired"] = {
72
+ "userinfo": {"preferred_username": "bob"},
73
+ "created_at": time.time() - _SESSION_TTL_SEC - 1,
74
+ }
75
+
76
+ request = MagicMock()
77
+ request.cookies = {_SESSION_COOKIE: "expired"}
78
+ request.headers = {}
79
+
80
+ assert get_user_from_server_session(request) is None
81
+ assert "expired" not in _sessions
82
+
83
+ def test_cookie_from_header_fallback(self):
84
+ """Test parsing cookie from raw Cookie header."""
85
+ _sessions["header_val"] = {
86
+ "userinfo": {"preferred_username": "carol"},
87
+ "created_at": time.time(),
88
+ }
89
+
90
+ request = MagicMock()
91
+ request.cookies = {} # Empty dict, fallback to header
92
+ request.headers = {"cookie": f"other=x; {_SESSION_COOKIE}=header_val; foo=bar"}
93
+
94
+ userinfo = get_user_from_server_session(request)
95
+ assert userinfo is not None
96
+ assert userinfo["preferred_username"] == "carol"
97
+
98
+ def test_none_request_returns_none(self):
99
+ """Test that None request returns None gracefully."""
100
+ # get_user_from_server_session checks hasattr, so pass an object
101
+ # without cookies or headers
102
+ assert get_user_from_server_session(None) is None
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Prune tests
107
+ # ---------------------------------------------------------------------------
108
+
109
+
110
+ class TestPruneExpired:
111
+ """Tests for _prune_expired."""
112
+
113
+ def test_prune_expired_sessions(self):
114
+ """Test that expired sessions are removed."""
115
+ _sessions["old"] = {
116
+ "userinfo": {"preferred_username": "old_user"},
117
+ "created_at": time.time() - _SESSION_TTL_SEC - 100,
118
+ }
119
+ _sessions["fresh"] = {
120
+ "userinfo": {"preferred_username": "new_user"},
121
+ "created_at": time.time(),
122
+ }
123
+
124
+ _prune_expired()
125
+
126
+ assert "old" not in _sessions
127
+ assert "fresh" in _sessions
128
+
129
+ def test_prune_expired_states(self):
130
+ """Test that expired CSRF states are removed."""
131
+ _pending_states["old_state"] = time.time() - _STATE_TTL_SEC - 100
132
+ _pending_states["fresh_state"] = time.time()
133
+
134
+ _prune_expired()
135
+
136
+ assert "old_state" not in _pending_states
137
+ assert "fresh_state" in _pending_states
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # mount_oauth_routes tests
142
+ # ---------------------------------------------------------------------------
143
+
144
+
145
+ class TestMountOAuthRoutes:
146
+ """Tests for mount_oauth_routes."""
147
+
148
+ @patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "")
149
+ def test_no_routes_when_client_id_missing(self):
150
+ """Test that routes are not mounted when OAUTH_CLIENT_ID is empty."""
151
+ app = MagicMock()
152
+ app.routes = []
153
+
154
+ mount_oauth_routes(app)
155
+
156
+ assert len(app.routes) == 0
157
+
158
+ @patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "test-client-id")
159
+ def test_routes_mounted_when_configured(self):
160
+ """Test that routes are mounted when OAuth is configured."""
161
+ app = MagicMock()
162
+ app.routes = []
163
+
164
+ mount_oauth_routes(app)
165
+
166
+ assert len(app.routes) == 3
167
+ paths = [r.path for r in app.routes]
168
+ assert "/api/auth/login" in paths
169
+ assert "/api/auth/callback" in paths
170
+ assert "/api/auth/logout" in paths
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # Login route tests
175
+ # ---------------------------------------------------------------------------
176
+
177
+
178
+ class TestLoginRoute:
179
+ """Tests for the /api/auth/login route."""
180
+
181
+ @pytest.mark.asyncio
182
+ @patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "test-client-id")
183
+ @patch("mosaic.ui.oauth.SPACE_HOST", "user-space.hf.space")
184
+ async def test_login_redirects_to_hf(self):
185
+ """Test that login redirects to HF authorize URL."""
186
+ from mosaic.ui.oauth import _login
187
+
188
+ request = MagicMock()
189
+ response = await _login(request)
190
+
191
+ assert response.status_code == 302
192
+ location = response.headers["location"]
193
+ assert "huggingface.co/oauth/authorize" in location
194
+ assert "client_id=test-client-id" in location
195
+ assert "redirect_uri=" in location
196
+ assert "state=" in location
197
+ # Should have stored a CSRF state
198
+ assert len(_pending_states) == 1
199
+
200
+ @pytest.mark.asyncio
201
+ @patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "")
202
+ async def test_login_returns_error_when_not_configured(self):
203
+ """Test that login returns error when OAuth is not configured."""
204
+ from mosaic.ui.oauth import _login
205
+
206
+ request = MagicMock()
207
+ response = await _login(request)
208
+
209
+ assert response.status_code == 500
210
+
211
+
212
+ # ---------------------------------------------------------------------------
213
+ # Callback route tests
214
+ # ---------------------------------------------------------------------------
215
+
216
+
217
+ class TestCallbackRoute:
218
+ """Tests for the /api/auth/callback route."""
219
+
220
+ @pytest.mark.asyncio
221
+ async def test_callback_missing_params(self):
222
+ """Test callback with missing code/state parameters."""
223
+ from mosaic.ui.oauth import _callback
224
+
225
+ request = MagicMock()
226
+ request.query_params = {}
227
+
228
+ response = await _callback(request)
229
+ assert response.status_code == 400
230
+
231
+ @pytest.mark.asyncio
232
+ async def test_callback_invalid_state(self):
233
+ """Test callback with invalid CSRF state."""
234
+ from mosaic.ui.oauth import _callback
235
+
236
+ request = MagicMock()
237
+ request.query_params = {"code": "test-code", "state": "invalid-state"}
238
+
239
+ response = await _callback(request)
240
+ assert response.status_code == 400
241
+
242
+ @pytest.mark.asyncio
243
+ @patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "test-client-id")
244
+ @patch("mosaic.ui.oauth.OAUTH_CLIENT_SECRET", "test-secret")
245
+ @patch("mosaic.ui.oauth.SPACE_HOST", "user-space.hf.space")
246
+ async def test_callback_exchanges_code_for_token(self):
247
+ """Test successful callback with mocked token exchange."""
248
+ import json
249
+ import base64
250
+
251
+ from mosaic.ui.oauth import _callback
252
+
253
+ # Set up a valid CSRF state
254
+ valid_state = "valid-state-token"
255
+ _pending_states[valid_state] = time.time()
256
+
257
+ # Create a mock id_token JWT
258
+ header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
259
+ payload_data = json.dumps(
260
+ {"preferred_username": "visitor123", "sub": "visitor123"}
261
+ )
262
+ payload = base64.urlsafe_b64encode(payload_data.encode()).rstrip(b"=").decode()
263
+ mock_id_token = f"{header}.{payload}.fake_signature"
264
+
265
+ # Mock httpx — json() is a regular method (not async) on httpx.Response
266
+ mock_response = AsyncMock()
267
+ mock_response.json = MagicMock(
268
+ return_value={
269
+ "access_token": "test-access-token",
270
+ "id_token": mock_id_token,
271
+ }
272
+ )
273
+ mock_response.raise_for_status = MagicMock()
274
+
275
+ mock_client = AsyncMock()
276
+ mock_client.post.return_value = mock_response
277
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
278
+ mock_client.__aexit__ = AsyncMock(return_value=False)
279
+
280
+ request = MagicMock()
281
+ request.query_params = {"code": "auth-code", "state": valid_state}
282
+ request.session = {}
283
+
284
+ with patch("httpx.AsyncClient", return_value=mock_client):
285
+ response = await _callback(request)
286
+
287
+ # Should redirect to /
288
+ assert response.status_code == 302
289
+
290
+ # Session should have oauth_info
291
+ assert "oauth_info" in request.session
292
+ assert (
293
+ request.session["oauth_info"]["userinfo"]["preferred_username"]
294
+ == "visitor123"
295
+ )
296
+
297
+ # Server-side session should be stored
298
+ assert len(_sessions) == 1
299
+
300
+ # CSRF state should be consumed
301
+ assert valid_state not in _pending_states
302
+
303
+
304
+ # ---------------------------------------------------------------------------
305
+ # Logout route tests
306
+ # ---------------------------------------------------------------------------
307
+
308
+
309
+ class TestLogoutRoute:
310
+ """Tests for the /api/auth/logout route."""
311
+
312
+ @pytest.mark.asyncio
313
+ async def test_logout_clears_session(self):
314
+ """Test that logout clears session data."""
315
+ from mosaic.ui.oauth import _logout
316
+
317
+ # Set up server-side session
318
+ _sessions["session_id"] = {
319
+ "userinfo": {"preferred_username": "user"},
320
+ "created_at": time.time(),
321
+ }
322
+
323
+ request = MagicMock()
324
+ request.session = {"oauth_info": {"userinfo": {"preferred_username": "user"}}}
325
+ request.cookies = {_SESSION_COOKIE: "session_id"}
326
+
327
+ response = await _logout(request)
328
+
329
+ assert response.status_code == 302
330
+ assert "oauth_info" not in request.session
331
+ assert "session_id" not in _sessions
332
+
333
+
334
+ # ---------------------------------------------------------------------------
335
+ # get_user_from_server_session integration
336
+ # ---------------------------------------------------------------------------
337
+
338
+
339
+ class TestGetUserFromServerSession:
340
+ """Integration tests for the server-session fallback."""
341
+
342
+ def test_works_with_gradio_request_like_object(self):
343
+ """Test with an object mimicking gr.Request."""
344
+ _sessions["gr_session"] = {
345
+ "userinfo": {"preferred_username": "gradio_user", "name": "Test User"},
346
+ "created_at": time.time(),
347
+ }
348
+
349
+ class MockGradioRequest:
350
+ cookies = {_SESSION_COOKIE: "gr_session"}
351
+ headers = {}
352
+
353
+ result = get_user_from_server_session(MockGradioRequest())
354
+ assert result is not None
355
+ assert result["preferred_username"] == "gradio_user"
356
+ assert result["name"] == "Test User"
357
+
358
+ def test_returns_none_for_object_without_cookies(self):
359
+ """Test graceful handling of objects without cookies."""
360
+
361
+ class Bare:
362
+ pass
363
+
364
+ assert get_user_from_server_session(Bare()) is None
uv.lock CHANGED
The diff for this file is too large to render. See raw diff