File size: 4,252 Bytes
7bf2d3e
 
 
 
d4c4bcc
 
 
 
 
7bf2d3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4c4bcc
 
 
 
7bf2d3e
 
d4c4bcc
 
7bf2d3e
d4c4bcc
 
 
 
 
 
7bf2d3e
 
d4c4bcc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import ipaddress
import socket
from urllib.parse import urljoin, urlparse

import httpx
from cli_textual.tools.base import ToolResult

MAX_CHARS = 8192

_BLOCKED_HOSTS = {
    "metadata.google.internal",
    "metadata.goog",
    "169.254.169.254",      # AWS/Azure IMDS
    "fd00:ec2::254",        # AWS IPv6 IMDS
    "168.63.129.16",        # Azure Wireserver
}


def _check_url(url: str) -> tuple[str | None, str | None]:
    """Validate *url* and return ``(error, safe_ip)``.

    Returns an error string if the URL is unsafe, otherwise returns
    ``(None, resolved_ip)`` so the caller can pin the connection to the
    already-validated IP (prevents DNS-rebinding / TOCTOU attacks).
    """
    parsed = urlparse(url)
    if parsed.scheme not in ("http", "https"):
        return f"Error: unsupported scheme '{parsed.scheme}'", None
    hostname = parsed.hostname
    if not hostname:
        return "Error: no hostname in URL", None
    if hostname in _BLOCKED_HOSTS:
        return f"Error: access denied — blocked host '{hostname}'", None
    try:
        safe_ip = None
        for info in socket.getaddrinfo(hostname, None):
            addr = ipaddress.ip_address(info[4][0])
            if addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_reserved:
                return "Error: access denied — private/internal IP", None
            if safe_ip is None:
                safe_ip = str(addr)
        if safe_ip is None:
            return f"Error: cannot resolve hostname '{hostname}'", None
        return None, safe_ip
    except socket.gaierror:
        return f"Error: cannot resolve hostname '{hostname}'", None


# Keep the old name as an alias for tests that import it directly
def _is_url_safe(url: str) -> str | None:
    err, _ = _check_url(url)
    return err


_MAX_REDIRECTS = 5


async def _safe_get(url: str) -> httpx.Response:
    """GET *url* with SSRF checks on every redirect hop.

    Each hop resolves DNS, validates the target, and pins the connection
    to the resolved IP with the correct ``sni_hostname`` for TLS.
    """
    for _ in range(_MAX_REDIRECTS):
        err, safe_ip = _check_url(url)
        if err:
            raise _SSRFBlocked(err)

        parsed = urlparse(url)
        original_host = parsed.hostname

        # Build a URL that connects to the pinned IP but preserves scheme/path/query.
        # IPv6 addresses need square brackets in the netloc.
        ip_host = f"[{safe_ip}]" if ":" in safe_ip else safe_ip
        pinned_url = parsed._replace(netloc=f"{ip_host}:{parsed.port}" if parsed.port else ip_host).geturl()

        # sni_hostname tells httpcore to use the original hostname for TLS SNI
        # and certificate verification instead of the pinned IP.
        extensions = {"sni_hostname": original_host} if parsed.scheme == "https" else {}

        async with httpx.AsyncClient(timeout=30) as client:
            response = await client.get(
                pinned_url,
                headers={"Host": original_host},
                extensions=extensions,
                follow_redirects=False,
            )

        if response.is_redirect:
            location = response.headers.get("location", "")
            if not location:
                break
            # Resolve relative redirects against the current URL
            url = urljoin(url, location)
            continue
        return response

    raise _SSRFBlocked("Error: too many redirects")


class _SSRFBlocked(Exception):
    pass


async def web_fetch(url: str) -> ToolResult:
    """Fetch a URL via HTTP GET and return the response body.

    Response body is capped at 8 KB.  Private/internal URLs are blocked.
    DNS is resolved and pinned per hop to prevent rebinding attacks.
    """
    try:
        response = await _safe_get(url)
        body = response.text
        truncated = ""
        if len(body) > MAX_CHARS:
            body = body[:MAX_CHARS]
            truncated = "\n[truncated]"
        return ToolResult(output=f"HTTP {response.status_code}\n{body}{truncated}")
    except _SSRFBlocked as exc:
        return ToolResult(output=str(exc), is_error=True)
    except Exception as exc:
        return ToolResult(output=f"Error fetching URL: {exc}", is_error=True)