PoraHobe / app /__init__.py
SpreadSheets600's picture
Handle OAuth invalid_grant gracefully behind proxy
02426f0
import os
from typing import Optional
from flask import Flask, flash, redirect, url_for
from oauthlib.oauth2.rfc6749.errors import InvalidGrantError
from werkzeug.middleware.proxy_fix import ProxyFix
# Allow OAuth scope to change (e.g. Discord adding 'guilds.join') without raising an error
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
from flask_dance.consumer import oauth_authorized, oauth_error
from flask_dance.contrib.discord import make_discord_blueprint
from flask_dance.contrib.google import make_google_blueprint
from flask_login import current_user, login_user
from .extensions import db, login_manager, migrate
from .models import OAuth, User
def create_app():
app = Flask(__name__)
app.config.from_object("config.Config")
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)
app.config.setdefault("PREFERRED_URL_SCHEME", "https")
db.init_app(app)
migrate.init_app(app, db)
login_manager.init_app(app)
login_manager.login_view = "main.login"
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
from .blueprints.main import main_bp
from .blueprints.notes import notes_bp
from .blueprints.admin import admin_bp
app.register_blueprint(main_bp)
app.register_blueprint(notes_bp, url_prefix="/notes")
app.register_blueprint(admin_bp, url_prefix="/admin")
google_bp = make_google_blueprint(
client_id=app.config["GOOGLE_CLIENT_ID"],
client_secret=app.config["GOOGLE_CLIENT_SECRET"],
scope=[
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/userinfo.email",
"openid",
],
reprompt_consent=True,
)
app.register_blueprint(google_bp, url_prefix="/login")
discord_bp = make_discord_blueprint(
client_id=app.config["DISCORD_CLIENT_ID"],
client_secret=app.config["DISCORD_CLIENT_SECRET"],
scope=["identify", "email"],
)
app.register_blueprint(discord_bp, url_prefix="/login")
def _finish_login(
provider_name: str,
provider_user_id: str,
email: Optional[str],
name: Optional[str],
token: dict,
):
if not email:
return False
if current_user.is_authenticated:
user = current_user
oauth = OAuth.query.filter_by(
provider=provider_name, provider_user_id=provider_user_id
).first()
if not oauth:
oauth = OAuth(
provider=provider_name,
provider_user_id=provider_user_id,
token=token,
user_id=user.id,
)
db.session.add(oauth)
else:
oauth.token = token
db.session.commit()
return redirect(url_for("main.settings"))
oauth = OAuth.query.filter_by(
provider=provider_name,
provider_user_id=provider_user_id,
).first()
if oauth:
user = oauth.user
oauth.token = token
else:
user = User.query.filter_by(email=email).first()
if not user:
user = User(name=name or email.split("@")[0], email=email)
db.session.add(user)
db.session.flush()
oauth = OAuth(
provider=provider_name,
provider_user_id=provider_user_id,
token=token,
user=user,
)
db.session.add(oauth)
db.session.commit()
login_user(user)
return redirect(url_for("main.index"))
@oauth_authorized.connect_via(google_bp)
def google_logged_in(blueprint, token):
if not token:
return False
resp = blueprint.session.get("/oauth2/v2/userinfo")
if not resp.ok:
return False
info = resp.json()
google_user_id = str(info["id"])
if "refresh_token" in token:
existing = OAuth.query.filter_by(
provider=blueprint.name, provider_user_id=google_user_id
).first()
if existing:
existing.token["refresh_token"] = token["refresh_token"]
return _finish_login(
provider_name=blueprint.name,
provider_user_id=google_user_id,
email=info.get("email"),
name=info.get("name"),
token=token,
)
@oauth_authorized.connect_via(discord_bp)
def discord_logged_in(blueprint, token):
if not token:
return False
resp = blueprint.session.get("/api/users/@me")
if not resp.ok:
return False
info = resp.json()
discord_user_id = str(info["id"])
response = _finish_login(
provider_name=blueprint.name,
provider_user_id=discord_user_id,
email=info.get("email"),
name=info.get("username"),
token=token,
)
return response
@oauth_error.connect_via(google_bp)
@oauth_error.connect_via(discord_bp)
def oauth_error_handler(blueprint, error, error_description=None, error_uri=None):
message = f"{blueprint.name.title()} OAuth error: {error_description or error}"
app.logger.error(message)
flash(
"Login failed. Please try again. If it keeps happening, recheck OAuth callback URL and server time.",
"error",
)
return redirect(url_for("main.login"))
@app.errorhandler(InvalidGrantError)
def handle_invalid_grant(error):
app.logger.warning(f"OAuth invalid_grant: {error}")
flash(
"Google/Discord sign-in expired or mismatched. Please retry login once.",
"error",
)
return redirect(url_for("main.login"))
return app