File size: 4,844 Bytes
8cdca00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """
WebSocket helpers for reverse interfaces.
"""
import ssl
import certifi
import aiohttp
from aiohttp_socks import ProxyConnector
from typing import Mapping, Optional, Any
from urllib.parse import urlparse
from app.core.logger import logger
from app.core.config import get_config
def _default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
context.load_verify_locations(certifi.where())
return context
def _normalize_socks_proxy(proxy_url: str) -> tuple[str, Optional[bool]]:
scheme = urlparse(proxy_url).scheme.lower()
rdns: Optional[bool] = None
base_scheme = scheme
if scheme == "socks5h":
base_scheme = "socks5"
rdns = True
elif scheme == "socks4a":
base_scheme = "socks4"
rdns = True
if base_scheme != scheme:
proxy_url = proxy_url.replace(f"{scheme}://", f"{base_scheme}://", 1)
return proxy_url, rdns
def resolve_proxy(proxy_url: Optional[str] = None, ssl_context: ssl.SSLContext = _default_ssl_context()) -> tuple[aiohttp.BaseConnector, Optional[str]]:
"""Resolve proxy connector.
Args:
proxy_url: Optional[str], the proxy URL. Defaults to None.
ssl_context: ssl.SSLContext, the SSL context. Defaults to _default_ssl_context().
Returns:
tuple[aiohttp.BaseConnector, Optional[str]]: The proxy connector and the proxy URL.
"""
if not proxy_url:
return aiohttp.TCPConnector(ssl=ssl_context), None
scheme = urlparse(proxy_url).scheme.lower()
if scheme.startswith("socks"):
normalized, rdns = _normalize_socks_proxy(proxy_url)
logger.info(f"Using SOCKS proxy: {proxy_url}")
try:
if rdns is not None:
return (
ProxyConnector.from_url(normalized, rdns=rdns, ssl=ssl_context),
None,
)
except TypeError:
return ProxyConnector.from_url(normalized, ssl=ssl_context), None
return ProxyConnector.from_url(normalized, ssl=ssl_context), None
logger.info(f"Using HTTP proxy: {proxy_url}")
return aiohttp.TCPConnector(ssl=ssl_context), proxy_url
class WebSocketConnection:
"""WebSocket connection wrapper."""
def __init__(self, session: aiohttp.ClientSession, ws: aiohttp.ClientWebSocketResponse) -> None:
self.session = session
self.ws = ws
async def close(self) -> None:
if not self.ws.closed:
await self.ws.close()
await self.session.close()
async def __aenter__(self) -> aiohttp.ClientWebSocketResponse:
return self.ws
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()
class WebSocketClient:
"""WebSocket client with proxy support."""
def __init__(self, proxy: Optional[str] = None) -> None:
self._proxy_override = proxy
self._ssl_context = _default_ssl_context()
async def connect(
self,
url: str,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
ws_kwargs: Optional[Mapping[str, object]] = None,
) -> WebSocketConnection:
"""Connect to the WebSocket.
Args:
url: str, the URL to connect to.
headers: Optional[Mapping[str, str]], the headers to send. Defaults to None.
ws_kwargs: Optional[Mapping[str, object]], extra ws_connect kwargs. Defaults to None.
Returns:
WebSocketConnection: The WebSocket connection.
"""
# Resolve proxy dynamically from config if not overridden
proxy_url = self._proxy_override or get_config("proxy.base_proxy_url")
connector, resolved_proxy = resolve_proxy(proxy_url, self._ssl_context)
logger.debug(f"WebSocket connect: proxy_url={proxy_url}, resolved_proxy={resolved_proxy}, connector={type(connector).__name__}")
# Build client timeout
total_timeout = (
float(timeout)
if timeout is not None
else float(get_config("voice.timeout") or 120)
)
client_timeout = aiohttp.ClientTimeout(total=total_timeout)
# Create session
session = aiohttp.ClientSession(connector=connector, timeout=client_timeout)
try:
# Cast to Any to avoid Pylance errors with **extra_kwargs
extra_kwargs: dict[str, Any] = dict(ws_kwargs or {})
ws = await session.ws_connect(
url,
headers=headers,
proxy=resolved_proxy,
ssl=self._ssl_context,
**extra_kwargs,
)
return WebSocketConnection(session, ws)
except Exception:
await session.close()
raise
__all__ = ["WebSocketClient", "WebSocketConnection", "resolve_proxy"]
|