wenjiao's picture
refactor repo code
b15b21e
"""Monkey-patch Gradio's OAuth callback to handle user denial /
errors gracefully (instead of returning HTTP 500).
Importing this module applies the patch as a side effect.
"""
from __future__ import annotations
import logging
from starlette.responses import RedirectResponse as _RedirectResponse
logger = logging.getLogger(__name__)
def apply_patch() -> None:
"""Wrap ``gradio.oauth._add_oauth_routes`` to clear OAuth state
and redirect on error."""
try:
import gradio.oauth as _gr_oauth
from starlette.routing import request_response as _starlette_request_response
_orig_add_oauth_routes = _gr_oauth._add_oauth_routes
def _patched_add_oauth_routes(app):
_orig_add_oauth_routes(app)
for _route in app.routes:
if getattr(_route, "path", None) != "/login/callback":
continue
_orig_endpoint = _route.endpoint
async def _safe_oauth_callback(request, _orig=_orig_endpoint):
target = request.query_params.get("_target_url") or "/"
err = request.query_params.get("error")
def _clear_oauth_state():
try:
for k in list(request.session.keys()):
if k.startswith("_state_huggingface"):
request.session.pop(k, None)
except Exception:
pass
if err:
logger.info(
"[Auth] OAuth declined (error=%s) — redirecting to %s",
err,
target,
)
_clear_oauth_state()
return _RedirectResponse(target)
try:
return await _orig(request)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"[Auth] OAuth callback failed: %s — redirecting to %s",
exc,
target,
)
_clear_oauth_state()
return _RedirectResponse(target)
_route.endpoint = _safe_oauth_callback
_route.app = _starlette_request_response(_safe_oauth_callback)
break
_gr_oauth._add_oauth_routes = _patched_add_oauth_routes
except Exception as _patch_exc: # pragma: no cover - defensive
logger.warning("[Auth] Could not patch Gradio OAuth callback: %s", _patch_exc)
apply_patch()