"""Egress policy for user-controlled web_fetch URLs (SSRF guard).""" from __future__ import annotations import ipaddress import socket from dataclasses import dataclass from urllib.parse import urlparse @dataclass(frozen=True, slots=True) class WebFetchEgressPolicy: """Egress rules for user-influenced web_fetch URLs.""" allow_private_network_targets: bool allowed_schemes: frozenset[str] class WebFetchEgressViolation(ValueError): """Raised when a web_fetch URL is rejected by egress policy (SSRF guard).""" def _port_for_url(parsed) -> int: if parsed.port is not None: return parsed.port return 443 if (parsed.scheme or "").lower() == "https" else 80 def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]: try: return socket.getaddrinfo( host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP ) except OSError as exc: raise WebFetchEgressViolation( f"Could not resolve host {host!r}: {exc}" ) from exc def get_validated_stream_addrinfos_for_egress( url: str, policy: WebFetchEgressPolicy ) -> list[tuple]: """Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning. Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS server cannot rebind to a disallowed address between resolution and the TCP connect (used by :func:`api.web_tools.outbound._run_web_fetch`). """ parsed = urlparse(url) scheme = (parsed.scheme or "").lower() if scheme not in policy.allowed_schemes: raise WebFetchEgressViolation( f"URL scheme {scheme!r} is not allowed for web_fetch" ) host = parsed.hostname if host is None or host == "": raise WebFetchEgressViolation("web_fetch URL must include a host") port = _port_for_url(parsed) if policy.allow_private_network_targets: return _stream_getaddrinfo_or_raise(host, port) host_lower = host.lower() if host_lower == "localhost" or host_lower.endswith(".localhost"): raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch") if host_lower.endswith(".local"): raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch") try: parsed_ip = ipaddress.ip_address(host) except ValueError: parsed_ip = None if parsed_ip is not None: if not parsed_ip.is_global: raise WebFetchEgressViolation( f"Non-public IP host {host!r} is not allowed for web_fetch" ) return _stream_getaddrinfo_or_raise(host, port) infos = _stream_getaddrinfo_or_raise(host, port) for *_, sockaddr in infos: addr = sockaddr[0] try: resolved = ipaddress.ip_address(addr) except ValueError: continue if not resolved.is_global: raise WebFetchEgressViolation( f"Host {host!r} resolves to a non-public address ({resolved})" ) return infos def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None: """Validate ``url`` (scheme, host, and resolved addresses) for web_fetch.""" get_validated_stream_addrinfos_for_egress(url, policy)