Spaces:
Running
Running
| import datetime | |
| import hashlib | |
| import logging | |
| import os | |
| import time | |
| import urllib.parse | |
| import warnings | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union | |
| from . import constants | |
| from .hf_api import whoami | |
| from .utils import experimental, get_token | |
| logger = logging.getLogger(__name__) | |
| if TYPE_CHECKING: | |
| import fastapi | |
| class OAuthOrgInfo: | |
| """ | |
| Information about an organization linked to a user logged in with OAuth. | |
| Attributes: | |
| sub (`str`): | |
| Unique identifier for the org. OpenID Connect field. | |
| name (`str`): | |
| The org's full name. OpenID Connect field. | |
| preferred_username (`str`): | |
| The org's username. OpenID Connect field. | |
| picture (`str`): | |
| The org's profile picture URL. OpenID Connect field. | |
| is_enterprise (`bool`): | |
| Whether the org is an enterprise org. Hugging Face field. | |
| can_pay (`Optional[bool]`, *optional*): | |
| Whether the org has a payment method set up. Hugging Face field. | |
| role_in_org (`Optional[str]`, *optional*): | |
| The user's role in the org. Hugging Face field. | |
| security_restrictions (`Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*): | |
| Array of security restrictions that the user hasn't completed for this org. Possible values: "ip", "token-policy", "mfa", "sso". Hugging Face field. | |
| """ | |
| sub: str | |
| name: str | |
| preferred_username: str | |
| picture: str | |
| is_enterprise: bool | |
| can_pay: Optional[bool] = None | |
| role_in_org: Optional[str] = None | |
| security_restrictions: Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]] = None | |
| class OAuthUserInfo: | |
| """ | |
| Information about a user logged in with OAuth. | |
| Attributes: | |
| sub (`str`): | |
| Unique identifier for the user, even in case of rename. OpenID Connect field. | |
| name (`str`): | |
| The user's full name. OpenID Connect field. | |
| preferred_username (`str`): | |
| The user's username. OpenID Connect field. | |
| email_verified (`Optional[bool]`, *optional*): | |
| Indicates if the user's email is verified. OpenID Connect field. | |
| email (`Optional[str]`, *optional*): | |
| The user's email address. OpenID Connect field. | |
| picture (`str`): | |
| The user's profile picture URL. OpenID Connect field. | |
| profile (`str`): | |
| The user's profile URL. OpenID Connect field. | |
| website (`Optional[str]`, *optional*): | |
| The user's website URL. OpenID Connect field. | |
| is_pro (`bool`): | |
| Whether the user is a pro user. Hugging Face field. | |
| can_pay (`Optional[bool]`, *optional*): | |
| Whether the user has a payment method set up. Hugging Face field. | |
| orgs (`Optional[List[OrgInfo]]`, *optional*): | |
| List of organizations the user is part of. Hugging Face field. | |
| """ | |
| sub: str | |
| name: str | |
| preferred_username: str | |
| email_verified: Optional[bool] | |
| email: Optional[str] | |
| picture: str | |
| profile: str | |
| website: Optional[str] | |
| is_pro: bool | |
| can_pay: Optional[bool] | |
| orgs: Optional[List[OAuthOrgInfo]] | |
| class OAuthInfo: | |
| """ | |
| Information about the OAuth login. | |
| Attributes: | |
| access_token (`str`): | |
| The access token. | |
| access_token_expires_at (`datetime.datetime`): | |
| The expiration date of the access token. | |
| user_info ([`OAuthUserInfo`]): | |
| The user information. | |
| state (`str`, *optional*): | |
| State passed to the OAuth provider in the original request to the OAuth provider. | |
| scope (`str`): | |
| Granted scope. | |
| """ | |
| access_token: str | |
| access_token_expires_at: datetime.datetime | |
| user_info: OAuthUserInfo | |
| state: Optional[str] | |
| scope: str | |
| def attach_huggingface_oauth(app: "fastapi.FastAPI", route_prefix: str = "/"): | |
| """ | |
| Add OAuth endpoints to a FastAPI app to enable OAuth login with Hugging Face. | |
| How to use: | |
| - Call this method on your FastAPI app to add the OAuth endpoints. | |
| - Inside your route handlers, call `parse_huggingface_oauth(request)` to retrieve the OAuth info. | |
| - If user is logged in, an [`OAuthInfo`] object is returned with the user's info. If not, `None` is returned. | |
| - In your app, make sure to add links to `/oauth/huggingface/login` and `/oauth/huggingface/logout` for the user to log in and out. | |
| Example: | |
| ```py | |
| from huggingface_hub import attach_huggingface_oauth, parse_huggingface_oauth | |
| # Create a FastAPI app | |
| app = FastAPI() | |
| # Add OAuth endpoints to the FastAPI app | |
| attach_huggingface_oauth(app) | |
| # Add a route that greets the user if they are logged in | |
| @app.get("/") | |
| def greet_json(request: Request): | |
| # Retrieve the OAuth info from the request | |
| oauth_info = parse_huggingface_oauth(request) # e.g. OAuthInfo dataclass | |
| if oauth_info is None: | |
| return {"msg": "Not logged in!"} | |
| return {"msg": f"Hello, {oauth_info.user_info.preferred_username}!"} | |
| ``` | |
| """ | |
| # TODO: handle generic case (handling OAuth in a non-Space environment with custom dev values) (low priority) | |
| # Add SessionMiddleware to the FastAPI app to store the OAuth info in the session. | |
| # Session Middleware requires a secret key to sign the cookies. Let's use a hash | |
| # of the OAuth secret key to make it unique to the Space + updated in case OAuth | |
| # config gets updated. When ran locally, we use an empty string as a secret key. | |
| try: | |
| from starlette.middleware.sessions import SessionMiddleware | |
| except ImportError as e: | |
| raise ImportError( | |
| "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " | |
| "`huggingface_hub[oauth]` to your requirements.txt file in order to install the required dependencies." | |
| ) from e | |
| session_secret = (constants.OAUTH_CLIENT_SECRET or "") + "-v1" | |
| app.add_middleware( | |
| SessionMiddleware, # type: ignore[arg-type] | |
| secret_key=hashlib.sha256(session_secret.encode()).hexdigest(), | |
| same_site="none", | |
| https_only=True, | |
| ) # type: ignore | |
| # Add OAuth endpoints to the FastAPI app: | |
| # - {route_prefix}/oauth/huggingface/login | |
| # - {route_prefix}/oauth/huggingface/callback | |
| # - {route_prefix}/oauth/huggingface/logout | |
| # If the app is running in a Space, OAuth is enabled normally. | |
| # Otherwise, we mock the endpoints to make the user log in with a fake user profile - without any calls to hf.co. | |
| route_prefix = route_prefix.strip("/") | |
| if os.getenv("SPACE_ID") is not None: | |
| logger.info("OAuth is enabled in the Space. Adding OAuth routes.") | |
| _add_oauth_routes(app, route_prefix=route_prefix) | |
| else: | |
| logger.info("App is not running in a Space. Adding mocked OAuth routes.") | |
| _add_mocked_oauth_routes(app, route_prefix=route_prefix) | |
| def parse_huggingface_oauth(request: "fastapi.Request") -> Optional[OAuthInfo]: | |
| """ | |
| Returns the information from a logged in user as a [`OAuthInfo`] object. | |
| For flexibility and future-proofing, this method is very lax in its parsing and does not raise errors. | |
| Missing fields are set to `None` without a warning. | |
| Return `None`, if the user is not logged in (no info in session cookie). | |
| See [`attach_huggingface_oauth`] for an example on how to use this method. | |
| """ | |
| if "oauth_info" not in request.session: | |
| logger.debug("No OAuth info in session.") | |
| return None | |
| logger.debug("Parsing OAuth info from session.") | |
| oauth_data = request.session["oauth_info"] | |
| user_data = oauth_data.get("userinfo", {}) | |
| orgs_data = user_data.get("orgs", []) | |
| orgs = ( | |
| [ | |
| OAuthOrgInfo( | |
| sub=org.get("sub"), | |
| name=org.get("name"), | |
| preferred_username=org.get("preferred_username"), | |
| picture=org.get("picture"), | |
| is_enterprise=org.get("isEnterprise"), | |
| can_pay=org.get("canPay"), | |
| role_in_org=org.get("roleInOrg"), | |
| security_restrictions=org.get("securityRestrictions"), | |
| ) | |
| for org in orgs_data | |
| ] | |
| if orgs_data | |
| else None | |
| ) | |
| user_info = OAuthUserInfo( | |
| sub=user_data.get("sub"), | |
| name=user_data.get("name"), | |
| preferred_username=user_data.get("preferred_username"), | |
| email_verified=user_data.get("email_verified"), | |
| email=user_data.get("email"), | |
| picture=user_data.get("picture"), | |
| profile=user_data.get("profile"), | |
| website=user_data.get("website"), | |
| is_pro=user_data.get("isPro"), | |
| can_pay=user_data.get("canPay"), | |
| orgs=orgs, | |
| ) | |
| return OAuthInfo( | |
| access_token=oauth_data.get("access_token"), | |
| access_token_expires_at=datetime.datetime.fromtimestamp(oauth_data.get("expires_at")), | |
| user_info=user_info, | |
| state=oauth_data.get("state"), | |
| scope=oauth_data.get("scope"), | |
| ) | |
| def _add_oauth_routes(app: "fastapi.FastAPI", route_prefix: str) -> None: | |
| """Add OAuth routes to the FastAPI app (login, callback handler and logout).""" | |
| try: | |
| import fastapi | |
| from authlib.integrations.base_client.errors import MismatchingStateError | |
| from authlib.integrations.starlette_client import OAuth | |
| from fastapi.responses import RedirectResponse | |
| except ImportError as e: | |
| raise ImportError( | |
| "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " | |
| "`huggingface_hub[oauth]` to your requirements.txt file." | |
| ) from e | |
| # Check environment variables | |
| msg = ( | |
| "OAuth is required but '{}' environment variable is not set. Make sure you've enabled OAuth in your Space by" | |
| " setting `hf_oauth: true` in the Space metadata." | |
| ) | |
| if constants.OAUTH_CLIENT_ID is None: | |
| raise ValueError(msg.format("OAUTH_CLIENT_ID")) | |
| if constants.OAUTH_CLIENT_SECRET is None: | |
| raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) | |
| if constants.OAUTH_SCOPES is None: | |
| raise ValueError(msg.format("OAUTH_SCOPES")) | |
| if constants.OPENID_PROVIDER_URL is None: | |
| raise ValueError(msg.format("OPENID_PROVIDER_URL")) | |
| # Register OAuth server | |
| oauth = OAuth() | |
| oauth.register( | |
| name="huggingface", | |
| client_id=constants.OAUTH_CLIENT_ID, | |
| client_secret=constants.OAUTH_CLIENT_SECRET, | |
| client_kwargs={"scope": constants.OAUTH_SCOPES}, | |
| server_metadata_url=constants.OPENID_PROVIDER_URL + "/.well-known/openid-configuration", | |
| ) | |
| login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) | |
| # Register OAuth endpoints | |
| async def oauth_login(request: fastapi.Request) -> RedirectResponse: | |
| """Endpoint that redirects to HF OAuth page.""" | |
| redirect_uri = _generate_redirect_uri(request) | |
| return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore | |
| async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: | |
| """Endpoint that handles the OAuth callback.""" | |
| try: | |
| oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore | |
| except MismatchingStateError: | |
| # Parse query params | |
| nb_redirects = int(request.query_params.get("_nb_redirects", 0)) | |
| target_url = request.query_params.get("_target_url") | |
| # Build redirect URI with the same query params as before and bump nb_redirects count | |
| query_params: Dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1} | |
| if target_url: | |
| query_params["_target_url"] = target_url | |
| redirect_uri = f"{login_uri}?{urllib.parse.urlencode(query_params)}" | |
| # If the user is redirected more than 3 times, it is very likely that the cookie is not working properly. | |
| # (e.g. browser is blocking third-party cookies in iframe). In this case, redirect the user in the | |
| # non-iframe view. | |
| if nb_redirects > constants.OAUTH_MAX_REDIRECTS: | |
| host = os.environ.get("SPACE_HOST") | |
| if host is None: # cannot happen in a Space | |
| raise RuntimeError( | |
| "App is not running in a Space (SPACE_HOST environment variable is not set). Cannot redirect to non-iframe view." | |
| ) from None | |
| host_url = "https://" + host.rstrip("/") | |
| return RedirectResponse(host_url + redirect_uri) | |
| # Redirect the user to the login page again | |
| return RedirectResponse(redirect_uri) | |
| # OAuth login worked => store the user info in the session and redirect | |
| logger.debug("Successfully logged in with OAuth. Storing user info in session.") | |
| request.session["oauth_info"] = oauth_info | |
| return RedirectResponse(_get_redirect_target(request)) | |
| async def oauth_logout(request: fastapi.Request) -> RedirectResponse: | |
| """Endpoint that logs out the user (e.g. delete info from cookie session).""" | |
| logger.debug("Logged out with OAuth. Removing user info from session.") | |
| request.session.pop("oauth_info", None) | |
| return RedirectResponse(_get_redirect_target(request)) | |
| def _add_mocked_oauth_routes(app: "fastapi.FastAPI", route_prefix: str = "/") -> None: | |
| """Add fake oauth routes if app is run locally and OAuth is enabled. | |
| Using OAuth will have the same behavior as in a Space but instead of authenticating with HF, a mocked user profile | |
| is added to the session. | |
| """ | |
| try: | |
| import fastapi | |
| from fastapi.responses import RedirectResponse | |
| from starlette.datastructures import URL | |
| except ImportError as e: | |
| raise ImportError( | |
| "Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " | |
| "`huggingface_hub[oauth]` to your requirements.txt file." | |
| ) from e | |
| warnings.warn( | |
| "OAuth is not supported outside of a Space environment. To help you debug your app locally, the oauth endpoints" | |
| " are mocked to return your profile and token. To make it work, your machine must be logged in to Huggingface." | |
| ) | |
| mocked_oauth_info = _get_mocked_oauth_info() | |
| login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) | |
| # Define OAuth routes | |
| async def oauth_login(request: fastapi.Request) -> RedirectResponse: | |
| """Fake endpoint that redirects to HF OAuth page.""" | |
| # Define target (where to redirect after login) | |
| redirect_uri = _generate_redirect_uri(request) | |
| return RedirectResponse(callback_uri + "?" + urllib.parse.urlencode({"_target_url": redirect_uri})) | |
| async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: | |
| """Endpoint that handles the OAuth callback.""" | |
| request.session["oauth_info"] = mocked_oauth_info | |
| return RedirectResponse(_get_redirect_target(request)) | |
| async def oauth_logout(request: fastapi.Request) -> RedirectResponse: | |
| """Endpoint that logs out the user (e.g. delete cookie session).""" | |
| request.session.pop("oauth_info", None) | |
| logout_url = URL("/").include_query_params(**request.query_params) | |
| return RedirectResponse(url=logout_url, status_code=302) # see https://github.com/gradio-app/gradio/pull/9659 | |
| def _generate_redirect_uri(request: "fastapi.Request") -> str: | |
| if "_target_url" in request.query_params: | |
| # if `_target_url` already in query params => respect it | |
| target = request.query_params["_target_url"] | |
| else: | |
| # otherwise => keep query params | |
| target = "/?" + urllib.parse.urlencode(request.query_params) | |
| redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target) | |
| redirect_uri_as_str = str(redirect_uri) | |
| if redirect_uri.netloc.endswith(".hf.space"): | |
| # In Space, FastAPI redirect as http but we want https | |
| redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://") | |
| return redirect_uri_as_str | |
| def _get_redirect_target(request: "fastapi.Request", default_target: str = "/") -> str: | |
| return request.query_params.get("_target_url", default_target) | |
| def _get_mocked_oauth_info() -> Dict: | |
| token = get_token() | |
| if token is None: | |
| raise ValueError( | |
| "Your machine must be logged in to HF to debug an OAuth app locally. Please" | |
| " run `hf auth login` or set `HF_TOKEN` as environment variable " | |
| "with one of your access token. You can generate a new token in your " | |
| "settings page (https://huggingface.co/settings/tokens)." | |
| ) | |
| user = whoami() | |
| if user["type"] != "user": | |
| raise ValueError( | |
| "Your machine is not logged in with a personal account. Please use a " | |
| "personal access token. You can generate a new token in your settings page" | |
| " (https://huggingface.co/settings/tokens)." | |
| ) | |
| return { | |
| "access_token": token, | |
| "token_type": "bearer", | |
| "expires_in": 8 * 60 * 60, # 8 hours | |
| "id_token": "FOOBAR", | |
| "scope": "openid profile", | |
| "refresh_token": "hf_oauth__refresh_token", | |
| "expires_at": int(time.time()) + 8 * 60 * 60, # 8 hours | |
| "userinfo": { | |
| "sub": "0123456789", | |
| "name": user["fullname"], | |
| "preferred_username": user["name"], | |
| "profile": f"https://huggingface.co/{user['name']}", | |
| "picture": user["avatarUrl"], | |
| "website": "", | |
| "aud": "00000000-0000-0000-0000-000000000000", | |
| "auth_time": 1691672844, | |
| "nonce": "aaaaaaaaaaaaaaaaaaa", | |
| "iat": 1691672844, | |
| "exp": 1691676444, | |
| "iss": "https://huggingface.co", | |
| }, | |
| } | |
| def _get_oauth_uris(route_prefix: str = "/") -> Tuple[str, str, str]: | |
| route_prefix = route_prefix.strip("/") | |
| if route_prefix: | |
| route_prefix = f"/{route_prefix}" | |
| return ( | |
| f"{route_prefix}/oauth/huggingface/login", | |
| f"{route_prefix}/oauth/huggingface/callback", | |
| f"{route_prefix}/oauth/huggingface/logout", | |
| ) | |