Spaces:
Running
Running
File size: 3,287 Bytes
0157ac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
@dataclass(frozen=True, slots=True)
class WebFetchEgressPolicy:
"""Egress rules for user-influenced web_fetch URLs."""
allow_private_network_targets: bool
allowed_schemes: frozenset[str]
class WebFetchEgressViolation(ValueError):
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
def _port_for_url(parsed) -> int:
if parsed.port is not None:
return parsed.port
return 443 if (parsed.scheme or "").lower() == "https" else 80
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
try:
return socket.getaddrinfo(
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
)
except OSError as exc:
raise WebFetchEgressViolation(
f"Could not resolve host {host!r}: {exc}"
) from exc
def get_validated_stream_addrinfos_for_egress(
url: str, policy: WebFetchEgressPolicy
) -> list[tuple]:
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
server cannot rebind to a disallowed address between resolution and the TCP
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
"""
parsed = urlparse(url)
scheme = (parsed.scheme or "").lower()
if scheme not in policy.allowed_schemes:
raise WebFetchEgressViolation(
f"URL scheme {scheme!r} is not allowed for web_fetch"
)
host = parsed.hostname
if host is None or host == "":
raise WebFetchEgressViolation("web_fetch URL must include a host")
port = _port_for_url(parsed)
if policy.allow_private_network_targets:
return _stream_getaddrinfo_or_raise(host, port)
host_lower = host.lower()
if host_lower == "localhost" or host_lower.endswith(".localhost"):
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
if host_lower.endswith(".local"):
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
try:
parsed_ip = ipaddress.ip_address(host)
except ValueError:
parsed_ip = None
if parsed_ip is not None:
if not parsed_ip.is_global:
raise WebFetchEgressViolation(
f"Non-public IP host {host!r} is not allowed for web_fetch"
)
return _stream_getaddrinfo_or_raise(host, port)
infos = _stream_getaddrinfo_or_raise(host, port)
for *_, sockaddr in infos:
addr = sockaddr[0]
try:
resolved = ipaddress.ip_address(addr)
except ValueError:
continue
if not resolved.is_global:
raise WebFetchEgressViolation(
f"Host {host!r} resolves to a non-public address ({resolved})"
)
return infos
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
get_validated_stream_addrinfos_for_egress(url, policy)
|