import asyncio import uvicorn from fastapi import FastAPI, Depends from starlette.responses import RedirectResponse from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client import OAuth, OAuthError from fastapi import Request import os from starlette.config import Config import gradio as gr from dotenv import load_dotenv from common import get_db from home import build_home_page from modules.models import SheamiUser from ui import get_app_theme, get_app_title, get_css, render_about_markdowns, render_logo, render_logo_small, render_selected_patient_actions load_dotenv() app = FastAPI() # OAuth settings GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID") GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET") SECRET_KEY = os.environ.get("AUTH_SECRET_KEY") # Set up OAuth config_data = {'GOOGLE_CLIENT_ID': GOOGLE_OAUTH_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_OAUTH_CLIENT_SECRET} starlette_config = Config(environ=config_data) oauth = OAuth(starlette_config) oauth.register( name='google', server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', client_kwargs={'scope': 'openid email profile'}, ) app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) # Dependency to get the current user def get_user(request: Request): user = request.session.get('user') if user: return user['name'] return None @app.get('/') def public(request: Request, user = Depends(get_user)): root_url = gr.route_utils.get_root_url(request, "/", None) if user: return RedirectResponse(url=f'{root_url}/home/') else: return RedirectResponse(url=f'{root_url}/main/') @app.route('/logout') async def logout(request: Request): request.session.pop('user', None) return RedirectResponse(url='/') @app.route('/login') async def login(request: Request): root_url = gr.route_utils.get_root_url(request, "/login", None) redirect_uri = f"{root_url}/auth" print("Redirecting to", redirect_uri) return await oauth.google.authorize_redirect(request, redirect_uri) @app.route('/auth') async def auth(request: Request): try: access_token = await oauth.google.authorize_access_token(request) except OAuthError: print("Error getting access token", str(OAuthError)) return RedirectResponse(url='/') request.session['user'] = dict(access_token)["userinfo"] print("Redirecting to /home") return RedirectResponse(url='/home') with gr.Blocks() as login_demo: render_logo_small() with gr.Row(): gr.Column() render_about_markdowns() gr.Column() with gr.Row(): gr.Column() btn = gr.Button("Proceed", variant="huggingface", scale=0) gr.Column() _js_redirect = """ () => { url = '/login' + window.location.search; window.open(url, '_blank'); } """ btn.click(None, js=_js_redirect) app = gr.mount_gradio_app(app, login_demo, path="/main") async def register_user(logged_in_user: SheamiUser): user = await get_db().get_user_by_email(email=logged_in_user.email) if not user: await get_db().add_user(email=logged_in_user.email, name=logged_in_user.name) def get_sheami_user(request: gr.Request): if request is None: return None try: picture = f"{request.request.session["user"]["picture"]}" except: picture = "assets/user.png" return SheamiUser( email=f"{request.request.session["user"]["email"]}", name=f"{request.username}", picture_url=picture, ) def get_loggedin_user_name(request: gr.Request): user = get_sheami_user(request) if user is None: return None else: return user.name def get_loggedin_user_email(request: gr.Request): user = get_sheami_user(request) if user is None: return None else: return user.email async def build_securely(): with gr.Blocks( title=get_app_title(), theme=get_app_theme(), css=get_css(), fill_height=True ) as demo: # Top menu bar with logo and login/logout buttons with gr.Row(): with gr.Column(scale=4): render_logo() with gr.Column(scale=1): with gr.Group(): gr.Button("Logout", link="/logout", variant="huggingface") logged_in_user_name = gr.Markdown(elem_classes="text-center") logged_in_user_email = gr.Markdown(elem_classes="text-center") gr.Markdown("---") logged_in_sheami_user = gr.State() with gr.Column(elem_id="patient-card") as patient_card: # gr.Markdown("### Selected Patient") ( selected_patient_info, delete_patient_btn, edit_patient_btn, upload_reports_btn, add_vitals_btn, ) = render_selected_patient_actions() @gr.render(inputs=logged_in_sheami_user) def render_home_page(user: SheamiUser | None): if user: asyncio.run(register_user(logged_in_user=user)) build_home_page( logged_in_user=user, selected_patient_info=selected_patient_info, delete_patient_btn=delete_patient_btn, edit_patient_btn=edit_patient_btn, upload_reports_btn=upload_reports_btn, add_vitals_btn=add_vitals_btn, ) else: pass demo.load(get_loggedin_user_name, inputs=None, outputs=logged_in_user_name) demo.load(get_sheami_user, inputs=None, outputs=logged_in_sheami_user) demo.load(get_loggedin_user_email, inputs=None, outputs=logged_in_user_email) # demo.load(list_organizations, inputs=None, outputs=m2) return demo sheami_app = asyncio.run(build_securely()).queue() app = gr.mount_gradio_app(app, sheami_app, path="/home", auth_dependency=get_user) if __name__ == '__main__': uvicorn.run(app,host="0.0.0.0",port=7860)