File size: 5,623 Bytes
bef585f
 
 
 
 
 
 
 
c4b5267
 
bef585f
 
 
6ce1c19
bef585f
 
 
 
 
 
 
 
 
c4b5267
bef585f
 
 
908a76b
bef585f
 
 
c4b5267
bef585f
 
1845cc4
 
 
 
2bfd9f0
1845cc4
bef585f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4046997
3b4df92
 
 
feb1f78
bef585f
4046997
 
bef585f
6ce1c19
 
 
 
908a76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef585f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cdb208
ad43a34
bef585f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""OAuth support for AutoTrain.
Taken from: https://github.com/gradio-app/gradio/blob/main/gradio/oauth.py
"""

from __future__ import annotations

import hashlib
import os
import random
import string
import urllib.parse

import fastapi
from authlib.integrations.base_client.errors import MismatchingStateError
from authlib.integrations.starlette_client import OAuth
from fastapi.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware


OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
RANDOM_STRING = "".join(random.choices(string.ascii_letters + string.digits, k=20))


def attach_oauth(app: fastapi.FastAPI):
    _add_oauth_routes(app)
    # Session Middleware requires a secret key to sign the cookies. Let's use a hash
    # of the OAuth secret key to make it unique to the Space + updated in case OAuth
    # config gets updated.
    session_secret = OAUTH_CLIENT_SECRET + RANDOM_STRING
    # ^ if we change the session cookie format in the future, we can bump the version of the session secret to make
    #   sure cookies are invalidated. Otherwise some users with an old cookie format might get a HTTP 500 error.
    app.add_middleware(
        SessionMiddleware,
        secret_key=hashlib.sha256(session_secret.encode()).hexdigest(),
        https_only=True,
        same_site="none",
    )


def _add_oauth_routes(app: fastapi.FastAPI) -> None:
    """Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
    # Check environment variables
    msg = (
        "OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
        " setting `hf_oauth: true` in the Space metadata."
    )
    if OAUTH_CLIENT_ID is None:
        raise ValueError(msg.format("OAUTH_CLIENT_ID"))
    if OAUTH_CLIENT_SECRET is None:
        raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
    if OAUTH_SCOPES is None:
        raise ValueError(msg.format("OAUTH_SCOPES"))
    if OPENID_PROVIDER_URL is None:
        raise ValueError(msg.format("OPENID_PROVIDER_URL"))

    # Register OAuth server
    oauth = OAuth()
    oauth.register(
        name="huggingface",
        client_id=OAUTH_CLIENT_ID,
        client_secret=OAUTH_CLIENT_SECRET,
        client_kwargs={"scope": OAUTH_SCOPES},
        server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
    )

    # Define OAuth routes
    @app.get("/login/huggingface")
    async def oauth_login(request: fastapi.Request):
        """Endpoint that redirects to HF OAuth page."""
        redirect_uri = request.url_for("auth")
        redirect_uri_as_str = str(redirect_uri)
        if redirect_uri.netloc.endswith(".hf.space"):
            redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
        return await oauth.huggingface.authorize_redirect(request, redirect_uri_as_str)  # type: ignore

    @app.get("/auth")
    async def auth(request: fastapi.Request) -> RedirectResponse:
        """Endpoint that handles the OAuth callback."""
        # oauth_info = await oauth.huggingface.authorize_access_token(request)  # type: ignore
        try:
            oauth_info = await oauth.huggingface.authorize_access_token(request)  # type: ignore
        except MismatchingStateError:
            # If the state mismatch, it is very likely that the cookie is corrupted.
            # There is a bug reported in authlib that causes the token to grow indefinitely if the user tries to login
            # repeatedly. Since cookies cannot get bigger than 4kb, the token will be truncated at some point - hence
            # losing the state. A workaround is to delete the cookie and redirect the user to the login page again.
            # See https://github.com/lepture/authlib/issues/622 for more details.
            login_uri = "/login/huggingface"
            if "_target_url" in request.query_params:
                login_uri += "?" + urllib.parse.urlencode(  # Keep same _target_url as before
                    {"_target_url": request.query_params["_target_url"]}
                )
            for key in list(request.session.keys()):
                # Delete all keys that are related to the OAuth state
                if key.startswith("_state_huggingface"):
                    request.session.pop(key)
            return RedirectResponse(login_uri)
        request.session["oauth_info"] = oauth_info
        return _redirect_to_target(request)


def _generate_redirect_uri(request: fastapi.Request) -> str:
    if "_target_url" in request.query_params:
        # if `_target_url` already in query params => respect it
        target = request.query_params["_target_url"]
    else:
        # otherwise => keep query params
        target = "/?" + urllib.parse.urlencode(request.query_params)

    redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target)
    redirect_uri_as_str = str(redirect_uri)
    if redirect_uri.netloc.endswith(".hf.space"):
        # In Space, FastAPI redirect as http but we want https
        redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
    return redirect_uri_as_str


def _redirect_to_target(request: fastapi.Request, default_target: str = "/") -> RedirectResponse:
    # target = request.query_params.get("_target_url", default_target)
    target = "https://huggingface.co/spaces/" + os.environ.get("SPACE_ID")
    return RedirectResponse(target)