Spaces:
Sleeping
Sleeping
| import ipaddress | |
| import socket | |
| from urllib.parse import urljoin, urlparse | |
| import httpx | |
| from cli_textual.tools.base import ToolResult | |
| MAX_CHARS = 8192 | |
| _BLOCKED_HOSTS = { | |
| "metadata.google.internal", | |
| "metadata.goog", | |
| "169.254.169.254", # AWS/Azure IMDS | |
| "fd00:ec2::254", # AWS IPv6 IMDS | |
| "168.63.129.16", # Azure Wireserver | |
| } | |
| def _check_url(url: str) -> tuple[str | None, str | None]: | |
| """Validate *url* and return ``(error, safe_ip)``. | |
| Returns an error string if the URL is unsafe, otherwise returns | |
| ``(None, resolved_ip)`` so the caller can pin the connection to the | |
| already-validated IP (prevents DNS-rebinding / TOCTOU attacks). | |
| """ | |
| parsed = urlparse(url) | |
| if parsed.scheme not in ("http", "https"): | |
| return f"Error: unsupported scheme '{parsed.scheme}'", None | |
| hostname = parsed.hostname | |
| if not hostname: | |
| return "Error: no hostname in URL", None | |
| if hostname in _BLOCKED_HOSTS: | |
| return f"Error: access denied — blocked host '{hostname}'", None | |
| try: | |
| safe_ip = None | |
| for info in socket.getaddrinfo(hostname, None): | |
| addr = ipaddress.ip_address(info[4][0]) | |
| if addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_reserved: | |
| return "Error: access denied — private/internal IP", None | |
| if safe_ip is None: | |
| safe_ip = str(addr) | |
| if safe_ip is None: | |
| return f"Error: cannot resolve hostname '{hostname}'", None | |
| return None, safe_ip | |
| except socket.gaierror: | |
| return f"Error: cannot resolve hostname '{hostname}'", None | |
| # Keep the old name as an alias for tests that import it directly | |
| def _is_url_safe(url: str) -> str | None: | |
| err, _ = _check_url(url) | |
| return err | |
| _MAX_REDIRECTS = 5 | |
| async def _safe_get(url: str) -> httpx.Response: | |
| """GET *url* with SSRF checks on every redirect hop. | |
| Each hop resolves DNS, validates the target, and pins the connection | |
| to the resolved IP with the correct ``sni_hostname`` for TLS. | |
| """ | |
| for _ in range(_MAX_REDIRECTS): | |
| err, safe_ip = _check_url(url) | |
| if err: | |
| raise _SSRFBlocked(err) | |
| parsed = urlparse(url) | |
| original_host = parsed.hostname | |
| # Build a URL that connects to the pinned IP but preserves scheme/path/query. | |
| # IPv6 addresses need square brackets in the netloc. | |
| ip_host = f"[{safe_ip}]" if ":" in safe_ip else safe_ip | |
| pinned_url = parsed._replace(netloc=f"{ip_host}:{parsed.port}" if parsed.port else ip_host).geturl() | |
| # sni_hostname tells httpcore to use the original hostname for TLS SNI | |
| # and certificate verification instead of the pinned IP. | |
| extensions = {"sni_hostname": original_host} if parsed.scheme == "https" else {} | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| response = await client.get( | |
| pinned_url, | |
| headers={"Host": original_host}, | |
| extensions=extensions, | |
| follow_redirects=False, | |
| ) | |
| if response.is_redirect: | |
| location = response.headers.get("location", "") | |
| if not location: | |
| break | |
| # Resolve relative redirects against the current URL | |
| url = urljoin(url, location) | |
| continue | |
| return response | |
| raise _SSRFBlocked("Error: too many redirects") | |
| class _SSRFBlocked(Exception): | |
| pass | |
| async def web_fetch(url: str) -> ToolResult: | |
| """Fetch a URL via HTTP GET and return the response body. | |
| Response body is capped at 8 KB. Private/internal URLs are blocked. | |
| DNS is resolved and pinned per hop to prevent rebinding attacks. | |
| """ | |
| try: | |
| response = await _safe_get(url) | |
| body = response.text | |
| truncated = "" | |
| if len(body) > MAX_CHARS: | |
| body = body[:MAX_CHARS] | |
| truncated = "\n[truncated]" | |
| return ToolResult(output=f"HTTP {response.status_code}\n{body}{truncated}") | |
| except _SSRFBlocked as exc: | |
| return ToolResult(output=str(exc), is_error=True) | |
| except Exception as exc: | |
| return ToolResult(output=f"Error fetching URL: {exc}", is_error=True) | |