Spaces:
Sleeping
Sleeping
github-actions[bot] commited on
Commit ·
dd79664
1
Parent(s): 502f69c
Sync from GitHub 247bfa13ed0af2a3e2eaf0193f136a55bee6daef
Browse files- auth/oauth.py +10 -2
- auth/session.py +24 -4
- tests/test_auth_mode.py +13 -0
- tests/test_hf_oauth_settings.py +14 -0
auth/oauth.py
CHANGED
|
@@ -13,6 +13,14 @@ class HFOAuthError(RuntimeError):
|
|
| 13 |
pass
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
@dataclass(frozen=True)
|
| 17 |
class HFOAuthSettings:
|
| 18 |
client_id: str
|
|
@@ -24,8 +32,8 @@ class HFOAuthSettings:
|
|
| 24 |
|
| 25 |
|
| 26 |
def get_hf_oauth_settings() -> HFOAuthSettings:
|
| 27 |
-
client_id =
|
| 28 |
-
client_secret =
|
| 29 |
if not client_id or not client_secret:
|
| 30 |
raise HFOAuthError("HF OAuth client configuration is missing.")
|
| 31 |
|
|
|
|
| 13 |
pass
|
| 14 |
|
| 15 |
|
| 16 |
+
def _first_env(*names: str) -> str:
|
| 17 |
+
for name in names:
|
| 18 |
+
value = os.getenv(name, "").strip()
|
| 19 |
+
if value:
|
| 20 |
+
return value
|
| 21 |
+
return ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
@dataclass(frozen=True)
|
| 25 |
class HFOAuthSettings:
|
| 26 |
client_id: str
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def get_hf_oauth_settings() -> HFOAuthSettings:
|
| 35 |
+
client_id = _first_env("HF_OAUTH_CLIENT_ID", "OAUTH_CLIENT_ID")
|
| 36 |
+
client_secret = _first_env("HF_OAUTH_CLIENT_SECRET", "OAUTH_CLIENT_SECRET")
|
| 37 |
if not client_id or not client_secret:
|
| 38 |
raise HFOAuthError("HF OAuth client configuration is missing.")
|
| 39 |
|
auth/session.py
CHANGED
|
@@ -30,6 +30,14 @@ class AuthBridgeTokenError(ValueError):
|
|
| 30 |
pass
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def get_auth_mode() -> str:
|
| 34 |
configured = os.getenv("AUTH_MODE", "").strip().lower()
|
| 35 |
if configured == AUTH_MODE_DEV:
|
|
@@ -37,16 +45,28 @@ def get_auth_mode() -> str:
|
|
| 37 |
if configured in {AUTH_MODE_HF, "oauth"}:
|
| 38 |
return AUTH_MODE_HF
|
| 39 |
if not configured or configured == "auto":
|
| 40 |
-
has_hf_client = bool(
|
| 41 |
-
has_hf_secret = bool(
|
| 42 |
if has_hf_client and has_hf_secret:
|
| 43 |
return AUTH_MODE_HF
|
| 44 |
return AUTH_MODE_DEV
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def configure_session_middleware(app) -> None:
|
| 48 |
"""Attach Starlette session middleware once during app setup."""
|
| 49 |
-
secret =
|
| 50 |
auth_mode = get_auth_mode()
|
| 51 |
if auth_mode == AUTH_MODE_HF and (not secret or secret == DEFAULT_DEV_SESSION_SECRET):
|
| 52 |
raise RuntimeError("APP_SESSION_SECRET must be set to a non-default value in hf_oauth mode.")
|
|
@@ -70,7 +90,7 @@ def configure_session_middleware(app) -> None:
|
|
| 70 |
|
| 71 |
|
| 72 |
def _bridge_serializer() -> URLSafeTimedSerializer:
|
| 73 |
-
secret =
|
| 74 |
return URLSafeTimedSerializer(secret_key=secret, salt=AUTH_BRIDGE_SALT)
|
| 75 |
|
| 76 |
|
|
|
|
| 30 |
pass
|
| 31 |
|
| 32 |
|
| 33 |
+
def _first_env(*names: str) -> str:
|
| 34 |
+
for name in names:
|
| 35 |
+
value = os.getenv(name, "").strip()
|
| 36 |
+
if value:
|
| 37 |
+
return value
|
| 38 |
+
return ""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
def get_auth_mode() -> str:
|
| 42 |
configured = os.getenv("AUTH_MODE", "").strip().lower()
|
| 43 |
if configured == AUTH_MODE_DEV:
|
|
|
|
| 45 |
if configured in {AUTH_MODE_HF, "oauth"}:
|
| 46 |
return AUTH_MODE_HF
|
| 47 |
if not configured or configured == "auto":
|
| 48 |
+
has_hf_client = bool(_first_env("HF_OAUTH_CLIENT_ID", "OAUTH_CLIENT_ID"))
|
| 49 |
+
has_hf_secret = bool(_first_env("HF_OAUTH_CLIENT_SECRET", "OAUTH_CLIENT_SECRET"))
|
| 50 |
if has_hf_client and has_hf_secret:
|
| 51 |
return AUTH_MODE_HF
|
| 52 |
return AUTH_MODE_DEV
|
| 53 |
|
| 54 |
|
| 55 |
+
def _resolve_session_secret() -> str:
|
| 56 |
+
configured = os.getenv("APP_SESSION_SECRET", "").strip()
|
| 57 |
+
if configured:
|
| 58 |
+
return configured
|
| 59 |
+
if get_auth_mode() == AUTH_MODE_HF:
|
| 60 |
+
# In HF Spaces OAuth deployments, OAUTH_CLIENT_SECRET is usually injected.
|
| 61 |
+
oauth_secret = _first_env("HF_OAUTH_CLIENT_SECRET", "OAUTH_CLIENT_SECRET")
|
| 62 |
+
if oauth_secret:
|
| 63 |
+
return oauth_secret
|
| 64 |
+
return DEFAULT_DEV_SESSION_SECRET
|
| 65 |
+
|
| 66 |
+
|
| 67 |
def configure_session_middleware(app) -> None:
|
| 68 |
"""Attach Starlette session middleware once during app setup."""
|
| 69 |
+
secret = _resolve_session_secret()
|
| 70 |
auth_mode = get_auth_mode()
|
| 71 |
if auth_mode == AUTH_MODE_HF and (not secret or secret == DEFAULT_DEV_SESSION_SECRET):
|
| 72 |
raise RuntimeError("APP_SESSION_SECRET must be set to a non-default value in hf_oauth mode.")
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def _bridge_serializer() -> URLSafeTimedSerializer:
|
| 93 |
+
secret = _resolve_session_secret()
|
| 94 |
return URLSafeTimedSerializer(secret_key=secret, salt=AUTH_BRIDGE_SALT)
|
| 95 |
|
| 96 |
|
tests/test_auth_mode.py
CHANGED
|
@@ -19,6 +19,8 @@ def test_get_auth_mode_auto_switches_to_hf(monkeypatch):
|
|
| 19 |
monkeypatch.setenv("AUTH_MODE", "auto")
|
| 20 |
monkeypatch.setenv("HF_OAUTH_CLIENT_ID", "client-id")
|
| 21 |
monkeypatch.setenv("HF_OAUTH_CLIENT_SECRET", "client-secret")
|
|
|
|
|
|
|
| 22 |
assert get_auth_mode() == AUTH_MODE_HF
|
| 23 |
|
| 24 |
|
|
@@ -26,4 +28,15 @@ def test_get_auth_mode_auto_falls_back_to_dev(monkeypatch):
|
|
| 26 |
monkeypatch.setenv("AUTH_MODE", "auto")
|
| 27 |
monkeypatch.delenv("HF_OAUTH_CLIENT_ID", raising=False)
|
| 28 |
monkeypatch.delenv("HF_OAUTH_CLIENT_SECRET", raising=False)
|
|
|
|
|
|
|
| 29 |
assert get_auth_mode() == AUTH_MODE_DEV
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
monkeypatch.setenv("AUTH_MODE", "auto")
|
| 20 |
monkeypatch.setenv("HF_OAUTH_CLIENT_ID", "client-id")
|
| 21 |
monkeypatch.setenv("HF_OAUTH_CLIENT_SECRET", "client-secret")
|
| 22 |
+
monkeypatch.delenv("OAUTH_CLIENT_ID", raising=False)
|
| 23 |
+
monkeypatch.delenv("OAUTH_CLIENT_SECRET", raising=False)
|
| 24 |
assert get_auth_mode() == AUTH_MODE_HF
|
| 25 |
|
| 26 |
|
|
|
|
| 28 |
monkeypatch.setenv("AUTH_MODE", "auto")
|
| 29 |
monkeypatch.delenv("HF_OAUTH_CLIENT_ID", raising=False)
|
| 30 |
monkeypatch.delenv("HF_OAUTH_CLIENT_SECRET", raising=False)
|
| 31 |
+
monkeypatch.delenv("OAUTH_CLIENT_ID", raising=False)
|
| 32 |
+
monkeypatch.delenv("OAUTH_CLIENT_SECRET", raising=False)
|
| 33 |
assert get_auth_mode() == AUTH_MODE_DEV
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_get_auth_mode_auto_switches_to_hf_with_space_oauth_vars(monkeypatch):
|
| 37 |
+
monkeypatch.setenv("AUTH_MODE", "auto")
|
| 38 |
+
monkeypatch.delenv("HF_OAUTH_CLIENT_ID", raising=False)
|
| 39 |
+
monkeypatch.delenv("HF_OAUTH_CLIENT_SECRET", raising=False)
|
| 40 |
+
monkeypatch.setenv("OAUTH_CLIENT_ID", "space-client-id")
|
| 41 |
+
monkeypatch.setenv("OAUTH_CLIENT_SECRET", "space-client-secret")
|
| 42 |
+
assert get_auth_mode() == AUTH_MODE_HF
|
tests/test_hf_oauth_settings.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from auth.oauth import get_hf_oauth_settings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_hf_oauth_settings_accept_space_oauth_vars(monkeypatch):
|
| 7 |
+
monkeypatch.delenv("HF_OAUTH_CLIENT_ID", raising=False)
|
| 8 |
+
monkeypatch.delenv("HF_OAUTH_CLIENT_SECRET", raising=False)
|
| 9 |
+
monkeypatch.setenv("OAUTH_CLIENT_ID", "space-client-id")
|
| 10 |
+
monkeypatch.setenv("OAUTH_CLIENT_SECRET", "space-client-secret")
|
| 11 |
+
|
| 12 |
+
settings = get_hf_oauth_settings()
|
| 13 |
+
assert settings.client_id == "space-client-id"
|
| 14 |
+
assert settings.client_secret == "space-client-secret"
|