| """Monkey-patch Gradio's OAuth callback to handle user denial / |
| errors gracefully (instead of returning HTTP 500). |
| |
| Importing this module applies the patch as a side effect. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
|
|
| from starlette.responses import RedirectResponse as _RedirectResponse |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def apply_patch() -> None: |
| """Wrap ``gradio.oauth._add_oauth_routes`` to clear OAuth state |
| and redirect on error.""" |
| try: |
| import gradio.oauth as _gr_oauth |
| from starlette.routing import request_response as _starlette_request_response |
|
|
| _orig_add_oauth_routes = _gr_oauth._add_oauth_routes |
|
|
| def _patched_add_oauth_routes(app): |
| _orig_add_oauth_routes(app) |
| for _route in app.routes: |
| if getattr(_route, "path", None) != "/login/callback": |
| continue |
| _orig_endpoint = _route.endpoint |
|
|
| async def _safe_oauth_callback(request, _orig=_orig_endpoint): |
| target = request.query_params.get("_target_url") or "/" |
| err = request.query_params.get("error") |
|
|
| def _clear_oauth_state(): |
| try: |
| for k in list(request.session.keys()): |
| if k.startswith("_state_huggingface"): |
| request.session.pop(k, None) |
| except Exception: |
| pass |
|
|
| if err: |
| logger.info( |
| "[Auth] OAuth declined (error=%s) — redirecting to %s", |
| err, |
| target, |
| ) |
| _clear_oauth_state() |
| return _RedirectResponse(target) |
| try: |
| return await _orig(request) |
| except Exception as exc: |
| logger.warning( |
| "[Auth] OAuth callback failed: %s — redirecting to %s", |
| exc, |
| target, |
| ) |
| _clear_oauth_state() |
| return _RedirectResponse(target) |
|
|
| _route.endpoint = _safe_oauth_callback |
| _route.app = _starlette_request_response(_safe_oauth_callback) |
| break |
|
|
| _gr_oauth._add_oauth_routes = _patched_add_oauth_routes |
| except Exception as _patch_exc: |
| logger.warning("[Auth] Could not patch Gradio OAuth callback: %s", _patch_exc) |
|
|
|
|
| apply_patch() |
|
|