| | import uvicorn |
| | from fastapi import FastAPI, Depends |
| | from starlette.responses import RedirectResponse |
| | from starlette.middleware.sessions import SessionMiddleware |
| | from authlib.integrations.starlette_client import OAuth, OAuthError |
| | from fastapi import Request |
| | import os |
| | from starlette.config import Config |
| | import gradio as gr |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID") |
| | GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") |
| | SECRET_KEY = os.environ.get("SECRET_KEY") |
| |
|
| | |
| | config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} |
| | starlette_config = Config(environ=config_data) |
| | oauth = OAuth(starlette_config) |
| | oauth.register( |
| | name='google', |
| | server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', |
| | client_kwargs={'scope': 'openid email profile'}, |
| | ) |
| |
|
| | app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
| |
|
| | |
| | def get_user(request: Request): |
| | user = request.session.get('user') |
| | print ("User", user) |
| | if user and user['email'].endswith("@zalando.de"): |
| | return user['name'] |
| | return None |
| |
|
| | @app.get('/') |
| | def public(request: Request, user = Depends(get_user)): |
| | root_url = gr.route_utils.get_root_url(request, "/", None) |
| | if user: |
| | return RedirectResponse(url=f'{root_url}/gradio/') |
| | else: |
| | return RedirectResponse(url=f'{root_url}/main/') |
| |
|
| | @app.route('/logout') |
| | async def logout(request: Request): |
| | request.session.pop('user', None) |
| | return RedirectResponse(url='/') |
| |
|
| | @app.route('/login') |
| | async def login(request: Request): |
| | root_url = gr.route_utils.get_root_url(request, "/login", None) |
| | redirect_uri = f"{root_url}/auth" |
| | print("Redirecting to", redirect_uri) |
| | return await oauth.google.authorize_redirect(request, redirect_uri) |
| |
|
| | @app.route('/auth') |
| | async def auth(request: Request): |
| | try: |
| | access_token = await oauth.google.authorize_access_token(request) |
| | except OAuthError: |
| | print("Error getting access token", str(OAuthError)) |
| | return RedirectResponse(url='/') |
| | request.session['user'] = dict(access_token)["userinfo"] |
| | print("Redirecting to /gradio") |
| | return RedirectResponse(url='/gradio') |
| |
|
| | with gr.Blocks() as login_demo: |
| | btn = gr.Button("Login") |
| | _js_redirect = """ |
| | () => { |
| | url = '/login' + window.location.search; |
| | window.open(url, '_blank'); |
| | } |
| | """ |
| | btn.click(None, js=_js_redirect) |
| |
|
| | app = gr.mount_gradio_app(app, login_demo, path="/main") |
| |
|
| | def greet(request: gr.Request): |
| | return f"Welcome to Gradio, {request.username}" |
| |
|
| | with gr.Blocks() as main_demo: |
| | m = gr.Markdown("Welcome to Gradio!") |
| | gr.Button("Logout", link="/logout") |
| | main_demo.load(greet, None, m) |
| |
|
| | app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | uvicorn.run(app) |
| |
|