"""Monkey-patch Gradio's OAuth callback to handle user denial / errors gracefully (instead of returning HTTP 500). Importing this module applies the patch as a side effect. """ from __future__ import annotations import logging from starlette.responses import RedirectResponse as _RedirectResponse logger = logging.getLogger(__name__) def apply_patch() -> None: """Wrap ``gradio.oauth._add_oauth_routes`` to clear OAuth state and redirect on error.""" try: import gradio.oauth as _gr_oauth from starlette.routing import request_response as _starlette_request_response _orig_add_oauth_routes = _gr_oauth._add_oauth_routes def _patched_add_oauth_routes(app): _orig_add_oauth_routes(app) for _route in app.routes: if getattr(_route, "path", None) != "/login/callback": continue _orig_endpoint = _route.endpoint async def _safe_oauth_callback(request, _orig=_orig_endpoint): target = request.query_params.get("_target_url") or "/" err = request.query_params.get("error") def _clear_oauth_state(): try: for k in list(request.session.keys()): if k.startswith("_state_huggingface"): request.session.pop(k, None) except Exception: pass if err: logger.info( "[Auth] OAuth declined (error=%s) — redirecting to %s", err, target, ) _clear_oauth_state() return _RedirectResponse(target) try: return await _orig(request) except Exception as exc: # pragma: no cover - defensive logger.warning( "[Auth] OAuth callback failed: %s — redirecting to %s", exc, target, ) _clear_oauth_state() return _RedirectResponse(target) _route.endpoint = _safe_oauth_callback _route.app = _starlette_request_response(_safe_oauth_callback) break _gr_oauth._add_oauth_routes = _patched_add_oauth_routes except Exception as _patch_exc: # pragma: no cover - defensive logger.warning("[Auth] Could not patch Gradio OAuth callback: %s", _patch_exc) apply_patch()