File size: 2,724 Bytes
b15b21e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Monkey-patch Gradio's OAuth callback to handle user denial /
errors gracefully (instead of returning HTTP 500).

Importing this module applies the patch as a side effect.
"""

from __future__ import annotations

import logging

from starlette.responses import RedirectResponse as _RedirectResponse

logger = logging.getLogger(__name__)

def apply_patch() -> None:
    """Wrap ``gradio.oauth._add_oauth_routes`` to clear OAuth state
    and redirect on error."""
    try:
        import gradio.oauth as _gr_oauth
        from starlette.routing import request_response as _starlette_request_response

        _orig_add_oauth_routes = _gr_oauth._add_oauth_routes

        def _patched_add_oauth_routes(app):
            _orig_add_oauth_routes(app)
            for _route in app.routes:
                if getattr(_route, "path", None) != "/login/callback":
                    continue
                _orig_endpoint = _route.endpoint

                async def _safe_oauth_callback(request, _orig=_orig_endpoint):
                    target = request.query_params.get("_target_url") or "/"
                    err = request.query_params.get("error")

                    def _clear_oauth_state():
                        try:
                            for k in list(request.session.keys()):
                                if k.startswith("_state_huggingface"):
                                    request.session.pop(k, None)
                        except Exception:
                            pass

                    if err:
                        logger.info(
                            "[Auth] OAuth declined (error=%s) — redirecting to %s",
                            err,
                            target,
                        )
                        _clear_oauth_state()
                        return _RedirectResponse(target)
                    try:
                        return await _orig(request)
                    except Exception as exc:  # pragma: no cover - defensive
                        logger.warning(
                            "[Auth] OAuth callback failed: %s — redirecting to %s",
                            exc,
                            target,
                        )
                        _clear_oauth_state()
                        return _RedirectResponse(target)

                _route.endpoint = _safe_oauth_callback
                _route.app = _starlette_request_response(_safe_oauth_callback)
                break

        _gr_oauth._add_oauth_routes = _patched_add_oauth_routes
    except Exception as _patch_exc:  # pragma: no cover - defensive
        logger.warning("[Auth] Could not patch Gradio OAuth callback: %s", _patch_exc)


apply_patch()