feat: implement manual OAuth for HF Spaces Docker SDK
Browse filesOn HF Spaces with sdk:docker, Gradio's built-in OAuth injects the Space
owner's identity into every session. This adds custom /api/auth/* routes
that perform the full Authorization Code flow directly with HF's OAuth
provider, writing the correct visitor identity into the Starlette session
so Gradio's OAuthProfile injection works transparently.
- Add src/mosaic/ui/oauth.py with login/callback/logout routes and
server-side session store (24h TTL, mosaic_auth cookie fallback)
- Mount custom OAuth routes on Gradio app when IS_HF_SPACES
- Update login/logout links to /api/auth/login and /api/auth/logout
- Add server-session fallback to extract_user_info() and _get_username()
- Stop hashing usernames in telemetry (use raw HF usernames)
- Add pyproject.toml uv index-strategy for macOS CPU torch compatibility
- Add tests/test_oauth.py with 18 tests covering the full OAuth flow
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- pyproject.toml +7 -0
- src/mosaic/telemetry/__init__.py +0 -2
- src/mosaic/telemetry/tracker.py +3 -4
- src/mosaic/telemetry/utils.py +15 -20
- src/mosaic/ui/app.py +7 -2
- src/mosaic/ui/oauth.py +321 -0
- src/mosaic/ui/user_tabs.py +13 -0
- tests/telemetry/test_tracker.py +2 -3
- tests/telemetry/test_utils.py +0 -35
- tests/test_oauth.py +364 -0
- uv.lock +0 -0
|
@@ -42,6 +42,13 @@ disable = [
|
|
| 42 |
"unspecified-encoding",
|
| 43 |
]
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
[tool.uv.sources]
|
| 46 |
# For local dev with SSH: uv pip install -e .
|
| 47 |
# For Docker builds with token: GH_TOKEN=<token> uv pip install -e .
|
|
|
|
| 42 |
"unspecified-encoding",
|
| 43 |
]
|
| 44 |
|
| 45 |
+
[tool.uv]
|
| 46 |
+
# Override PyTorch dependencies from mussel[torch-gpu] for macOS compatibility
|
| 47 |
+
override-dependencies = [
|
| 48 |
+
"torch>=2.0.0; sys_platform == 'darwin'",
|
| 49 |
+
"torchvision>=0.15.0; sys_platform == 'darwin'",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
[tool.uv.sources]
|
| 53 |
# For local dev with SSH: uv pip install -e .
|
| 54 |
# For Docker builds with token: GH_TOKEN=<token> uv pip install -e .
|
|
@@ -47,7 +47,6 @@ from mosaic.telemetry.utils import (
|
|
| 47 |
StageTimer,
|
| 48 |
sanitize_error_message,
|
| 49 |
hash_session_id,
|
| 50 |
-
hash_username,
|
| 51 |
UserInfo,
|
| 52 |
extract_user_info,
|
| 53 |
)
|
|
@@ -68,7 +67,6 @@ __all__ = [
|
|
| 68 |
"StageTimer",
|
| 69 |
"sanitize_error_message",
|
| 70 |
"hash_session_id",
|
| 71 |
-
"hash_username",
|
| 72 |
"UserInfo",
|
| 73 |
"extract_user_info",
|
| 74 |
]
|
|
|
|
| 47 |
StageTimer,
|
| 48 |
sanitize_error_message,
|
| 49 |
hash_session_id,
|
|
|
|
| 50 |
UserInfo,
|
| 51 |
extract_user_info,
|
| 52 |
)
|
|
|
|
| 67 |
"StageTimer",
|
| 68 |
"sanitize_error_message",
|
| 69 |
"hash_session_id",
|
|
|
|
| 70 |
"UserInfo",
|
| 71 |
"extract_user_info",
|
| 72 |
]
|
|
@@ -22,7 +22,6 @@ from mosaic.telemetry.events import (
|
|
| 22 |
from mosaic.telemetry.storage import TelemetryStorage
|
| 23 |
from mosaic.telemetry.utils import (
|
| 24 |
hash_session_id,
|
| 25 |
-
hash_username,
|
| 26 |
sanitize_error_message,
|
| 27 |
)
|
| 28 |
|
|
@@ -256,7 +255,7 @@ class TelemetryTracker:
|
|
| 256 |
success=success,
|
| 257 |
cached_slide_count=cached_slide_count,
|
| 258 |
is_logged_in=is_logged_in,
|
| 259 |
-
hf_username=
|
| 260 |
)
|
| 261 |
self.storage.write_usage_event(event)
|
| 262 |
|
|
@@ -332,7 +331,7 @@ class TelemetryTracker:
|
|
| 332 |
gpu_type=gpu_type,
|
| 333 |
peak_gpu_memory_gb=peak_gpu_memory_gb,
|
| 334 |
is_logged_in=is_logged_in,
|
| 335 |
-
hf_username=
|
| 336 |
)
|
| 337 |
self.storage.write_resource_event(event)
|
| 338 |
|
|
@@ -377,7 +376,7 @@ class TelemetryTracker:
|
|
| 377 |
slide_count=slide_count,
|
| 378 |
gpu_type=gpu_type,
|
| 379 |
is_logged_in=is_logged_in,
|
| 380 |
-
hf_username=
|
| 381 |
)
|
| 382 |
self.storage.write_failure_event(event)
|
| 383 |
|
|
|
|
| 22 |
from mosaic.telemetry.storage import TelemetryStorage
|
| 23 |
from mosaic.telemetry.utils import (
|
| 24 |
hash_session_id,
|
|
|
|
| 25 |
sanitize_error_message,
|
| 26 |
)
|
| 27 |
|
|
|
|
| 255 |
success=success,
|
| 256 |
cached_slide_count=cached_slide_count,
|
| 257 |
is_logged_in=is_logged_in,
|
| 258 |
+
hf_username=hf_username,
|
| 259 |
)
|
| 260 |
self.storage.write_usage_event(event)
|
| 261 |
|
|
|
|
| 331 |
gpu_type=gpu_type,
|
| 332 |
peak_gpu_memory_gb=peak_gpu_memory_gb,
|
| 333 |
is_logged_in=is_logged_in,
|
| 334 |
+
hf_username=hf_username,
|
| 335 |
)
|
| 336 |
self.storage.write_resource_event(event)
|
| 337 |
|
|
|
|
| 376 |
slide_count=slide_count,
|
| 377 |
gpu_type=gpu_type,
|
| 378 |
is_logged_in=is_logged_in,
|
| 379 |
+
hf_username=hf_username,
|
| 380 |
)
|
| 381 |
self.storage.write_failure_event(event)
|
| 382 |
|
|
@@ -5,7 +5,7 @@ This module provides helper utilities:
|
|
| 5 |
- sanitize_error_message: Remove sensitive data from error messages
|
| 6 |
- hash_session_id: Hash session IDs for privacy
|
| 7 |
- UserInfo: Dataclass for user information from HF Spaces
|
| 8 |
-
- extract_user_info: Extract user info from Gradio request
|
| 9 |
"""
|
| 10 |
|
| 11 |
import hashlib
|
|
@@ -99,25 +99,6 @@ def hash_session_id(session_id: Optional[str]) -> Optional[str]:
|
|
| 99 |
return hashlib.sha256(salted.encode()).hexdigest()[:16]
|
| 100 |
|
| 101 |
|
| 102 |
-
def hash_username(username: Optional[str]) -> Optional[str]:
|
| 103 |
-
"""Hash a username for privacy in telemetry.
|
| 104 |
-
|
| 105 |
-
Uses SHA-256 with a different salt than session IDs to create a one-way hash.
|
| 106 |
-
This allows distinguishing users in telemetry without storing actual usernames.
|
| 107 |
-
|
| 108 |
-
Args:
|
| 109 |
-
username: HuggingFace username (can be None for anonymous users)
|
| 110 |
-
|
| 111 |
-
Returns:
|
| 112 |
-
Hashed username or None if input is None
|
| 113 |
-
"""
|
| 114 |
-
if username is None:
|
| 115 |
-
return None
|
| 116 |
-
|
| 117 |
-
salted = f"mosaic_user:{username}"
|
| 118 |
-
return hashlib.sha256(salted.encode()).hexdigest()[:16]
|
| 119 |
-
|
| 120 |
-
|
| 121 |
@dataclass
|
| 122 |
class UserInfo:
|
| 123 |
"""User information extracted from HF Spaces request.
|
|
@@ -169,5 +150,19 @@ def extract_user_info(request, is_hf_spaces: bool = False, profile=None) -> User
|
|
| 169 |
except Exception as e:
|
| 170 |
logger.debug(f"Could not extract username from OAuthProfile: {e}")
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
logger.debug("User not logged in: no OAuthProfile available")
|
| 173 |
return UserInfo()
|
|
|
|
| 5 |
- sanitize_error_message: Remove sensitive data from error messages
|
| 6 |
- hash_session_id: Hash session IDs for privacy
|
| 7 |
- UserInfo: Dataclass for user information from HF Spaces
|
| 8 |
+
- extract_user_info: Extract user info from Gradio request/OAuth profile
|
| 9 |
"""
|
| 10 |
|
| 11 |
import hashlib
|
|
|
|
| 99 |
return hashlib.sha256(salted.encode()).hexdigest()[:16]
|
| 100 |
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
@dataclass
|
| 103 |
class UserInfo:
|
| 104 |
"""User information extracted from HF Spaces request.
|
|
|
|
| 150 |
except Exception as e:
|
| 151 |
logger.debug(f"Could not extract username from OAuthProfile: {e}")
|
| 152 |
|
| 153 |
+
# Fallback: check server-side session store (custom OAuth flow)
|
| 154 |
+
if request is not None:
|
| 155 |
+
try:
|
| 156 |
+
from mosaic.ui.oauth import get_user_from_server_session
|
| 157 |
+
|
| 158 |
+
userinfo = get_user_from_server_session(request)
|
| 159 |
+
if userinfo:
|
| 160 |
+
username = userinfo.get("preferred_username")
|
| 161 |
+
if username:
|
| 162 |
+
logger.info(f"Extracted user from server-side session: {username}")
|
| 163 |
+
return UserInfo(is_logged_in=True, username=username)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.debug(f"Server-side session lookup failed: {e}")
|
| 166 |
+
|
| 167 |
logger.debug("User not logged in: no OAuthProfile available")
|
| 168 |
return UserInfo()
|
|
@@ -37,6 +37,7 @@ from mosaic.data_directory import get_tcga_cache_directory
|
|
| 37 |
from mosaic.model_manager import load_all_models
|
| 38 |
from mosaic.hardware import DEFAULT_CONCURRENCY_LIMIT, IS_HF_SPACES, IS_T4_GPU, GPU_TYPE
|
| 39 |
from mosaic.telemetry import extract_user_info
|
|
|
|
| 40 |
from mosaic.tcga import (
|
| 41 |
compute_settings_hash,
|
| 42 |
download_results_from_hf,
|
|
@@ -1933,7 +1934,7 @@ This tool is for research purposes only and not approved for clinical diagnosis.
|
|
| 1933 |
)
|
| 1934 |
return (
|
| 1935 |
gr.update(
|
| 1936 |
-
value=f"Signed in as **{username}** \u00b7 [Sign out](/logout)",
|
| 1937 |
visible=True,
|
| 1938 |
), # login_status_md
|
| 1939 |
gr.update(visible=True), # user_storage_tabs
|
|
@@ -1943,7 +1944,7 @@ This tool is for research purposes only and not approved for clinical diagnosis.
|
|
| 1943 |
else:
|
| 1944 |
return (
|
| 1945 |
gr.update(
|
| 1946 |
-
value="[Sign in with HuggingFace](/
|
| 1947 |
visible=True,
|
| 1948 |
), # login_status_md
|
| 1949 |
gr.update(visible=False), # user_storage_tabs
|
|
@@ -1967,6 +1968,10 @@ This tool is for research purposes only and not approved for clinical diagnosis.
|
|
| 1967 |
# Higher-memory GPUs and ZeroGPU can handle multiple concurrent analyses
|
| 1968 |
demo.queue(max_size=10, default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT)
|
| 1969 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1970 |
# Register cleanup handler for graceful shutdown
|
| 1971 |
import atexit
|
| 1972 |
|
|
|
|
| 37 |
from mosaic.model_manager import load_all_models
|
| 38 |
from mosaic.hardware import DEFAULT_CONCURRENCY_LIMIT, IS_HF_SPACES, IS_T4_GPU, GPU_TYPE
|
| 39 |
from mosaic.telemetry import extract_user_info
|
| 40 |
+
from mosaic.ui.oauth import mount_oauth_routes
|
| 41 |
from mosaic.tcga import (
|
| 42 |
compute_settings_hash,
|
| 43 |
download_results_from_hf,
|
|
|
|
| 1934 |
)
|
| 1935 |
return (
|
| 1936 |
gr.update(
|
| 1937 |
+
value=f"Signed in as **{username}** \u00b7 [Sign out](/api/auth/logout)",
|
| 1938 |
visible=True,
|
| 1939 |
), # login_status_md
|
| 1940 |
gr.update(visible=True), # user_storage_tabs
|
|
|
|
| 1944 |
else:
|
| 1945 |
return (
|
| 1946 |
gr.update(
|
| 1947 |
+
value="[Sign in with HuggingFace](/api/auth/login) to save slides and results",
|
| 1948 |
visible=True,
|
| 1949 |
), # login_status_md
|
| 1950 |
gr.update(visible=False), # user_storage_tabs
|
|
|
|
| 1968 |
# Higher-memory GPUs and ZeroGPU can handle multiple concurrent analyses
|
| 1969 |
demo.queue(max_size=10, default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT)
|
| 1970 |
|
| 1971 |
+
# Mount custom OAuth routes for HF Spaces (sdk:docker)
|
| 1972 |
+
if IS_HF_SPACES:
|
| 1973 |
+
mount_oauth_routes(demo.app)
|
| 1974 |
+
|
| 1975 |
# Register cleanup handler for graceful shutdown
|
| 1976 |
import atexit
|
| 1977 |
|
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Manual OAuth flow for HF Spaces with Docker SDK.
|
| 2 |
+
|
| 3 |
+
On HF Spaces with sdk:docker, Gradio's built-in OAuth (gr.LoginButton /
|
| 4 |
+
gr.OAuthProfile) doesn't work because the HF reverse proxy injects the
|
| 5 |
+
Space owner's identity into every session. This module implements the
|
| 6 |
+
Authorization Code flow directly against HF's OAuth provider and writes
|
| 7 |
+
the visitor's real identity into the Starlette session so that Gradio's
|
| 8 |
+
existing OAuthProfile injection works transparently.
|
| 9 |
+
|
| 10 |
+
Environment variables (set automatically by HF Spaces):
|
| 11 |
+
OAUTH_CLIENT_ID: OAuth application client ID
|
| 12 |
+
OAUTH_CLIENT_SECRET: OAuth application client secret
|
| 13 |
+
OAUTH_SCOPES: Space-separated scopes (default: "openid profile")
|
| 14 |
+
SPACE_HOST: Public hostname of the Space (e.g. "user-space.hf.space")
|
| 15 |
+
|
| 16 |
+
Routes mounted on the Gradio ASGI app:
|
| 17 |
+
GET /api/auth/login -> Redirect to HF authorize endpoint
|
| 18 |
+
GET /api/auth/callback -> Exchange code for token, set session
|
| 19 |
+
GET /api/auth/logout -> Clear session, redirect to /
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import secrets
|
| 25 |
+
import time
|
| 26 |
+
from typing import Optional
|
| 27 |
+
from urllib.parse import urlencode
|
| 28 |
+
|
| 29 |
+
from loguru import logger
|
| 30 |
+
from starlette.requests import Request
|
| 31 |
+
from starlette.responses import RedirectResponse, JSONResponse
|
| 32 |
+
from starlette.routing import Route
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Configuration
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
|
| 39 |
+
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
|
| 40 |
+
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES", "openid profile")
|
| 41 |
+
SPACE_HOST = os.environ.get("SPACE_HOST", "")
|
| 42 |
+
|
| 43 |
+
HF_AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
|
| 44 |
+
HF_TOKEN_URL = "https://huggingface.co/oauth/token"
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Server-side session store
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
# {cookie_value: {"userinfo": {...}, "created_at": float}}
|
| 51 |
+
_sessions: dict[str, dict] = {}
|
| 52 |
+
_SESSION_TTL_SEC = 24 * 60 * 60 # 24 hours
|
| 53 |
+
_SESSION_COOKIE = "mosaic_auth"
|
| 54 |
+
|
| 55 |
+
# In-memory CSRF state tokens: {state_value: created_at_float}
|
| 56 |
+
_pending_states: dict[str, float] = {}
|
| 57 |
+
_STATE_TTL_SEC = 10 * 60 # 10 minutes
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _prune_expired() -> None:
|
| 61 |
+
"""Remove expired sessions and CSRF state tokens."""
|
| 62 |
+
now = time.time()
|
| 63 |
+
expired_sessions = [
|
| 64 |
+
k for k, v in _sessions.items() if now - v["created_at"] > _SESSION_TTL_SEC
|
| 65 |
+
]
|
| 66 |
+
for k in expired_sessions:
|
| 67 |
+
del _sessions[k]
|
| 68 |
+
expired_states = [k for k, v in _pending_states.items() if now - v > _STATE_TTL_SEC]
|
| 69 |
+
for k in expired_states:
|
| 70 |
+
del _pending_states[k]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_user_from_server_session(request) -> Optional[dict]:
|
| 74 |
+
"""Look up user info from the server-side session store.
|
| 75 |
+
|
| 76 |
+
Checks the ``mosaic_auth`` cookie in the request and returns the
|
| 77 |
+
stored userinfo dict, or None if the cookie is missing/expired.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
request: Starlette/Gradio request object (needs .cookies or .headers)
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
userinfo dict with at least ``preferred_username`` key, or None
|
| 84 |
+
"""
|
| 85 |
+
_prune_expired()
|
| 86 |
+
|
| 87 |
+
cookie_val = None
|
| 88 |
+
# Gradio's gr.Request wraps a Starlette Request; try .cookies first
|
| 89 |
+
if hasattr(request, "cookies"):
|
| 90 |
+
cookies = request.cookies
|
| 91 |
+
if isinstance(cookies, dict):
|
| 92 |
+
cookie_val = cookies.get(_SESSION_COOKIE)
|
| 93 |
+
# Fallback: parse Cookie header
|
| 94 |
+
if cookie_val is None and hasattr(request, "headers"):
|
| 95 |
+
headers = request.headers
|
| 96 |
+
cookie_header = None
|
| 97 |
+
if isinstance(headers, dict):
|
| 98 |
+
cookie_header = headers.get("cookie", "")
|
| 99 |
+
elif hasattr(headers, "get"):
|
| 100 |
+
cookie_header = headers.get("cookie", "")
|
| 101 |
+
if cookie_header:
|
| 102 |
+
for part in cookie_header.split(";"):
|
| 103 |
+
part = part.strip()
|
| 104 |
+
if part.startswith(f"{_SESSION_COOKIE}="):
|
| 105 |
+
cookie_val = part[len(f"{_SESSION_COOKIE}=") :]
|
| 106 |
+
break
|
| 107 |
+
|
| 108 |
+
if not cookie_val:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
entry = _sessions.get(cookie_val)
|
| 112 |
+
if entry is None:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
if time.time() - entry["created_at"] > _SESSION_TTL_SEC:
|
| 116 |
+
del _sessions[cookie_val]
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
return entry.get("userinfo")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# Route handlers
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
async def _login(request: Request):
|
| 128 |
+
"""Redirect to HF OAuth authorize endpoint."""
|
| 129 |
+
if not OAUTH_CLIENT_ID or not SPACE_HOST:
|
| 130 |
+
return JSONResponse(
|
| 131 |
+
{"error": "OAuth not configured (missing OAUTH_CLIENT_ID or SPACE_HOST)"},
|
| 132 |
+
status_code=500,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
_prune_expired()
|
| 136 |
+
state = secrets.token_urlsafe(32)
|
| 137 |
+
_pending_states[state] = time.time()
|
| 138 |
+
|
| 139 |
+
# Build redirect URI pointing back to our callback
|
| 140 |
+
redirect_uri = f"https://{SPACE_HOST}/api/auth/callback"
|
| 141 |
+
|
| 142 |
+
params = {
|
| 143 |
+
"client_id": OAUTH_CLIENT_ID,
|
| 144 |
+
"redirect_uri": redirect_uri,
|
| 145 |
+
"response_type": "code",
|
| 146 |
+
"scope": OAUTH_SCOPES,
|
| 147 |
+
"state": state,
|
| 148 |
+
}
|
| 149 |
+
authorize_url = f"{HF_AUTHORIZE_URL}?{urlencode(params)}"
|
| 150 |
+
logger.info(f"OAuth login: redirecting to HF authorize endpoint")
|
| 151 |
+
return RedirectResponse(authorize_url, status_code=302)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
async def _callback(request: Request):
|
| 155 |
+
"""Handle OAuth callback: exchange code for token, set session."""
|
| 156 |
+
import httpx
|
| 157 |
+
|
| 158 |
+
code = request.query_params.get("code")
|
| 159 |
+
state = request.query_params.get("state")
|
| 160 |
+
|
| 161 |
+
if not code or not state:
|
| 162 |
+
return JSONResponse(
|
| 163 |
+
{"error": "Missing code or state parameter"}, status_code=400
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Validate CSRF state
|
| 167 |
+
_prune_expired()
|
| 168 |
+
created_at = _pending_states.pop(state, None)
|
| 169 |
+
if created_at is None:
|
| 170 |
+
return JSONResponse(
|
| 171 |
+
{"error": "Invalid or expired state parameter"}, status_code=400
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Exchange authorization code for access token
|
| 175 |
+
redirect_uri = f"https://{SPACE_HOST}/api/auth/callback"
|
| 176 |
+
token_data = {
|
| 177 |
+
"grant_type": "authorization_code",
|
| 178 |
+
"code": code,
|
| 179 |
+
"redirect_uri": redirect_uri,
|
| 180 |
+
"client_id": OAUTH_CLIENT_ID,
|
| 181 |
+
"client_secret": OAUTH_CLIENT_SECRET,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
async with httpx.AsyncClient() as client:
|
| 186 |
+
resp = await client.post(
|
| 187 |
+
HF_TOKEN_URL,
|
| 188 |
+
data=token_data,
|
| 189 |
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
| 190 |
+
)
|
| 191 |
+
resp.raise_for_status()
|
| 192 |
+
token_response = resp.json()
|
| 193 |
+
except httpx.HTTPStatusError as e:
|
| 194 |
+
logger.error(
|
| 195 |
+
f"OAuth token exchange failed: {e.response.status_code} {e.response.text}"
|
| 196 |
+
)
|
| 197 |
+
return JSONResponse({"error": "Token exchange failed"}, status_code=502)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
logger.error(f"OAuth token exchange error: {e}")
|
| 200 |
+
return JSONResponse({"error": "Token exchange failed"}, status_code=502)
|
| 201 |
+
|
| 202 |
+
# Extract userinfo from the id_token (JWT) or use the userinfo endpoint
|
| 203 |
+
access_token = token_response.get("access_token")
|
| 204 |
+
userinfo = None
|
| 205 |
+
|
| 206 |
+
# Try to decode the id_token (JWT) for userinfo
|
| 207 |
+
id_token = token_response.get("id_token")
|
| 208 |
+
if id_token:
|
| 209 |
+
try:
|
| 210 |
+
# JWT is base64url-encoded: header.payload.signature
|
| 211 |
+
# We only need the payload (claims) — no signature verification
|
| 212 |
+
# since we just received this directly from HF's token endpoint
|
| 213 |
+
import base64
|
| 214 |
+
|
| 215 |
+
payload_b64 = id_token.split(".")[1]
|
| 216 |
+
# Add padding if needed
|
| 217 |
+
padding = 4 - len(payload_b64) % 4
|
| 218 |
+
if padding != 4:
|
| 219 |
+
payload_b64 += "=" * padding
|
| 220 |
+
payload_bytes = base64.urlsafe_b64decode(payload_b64)
|
| 221 |
+
userinfo = json.loads(payload_bytes)
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning(f"Failed to decode id_token: {e}")
|
| 224 |
+
|
| 225 |
+
# Fallback: call userinfo endpoint
|
| 226 |
+
if userinfo is None and access_token:
|
| 227 |
+
try:
|
| 228 |
+
async with httpx.AsyncClient() as client:
|
| 229 |
+
resp = await client.get(
|
| 230 |
+
"https://huggingface.co/oauth/userinfo",
|
| 231 |
+
headers={"Authorization": f"Bearer {access_token}"},
|
| 232 |
+
)
|
| 233 |
+
resp.raise_for_status()
|
| 234 |
+
userinfo = resp.json()
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.error(f"Failed to fetch userinfo: {e}")
|
| 237 |
+
return JSONResponse({"error": "Failed to get user info"}, status_code=502)
|
| 238 |
+
|
| 239 |
+
if not userinfo:
|
| 240 |
+
return JSONResponse({"error": "No user info received"}, status_code=502)
|
| 241 |
+
|
| 242 |
+
# Extract username
|
| 243 |
+
username = userinfo.get("preferred_username") or userinfo.get("sub", "unknown")
|
| 244 |
+
logger.info(f"OAuth callback: authenticated user '{username}'")
|
| 245 |
+
|
| 246 |
+
# Write into Gradio's Starlette session so OAuthProfile picks it up
|
| 247 |
+
# Gradio expects session["oauth_info"]["userinfo"] with at least
|
| 248 |
+
# "preferred_username" and optionally "name", "picture", etc.
|
| 249 |
+
oauth_info = {
|
| 250 |
+
"userinfo": userinfo,
|
| 251 |
+
"access_token": access_token,
|
| 252 |
+
}
|
| 253 |
+
request.session["oauth_info"] = oauth_info
|
| 254 |
+
|
| 255 |
+
# Also store in our server-side session as fallback
|
| 256 |
+
session_id = secrets.token_urlsafe(32)
|
| 257 |
+
_sessions[session_id] = {
|
| 258 |
+
"userinfo": userinfo,
|
| 259 |
+
"created_at": time.time(),
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
# Build redirect response with cookie
|
| 263 |
+
response = RedirectResponse("/", status_code=302)
|
| 264 |
+
response.set_cookie(
|
| 265 |
+
_SESSION_COOKIE,
|
| 266 |
+
session_id,
|
| 267 |
+
httponly=True,
|
| 268 |
+
secure=True,
|
| 269 |
+
samesite="lax",
|
| 270 |
+
max_age=_SESSION_TTL_SEC,
|
| 271 |
+
)
|
| 272 |
+
return response
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
async def _logout(request: Request):
|
| 276 |
+
"""Clear session and redirect to /."""
|
| 277 |
+
# Clear Gradio's session
|
| 278 |
+
if "oauth_info" in request.session:
|
| 279 |
+
del request.session["oauth_info"]
|
| 280 |
+
|
| 281 |
+
# Clear server-side session
|
| 282 |
+
cookie_val = request.cookies.get(_SESSION_COOKIE)
|
| 283 |
+
if cookie_val and cookie_val in _sessions:
|
| 284 |
+
del _sessions[cookie_val]
|
| 285 |
+
|
| 286 |
+
response = RedirectResponse("/", status_code=302)
|
| 287 |
+
response.delete_cookie(_SESSION_COOKIE)
|
| 288 |
+
return response
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# ---------------------------------------------------------------------------
|
| 292 |
+
# Mount helper
|
| 293 |
+
# ---------------------------------------------------------------------------
|
| 294 |
+
|
| 295 |
+
_oauth_routes = [
|
| 296 |
+
Route("/api/auth/login", _login, methods=["GET"]),
|
| 297 |
+
Route("/api/auth/callback", _callback, methods=["GET"]),
|
| 298 |
+
Route("/api/auth/logout", _logout, methods=["GET"]),
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def mount_oauth_routes(app) -> None:
|
| 303 |
+
"""Mount custom OAuth routes on the Gradio ASGI app.
|
| 304 |
+
|
| 305 |
+
Should be called after ``demo.queue()`` and before ``demo.launch()``.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
app: The Starlette/FastAPI app (``demo.app``)
|
| 309 |
+
"""
|
| 310 |
+
if not OAUTH_CLIENT_ID:
|
| 311 |
+
logger.warning(
|
| 312 |
+
"OAuth routes not mounted: OAUTH_CLIENT_ID not set. "
|
| 313 |
+
"Custom login will not work."
|
| 314 |
+
)
|
| 315 |
+
return
|
| 316 |
+
|
| 317 |
+
# Insert our routes at the beginning so they take priority
|
| 318 |
+
app.routes[:0] = _oauth_routes
|
| 319 |
+
logger.info(
|
| 320 |
+
f"Mounted custom OAuth routes: /api/auth/login, /api/auth/callback, /api/auth/logout"
|
| 321 |
+
)
|
|
@@ -54,6 +54,19 @@ def _get_username(request: gr.Request = None, profile=None) -> tuple[str, bool]:
|
|
| 54 |
except Exception:
|
| 55 |
pass
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return (None, False)
|
| 58 |
else:
|
| 59 |
# Local mode - use debug username
|
|
|
|
| 54 |
except Exception:
|
| 55 |
pass
|
| 56 |
|
| 57 |
+
# Fallback: check server-side session store (custom OAuth flow)
|
| 58 |
+
if request is not None:
|
| 59 |
+
try:
|
| 60 |
+
from mosaic.ui.oauth import get_user_from_server_session
|
| 61 |
+
|
| 62 |
+
userinfo = get_user_from_server_session(request)
|
| 63 |
+
if userinfo:
|
| 64 |
+
username = userinfo.get("preferred_username")
|
| 65 |
+
if username:
|
| 66 |
+
return (username, False)
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
return (None, False)
|
| 71 |
else:
|
| 72 |
# Local mode - use debug username
|
|
@@ -8,7 +8,6 @@ from pathlib import Path
|
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
from mosaic.telemetry import TelemetryTracker, TelemetryConfig
|
| 11 |
-
from mosaic.telemetry.utils import hash_username
|
| 12 |
|
| 13 |
|
| 14 |
@pytest.fixture
|
|
@@ -153,7 +152,7 @@ class TestUsageEvents:
|
|
| 153 |
event = json.loads(f.read().strip())
|
| 154 |
|
| 155 |
assert event["is_logged_in"] is True
|
| 156 |
-
assert event["hf_username"] ==
|
| 157 |
|
| 158 |
def test_log_analysis_complete(self, tracker, temp_dir):
|
| 159 |
"""Test logging analysis complete event."""
|
|
@@ -287,7 +286,7 @@ class TestResourceEvents:
|
|
| 287 |
event = json.loads(f.read().strip())
|
| 288 |
|
| 289 |
assert event["is_logged_in"] is True
|
| 290 |
-
assert event["hf_username"] ==
|
| 291 |
|
| 292 |
|
| 293 |
class TestFailureEvents:
|
|
|
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
from mosaic.telemetry import TelemetryTracker, TelemetryConfig
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@pytest.fixture
|
|
|
|
| 152 |
event = json.loads(f.read().strip())
|
| 153 |
|
| 154 |
assert event["is_logged_in"] is True
|
| 155 |
+
assert event["hf_username"] == "testuser"
|
| 156 |
|
| 157 |
def test_log_analysis_complete(self, tracker, temp_dir):
|
| 158 |
"""Test logging analysis complete event."""
|
|
|
|
| 286 |
event = json.loads(f.read().strip())
|
| 287 |
|
| 288 |
assert event["is_logged_in"] is True
|
| 289 |
+
assert event["hf_username"] == "testuser"
|
| 290 |
|
| 291 |
|
| 292 |
class TestFailureEvents:
|
|
@@ -8,7 +8,6 @@ from mosaic.telemetry.utils import (
|
|
| 8 |
StageTimer,
|
| 9 |
sanitize_error_message,
|
| 10 |
hash_session_id,
|
| 11 |
-
hash_username,
|
| 12 |
UserInfo,
|
| 13 |
extract_user_info,
|
| 14 |
)
|
|
@@ -151,40 +150,6 @@ class TestHashSessionId:
|
|
| 151 |
assert len(set(hashes)) == 1 # All hashes should be identical
|
| 152 |
|
| 153 |
|
| 154 |
-
class TestHashUsername:
|
| 155 |
-
"""Tests for username hashing."""
|
| 156 |
-
|
| 157 |
-
def test_hash_username(self):
|
| 158 |
-
"""Test basic username hashing."""
|
| 159 |
-
hashed = hash_username("testuser")
|
| 160 |
-
assert hashed is not None
|
| 161 |
-
assert hashed != "testuser"
|
| 162 |
-
assert len(hashed) == 16
|
| 163 |
-
|
| 164 |
-
def test_hash_none_returns_none(self):
|
| 165 |
-
"""Test that None input returns None."""
|
| 166 |
-
assert hash_username(None) is None
|
| 167 |
-
|
| 168 |
-
def test_hash_is_deterministic(self):
|
| 169 |
-
"""Test that same input produces same hash."""
|
| 170 |
-
hash1 = hash_username("alice")
|
| 171 |
-
hash2 = hash_username("alice")
|
| 172 |
-
assert hash1 == hash2
|
| 173 |
-
|
| 174 |
-
def test_different_inputs_different_hashes(self):
|
| 175 |
-
"""Test that different usernames produce different hashes."""
|
| 176 |
-
hash1 = hash_username("alice")
|
| 177 |
-
hash2 = hash_username("bob")
|
| 178 |
-
assert hash1 != hash2
|
| 179 |
-
|
| 180 |
-
def test_different_salt_from_session_id(self):
|
| 181 |
-
"""Test that username hash uses different salt than session hash."""
|
| 182 |
-
value = "same_value"
|
| 183 |
-
username_hash = hash_username(value)
|
| 184 |
-
session_hash = hash_session_id(value)
|
| 185 |
-
assert username_hash != session_hash
|
| 186 |
-
|
| 187 |
-
|
| 188 |
class TestUserInfo:
|
| 189 |
"""Tests for UserInfo dataclass."""
|
| 190 |
|
|
|
|
| 8 |
StageTimer,
|
| 9 |
sanitize_error_message,
|
| 10 |
hash_session_id,
|
|
|
|
| 11 |
UserInfo,
|
| 12 |
extract_user_info,
|
| 13 |
)
|
|
|
|
| 150 |
assert len(set(hashes)) == 1 # All hashes should be identical
|
| 151 |
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
class TestUserInfo:
|
| 154 |
"""Tests for UserInfo dataclass."""
|
| 155 |
|
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the custom OAuth module (mosaic.ui.oauth)."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from mosaic.ui.oauth import (
|
| 9 |
+
_prune_expired,
|
| 10 |
+
_sessions,
|
| 11 |
+
_pending_states,
|
| 12 |
+
_SESSION_COOKIE,
|
| 13 |
+
_SESSION_TTL_SEC,
|
| 14 |
+
_STATE_TTL_SEC,
|
| 15 |
+
get_user_from_server_session,
|
| 16 |
+
mount_oauth_routes,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(autouse=True)
|
| 21 |
+
def clean_sessions():
|
| 22 |
+
"""Clear session and state stores before each test."""
|
| 23 |
+
_sessions.clear()
|
| 24 |
+
_pending_states.clear()
|
| 25 |
+
yield
|
| 26 |
+
_sessions.clear()
|
| 27 |
+
_pending_states.clear()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Session store tests
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestSessionStore:
|
| 36 |
+
"""Tests for the server-side session store."""
|
| 37 |
+
|
| 38 |
+
def test_store_and_retrieve_session(self):
|
| 39 |
+
"""Test basic session CRUD."""
|
| 40 |
+
_sessions["abc123"] = {
|
| 41 |
+
"userinfo": {"preferred_username": "alice"},
|
| 42 |
+
"created_at": time.time(),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
request = MagicMock()
|
| 46 |
+
request.cookies = {_SESSION_COOKIE: "abc123"}
|
| 47 |
+
request.headers = {}
|
| 48 |
+
|
| 49 |
+
userinfo = get_user_from_server_session(request)
|
| 50 |
+
assert userinfo is not None
|
| 51 |
+
assert userinfo["preferred_username"] == "alice"
|
| 52 |
+
|
| 53 |
+
def test_missing_cookie_returns_none(self):
|
| 54 |
+
"""Test that missing cookie returns None."""
|
| 55 |
+
request = MagicMock()
|
| 56 |
+
request.cookies = {}
|
| 57 |
+
request.headers = {}
|
| 58 |
+
|
| 59 |
+
assert get_user_from_server_session(request) is None
|
| 60 |
+
|
| 61 |
+
def test_unknown_session_id_returns_none(self):
|
| 62 |
+
"""Test that unknown session ID returns None."""
|
| 63 |
+
request = MagicMock()
|
| 64 |
+
request.cookies = {_SESSION_COOKIE: "unknown"}
|
| 65 |
+
request.headers = {}
|
| 66 |
+
|
| 67 |
+
assert get_user_from_server_session(request) is None
|
| 68 |
+
|
| 69 |
+
def test_expired_session_returns_none(self):
|
| 70 |
+
"""Test that expired sessions are pruned."""
|
| 71 |
+
_sessions["expired"] = {
|
| 72 |
+
"userinfo": {"preferred_username": "bob"},
|
| 73 |
+
"created_at": time.time() - _SESSION_TTL_SEC - 1,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
request = MagicMock()
|
| 77 |
+
request.cookies = {_SESSION_COOKIE: "expired"}
|
| 78 |
+
request.headers = {}
|
| 79 |
+
|
| 80 |
+
assert get_user_from_server_session(request) is None
|
| 81 |
+
assert "expired" not in _sessions
|
| 82 |
+
|
| 83 |
+
def test_cookie_from_header_fallback(self):
|
| 84 |
+
"""Test parsing cookie from raw Cookie header."""
|
| 85 |
+
_sessions["header_val"] = {
|
| 86 |
+
"userinfo": {"preferred_username": "carol"},
|
| 87 |
+
"created_at": time.time(),
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
request = MagicMock()
|
| 91 |
+
request.cookies = {} # Empty dict, fallback to header
|
| 92 |
+
request.headers = {"cookie": f"other=x; {_SESSION_COOKIE}=header_val; foo=bar"}
|
| 93 |
+
|
| 94 |
+
userinfo = get_user_from_server_session(request)
|
| 95 |
+
assert userinfo is not None
|
| 96 |
+
assert userinfo["preferred_username"] == "carol"
|
| 97 |
+
|
| 98 |
+
def test_none_request_returns_none(self):
|
| 99 |
+
"""Test that None request returns None gracefully."""
|
| 100 |
+
# get_user_from_server_session checks hasattr, so pass an object
|
| 101 |
+
# without cookies or headers
|
| 102 |
+
assert get_user_from_server_session(None) is None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Prune tests
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class TestPruneExpired:
|
| 111 |
+
"""Tests for _prune_expired."""
|
| 112 |
+
|
| 113 |
+
def test_prune_expired_sessions(self):
|
| 114 |
+
"""Test that expired sessions are removed."""
|
| 115 |
+
_sessions["old"] = {
|
| 116 |
+
"userinfo": {"preferred_username": "old_user"},
|
| 117 |
+
"created_at": time.time() - _SESSION_TTL_SEC - 100,
|
| 118 |
+
}
|
| 119 |
+
_sessions["fresh"] = {
|
| 120 |
+
"userinfo": {"preferred_username": "new_user"},
|
| 121 |
+
"created_at": time.time(),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
_prune_expired()
|
| 125 |
+
|
| 126 |
+
assert "old" not in _sessions
|
| 127 |
+
assert "fresh" in _sessions
|
| 128 |
+
|
| 129 |
+
def test_prune_expired_states(self):
|
| 130 |
+
"""Test that expired CSRF states are removed."""
|
| 131 |
+
_pending_states["old_state"] = time.time() - _STATE_TTL_SEC - 100
|
| 132 |
+
_pending_states["fresh_state"] = time.time()
|
| 133 |
+
|
| 134 |
+
_prune_expired()
|
| 135 |
+
|
| 136 |
+
assert "old_state" not in _pending_states
|
| 137 |
+
assert "fresh_state" in _pending_states
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
# mount_oauth_routes tests
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class TestMountOAuthRoutes:
|
| 146 |
+
"""Tests for mount_oauth_routes."""
|
| 147 |
+
|
| 148 |
+
@patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "")
|
| 149 |
+
def test_no_routes_when_client_id_missing(self):
|
| 150 |
+
"""Test that routes are not mounted when OAUTH_CLIENT_ID is empty."""
|
| 151 |
+
app = MagicMock()
|
| 152 |
+
app.routes = []
|
| 153 |
+
|
| 154 |
+
mount_oauth_routes(app)
|
| 155 |
+
|
| 156 |
+
assert len(app.routes) == 0
|
| 157 |
+
|
| 158 |
+
@patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "test-client-id")
|
| 159 |
+
def test_routes_mounted_when_configured(self):
|
| 160 |
+
"""Test that routes are mounted when OAuth is configured."""
|
| 161 |
+
app = MagicMock()
|
| 162 |
+
app.routes = []
|
| 163 |
+
|
| 164 |
+
mount_oauth_routes(app)
|
| 165 |
+
|
| 166 |
+
assert len(app.routes) == 3
|
| 167 |
+
paths = [r.path for r in app.routes]
|
| 168 |
+
assert "/api/auth/login" in paths
|
| 169 |
+
assert "/api/auth/callback" in paths
|
| 170 |
+
assert "/api/auth/logout" in paths
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
# Login route tests
|
| 175 |
+
# ---------------------------------------------------------------------------
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TestLoginRoute:
|
| 179 |
+
"""Tests for the /api/auth/login route."""
|
| 180 |
+
|
| 181 |
+
@pytest.mark.asyncio
|
| 182 |
+
@patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "test-client-id")
|
| 183 |
+
@patch("mosaic.ui.oauth.SPACE_HOST", "user-space.hf.space")
|
| 184 |
+
async def test_login_redirects_to_hf(self):
|
| 185 |
+
"""Test that login redirects to HF authorize URL."""
|
| 186 |
+
from mosaic.ui.oauth import _login
|
| 187 |
+
|
| 188 |
+
request = MagicMock()
|
| 189 |
+
response = await _login(request)
|
| 190 |
+
|
| 191 |
+
assert response.status_code == 302
|
| 192 |
+
location = response.headers["location"]
|
| 193 |
+
assert "huggingface.co/oauth/authorize" in location
|
| 194 |
+
assert "client_id=test-client-id" in location
|
| 195 |
+
assert "redirect_uri=" in location
|
| 196 |
+
assert "state=" in location
|
| 197 |
+
# Should have stored a CSRF state
|
| 198 |
+
assert len(_pending_states) == 1
|
| 199 |
+
|
| 200 |
+
@pytest.mark.asyncio
|
| 201 |
+
@patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "")
|
| 202 |
+
async def test_login_returns_error_when_not_configured(self):
|
| 203 |
+
"""Test that login returns error when OAuth is not configured."""
|
| 204 |
+
from mosaic.ui.oauth import _login
|
| 205 |
+
|
| 206 |
+
request = MagicMock()
|
| 207 |
+
response = await _login(request)
|
| 208 |
+
|
| 209 |
+
assert response.status_code == 500
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ---------------------------------------------------------------------------
|
| 213 |
+
# Callback route tests
|
| 214 |
+
# ---------------------------------------------------------------------------
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class TestCallbackRoute:
|
| 218 |
+
"""Tests for the /api/auth/callback route."""
|
| 219 |
+
|
| 220 |
+
@pytest.mark.asyncio
|
| 221 |
+
async def test_callback_missing_params(self):
|
| 222 |
+
"""Test callback with missing code/state parameters."""
|
| 223 |
+
from mosaic.ui.oauth import _callback
|
| 224 |
+
|
| 225 |
+
request = MagicMock()
|
| 226 |
+
request.query_params = {}
|
| 227 |
+
|
| 228 |
+
response = await _callback(request)
|
| 229 |
+
assert response.status_code == 400
|
| 230 |
+
|
| 231 |
+
@pytest.mark.asyncio
|
| 232 |
+
async def test_callback_invalid_state(self):
|
| 233 |
+
"""Test callback with invalid CSRF state."""
|
| 234 |
+
from mosaic.ui.oauth import _callback
|
| 235 |
+
|
| 236 |
+
request = MagicMock()
|
| 237 |
+
request.query_params = {"code": "test-code", "state": "invalid-state"}
|
| 238 |
+
|
| 239 |
+
response = await _callback(request)
|
| 240 |
+
assert response.status_code == 400
|
| 241 |
+
|
| 242 |
+
@pytest.mark.asyncio
|
| 243 |
+
@patch("mosaic.ui.oauth.OAUTH_CLIENT_ID", "test-client-id")
|
| 244 |
+
@patch("mosaic.ui.oauth.OAUTH_CLIENT_SECRET", "test-secret")
|
| 245 |
+
@patch("mosaic.ui.oauth.SPACE_HOST", "user-space.hf.space")
|
| 246 |
+
async def test_callback_exchanges_code_for_token(self):
|
| 247 |
+
"""Test successful callback with mocked token exchange."""
|
| 248 |
+
import json
|
| 249 |
+
import base64
|
| 250 |
+
|
| 251 |
+
from mosaic.ui.oauth import _callback
|
| 252 |
+
|
| 253 |
+
# Set up a valid CSRF state
|
| 254 |
+
valid_state = "valid-state-token"
|
| 255 |
+
_pending_states[valid_state] = time.time()
|
| 256 |
+
|
| 257 |
+
# Create a mock id_token JWT
|
| 258 |
+
header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
|
| 259 |
+
payload_data = json.dumps(
|
| 260 |
+
{"preferred_username": "visitor123", "sub": "visitor123"}
|
| 261 |
+
)
|
| 262 |
+
payload = base64.urlsafe_b64encode(payload_data.encode()).rstrip(b"=").decode()
|
| 263 |
+
mock_id_token = f"{header}.{payload}.fake_signature"
|
| 264 |
+
|
| 265 |
+
# Mock httpx — json() is a regular method (not async) on httpx.Response
|
| 266 |
+
mock_response = AsyncMock()
|
| 267 |
+
mock_response.json = MagicMock(
|
| 268 |
+
return_value={
|
| 269 |
+
"access_token": "test-access-token",
|
| 270 |
+
"id_token": mock_id_token,
|
| 271 |
+
}
|
| 272 |
+
)
|
| 273 |
+
mock_response.raise_for_status = MagicMock()
|
| 274 |
+
|
| 275 |
+
mock_client = AsyncMock()
|
| 276 |
+
mock_client.post.return_value = mock_response
|
| 277 |
+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
| 278 |
+
mock_client.__aexit__ = AsyncMock(return_value=False)
|
| 279 |
+
|
| 280 |
+
request = MagicMock()
|
| 281 |
+
request.query_params = {"code": "auth-code", "state": valid_state}
|
| 282 |
+
request.session = {}
|
| 283 |
+
|
| 284 |
+
with patch("httpx.AsyncClient", return_value=mock_client):
|
| 285 |
+
response = await _callback(request)
|
| 286 |
+
|
| 287 |
+
# Should redirect to /
|
| 288 |
+
assert response.status_code == 302
|
| 289 |
+
|
| 290 |
+
# Session should have oauth_info
|
| 291 |
+
assert "oauth_info" in request.session
|
| 292 |
+
assert (
|
| 293 |
+
request.session["oauth_info"]["userinfo"]["preferred_username"]
|
| 294 |
+
== "visitor123"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Server-side session should be stored
|
| 298 |
+
assert len(_sessions) == 1
|
| 299 |
+
|
| 300 |
+
# CSRF state should be consumed
|
| 301 |
+
assert valid_state not in _pending_states
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# ---------------------------------------------------------------------------
|
| 305 |
+
# Logout route tests
|
| 306 |
+
# ---------------------------------------------------------------------------
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class TestLogoutRoute:
|
| 310 |
+
"""Tests for the /api/auth/logout route."""
|
| 311 |
+
|
| 312 |
+
@pytest.mark.asyncio
|
| 313 |
+
async def test_logout_clears_session(self):
|
| 314 |
+
"""Test that logout clears session data."""
|
| 315 |
+
from mosaic.ui.oauth import _logout
|
| 316 |
+
|
| 317 |
+
# Set up server-side session
|
| 318 |
+
_sessions["session_id"] = {
|
| 319 |
+
"userinfo": {"preferred_username": "user"},
|
| 320 |
+
"created_at": time.time(),
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
request = MagicMock()
|
| 324 |
+
request.session = {"oauth_info": {"userinfo": {"preferred_username": "user"}}}
|
| 325 |
+
request.cookies = {_SESSION_COOKIE: "session_id"}
|
| 326 |
+
|
| 327 |
+
response = await _logout(request)
|
| 328 |
+
|
| 329 |
+
assert response.status_code == 302
|
| 330 |
+
assert "oauth_info" not in request.session
|
| 331 |
+
assert "session_id" not in _sessions
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ---------------------------------------------------------------------------
|
| 335 |
+
# get_user_from_server_session integration
|
| 336 |
+
# ---------------------------------------------------------------------------
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class TestGetUserFromServerSession:
|
| 340 |
+
"""Integration tests for the server-session fallback."""
|
| 341 |
+
|
| 342 |
+
def test_works_with_gradio_request_like_object(self):
|
| 343 |
+
"""Test with an object mimicking gr.Request."""
|
| 344 |
+
_sessions["gr_session"] = {
|
| 345 |
+
"userinfo": {"preferred_username": "gradio_user", "name": "Test User"},
|
| 346 |
+
"created_at": time.time(),
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
class MockGradioRequest:
|
| 350 |
+
cookies = {_SESSION_COOKIE: "gr_session"}
|
| 351 |
+
headers = {}
|
| 352 |
+
|
| 353 |
+
result = get_user_from_server_session(MockGradioRequest())
|
| 354 |
+
assert result is not None
|
| 355 |
+
assert result["preferred_username"] == "gradio_user"
|
| 356 |
+
assert result["name"] == "Test User"
|
| 357 |
+
|
| 358 |
+
def test_returns_none_for_object_without_cookies(self):
|
| 359 |
+
"""Test graceful handling of objects without cookies."""
|
| 360 |
+
|
| 361 |
+
class Bare:
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
assert get_user_from_server_session(Bare()) is None
|
|
The diff for this file is too large to render.
See raw diff
|
|
|