sheami / app.py
vikramvasudevan's picture
Update app.py
8b14696 verified
import asyncio
import uvicorn
from fastapi import FastAPI, Depends
from starlette.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import Request
import os
from starlette.config import Config
import gradio as gr
from dotenv import load_dotenv
from common import get_db
from home import build_home_page
from modules.models import SheamiUser
from ui import get_app_theme, get_app_title, get_css, render_about_markdowns, render_logo, render_logo_small, render_selected_patient_actions
load_dotenv()
app = FastAPI()
# OAuth settings
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID")
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET")
SECRET_KEY = os.environ.get("AUTH_SECRET_KEY")
# Set up OAuth
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_OAUTH_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_OAUTH_CLIENT_SECRET}
starlette_config = Config(environ=config_data)
oauth = OAuth(starlette_config)
oauth.register(
name='google',
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={'scope': 'openid email profile'},
)
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
# Dependency to get the current user
def get_user(request: Request):
user = request.session.get('user')
if user:
return user['name']
return None
@app.get('/')
def public(request: Request, user = Depends(get_user)):
root_url = gr.route_utils.get_root_url(request, "/", None)
if user:
return RedirectResponse(url=f'{root_url}/home/')
else:
return RedirectResponse(url=f'{root_url}/main/')
@app.route('/logout')
async def logout(request: Request):
request.session.pop('user', None)
return RedirectResponse(url='/')
@app.route('/login')
async def login(request: Request):
root_url = gr.route_utils.get_root_url(request, "/login", None)
redirect_uri = f"{root_url}/auth"
print("Redirecting to", redirect_uri)
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.route('/auth')
async def auth(request: Request):
try:
access_token = await oauth.google.authorize_access_token(request)
except OAuthError:
print("Error getting access token", str(OAuthError))
return RedirectResponse(url='/')
request.session['user'] = dict(access_token)["userinfo"]
print("Redirecting to /home")
return RedirectResponse(url='/home')
with gr.Blocks() as login_demo:
render_logo_small()
with gr.Row():
gr.Column()
render_about_markdowns()
gr.Column()
with gr.Row():
gr.Column()
btn = gr.Button("Proceed", variant="huggingface", scale=0)
gr.Column()
_js_redirect = """
() => {
url = '/login' + window.location.search;
window.open(url, '_blank');
}
"""
btn.click(None, js=_js_redirect)
app = gr.mount_gradio_app(app, login_demo, path="/main")
async def register_user(logged_in_user: SheamiUser):
user = await get_db().get_user_by_email(email=logged_in_user.email)
if not user:
await get_db().add_user(email=logged_in_user.email, name=logged_in_user.name)
def get_sheami_user(request: gr.Request):
if request is None:
return None
try:
picture = f"{request.request.session["user"]["picture"]}"
except:
picture = "assets/user.png"
return SheamiUser(
email=f"{request.request.session["user"]["email"]}",
name=f"{request.username}",
picture_url=picture,
)
def get_loggedin_user_name(request: gr.Request):
user = get_sheami_user(request)
if user is None:
return None
else:
return user.name
def get_loggedin_user_email(request: gr.Request):
user = get_sheami_user(request)
if user is None:
return None
else:
return user.email
async def build_securely():
with gr.Blocks(
title=get_app_title(), theme=get_app_theme(), css=get_css(), fill_height=True
) as demo:
# Top menu bar with logo and login/logout buttons
with gr.Row():
with gr.Column(scale=4):
render_logo()
with gr.Column(scale=1):
with gr.Group():
gr.Button("Logout", link="/logout", variant="huggingface")
logged_in_user_name = gr.Markdown(elem_classes="text-center")
logged_in_user_email = gr.Markdown(elem_classes="text-center")
gr.Markdown("---")
logged_in_sheami_user = gr.State()
with gr.Column(elem_id="patient-card") as patient_card:
# gr.Markdown("### Selected Patient")
(
selected_patient_info,
delete_patient_btn,
edit_patient_btn,
upload_reports_btn,
add_vitals_btn,
) = render_selected_patient_actions()
@gr.render(inputs=logged_in_sheami_user)
def render_home_page(user: SheamiUser | None):
if user:
asyncio.run(register_user(logged_in_user=user))
build_home_page(
logged_in_user=user,
selected_patient_info=selected_patient_info,
delete_patient_btn=delete_patient_btn,
edit_patient_btn=edit_patient_btn,
upload_reports_btn=upload_reports_btn,
add_vitals_btn=add_vitals_btn,
)
else:
pass
demo.load(get_loggedin_user_name, inputs=None, outputs=logged_in_user_name)
demo.load(get_sheami_user, inputs=None, outputs=logged_in_sheami_user)
demo.load(get_loggedin_user_email, inputs=None, outputs=logged_in_user_email)
# demo.load(list_organizations, inputs=None, outputs=m2)
return demo
sheami_app = asyncio.run(build_securely()).queue()
app = gr.mount_gradio_app(app, sheami_app, path="/home", auth_dependency=get_user)
if __name__ == '__main__':
uvicorn.run(app,host="0.0.0.0",port=7860)