James Lindsay
feat: add safe mode for public hosting (SAFE_MODE=1)
7bf2d3e unverified
import ipaddress
import socket
from urllib.parse import urljoin, urlparse
import httpx
from cli_textual.tools.base import ToolResult
MAX_CHARS = 8192
_BLOCKED_HOSTS = {
"metadata.google.internal",
"metadata.goog",
"169.254.169.254", # AWS/Azure IMDS
"fd00:ec2::254", # AWS IPv6 IMDS
"168.63.129.16", # Azure Wireserver
}
def _check_url(url: str) -> tuple[str | None, str | None]:
"""Validate *url* and return ``(error, safe_ip)``.
Returns an error string if the URL is unsafe, otherwise returns
``(None, resolved_ip)`` so the caller can pin the connection to the
already-validated IP (prevents DNS-rebinding / TOCTOU attacks).
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return f"Error: unsupported scheme '{parsed.scheme}'", None
hostname = parsed.hostname
if not hostname:
return "Error: no hostname in URL", None
if hostname in _BLOCKED_HOSTS:
return f"Error: access denied — blocked host '{hostname}'", None
try:
safe_ip = None
for info in socket.getaddrinfo(hostname, None):
addr = ipaddress.ip_address(info[4][0])
if addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_reserved:
return "Error: access denied — private/internal IP", None
if safe_ip is None:
safe_ip = str(addr)
if safe_ip is None:
return f"Error: cannot resolve hostname '{hostname}'", None
return None, safe_ip
except socket.gaierror:
return f"Error: cannot resolve hostname '{hostname}'", None
# Keep the old name as an alias for tests that import it directly
def _is_url_safe(url: str) -> str | None:
err, _ = _check_url(url)
return err
_MAX_REDIRECTS = 5
async def _safe_get(url: str) -> httpx.Response:
"""GET *url* with SSRF checks on every redirect hop.
Each hop resolves DNS, validates the target, and pins the connection
to the resolved IP with the correct ``sni_hostname`` for TLS.
"""
for _ in range(_MAX_REDIRECTS):
err, safe_ip = _check_url(url)
if err:
raise _SSRFBlocked(err)
parsed = urlparse(url)
original_host = parsed.hostname
# Build a URL that connects to the pinned IP but preserves scheme/path/query.
# IPv6 addresses need square brackets in the netloc.
ip_host = f"[{safe_ip}]" if ":" in safe_ip else safe_ip
pinned_url = parsed._replace(netloc=f"{ip_host}:{parsed.port}" if parsed.port else ip_host).geturl()
# sni_hostname tells httpcore to use the original hostname for TLS SNI
# and certificate verification instead of the pinned IP.
extensions = {"sni_hostname": original_host} if parsed.scheme == "https" else {}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(
pinned_url,
headers={"Host": original_host},
extensions=extensions,
follow_redirects=False,
)
if response.is_redirect:
location = response.headers.get("location", "")
if not location:
break
# Resolve relative redirects against the current URL
url = urljoin(url, location)
continue
return response
raise _SSRFBlocked("Error: too many redirects")
class _SSRFBlocked(Exception):
pass
async def web_fetch(url: str) -> ToolResult:
"""Fetch a URL via HTTP GET and return the response body.
Response body is capped at 8 KB. Private/internal URLs are blocked.
DNS is resolved and pinned per hop to prevent rebinding attacks.
"""
try:
response = await _safe_get(url)
body = response.text
truncated = ""
if len(body) > MAX_CHARS:
body = body[:MAX_CHARS]
truncated = "\n[truncated]"
return ToolResult(output=f"HTTP {response.status_code}\n{body}{truncated}")
except _SSRFBlocked as exc:
return ToolResult(output=str(exc), is_error=True)
except Exception as exc:
return ToolResult(output=f"Error fetching URL: {exc}", is_error=True)