File size: 4,810 Bytes
8059bf0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | package urlvalidator
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
)
type ValidationOptions struct {
AllowedHosts []string
RequireAllowlist bool
AllowPrivate bool
}
// ValidateHTTPURL validates an outbound HTTP/HTTPS URL.
//
// It provides a single validation entry point that supports:
// - scheme 校验(https 或可选允许 http)
// - 可选 allowlist(支持 *.example.com 通配)
// - allow_private_hosts 策略(阻断 localhost/私网字面量 IP)
//
// 注意:DNS Rebinding 防护(解析后 IP 校验)应在实际发起请求时执行,避免 TOCTOU。
func ValidateHTTPURL(raw string, allowInsecureHTTP bool, opts ValidationOptions) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
scheme := strings.ToLower(parsed.Scheme)
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return "", errors.New("invalid host")
}
if !opts.AllowPrivate && isBlockedHost(host) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
if port := parsed.Port(); port != "" {
num, err := strconv.Atoi(port)
if err != nil || num <= 0 || num > 65535 {
return "", fmt.Errorf("invalid port: %s", port)
}
}
allowlist := normalizeAllowlist(opts.AllowedHosts)
if opts.RequireAllowlist && len(allowlist) == 0 {
return "", errors.New("allowlist is not configured")
}
if len(allowlist) > 0 && !isAllowedHost(host, allowlist) {
return "", fmt.Errorf("host is not allowed: %s", host)
}
parsed.Path = strings.TrimRight(parsed.Path, "/")
parsed.RawPath = ""
return strings.TrimRight(parsed.String(), "/"), nil
}
func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) {
// 最小格式校验:仅保证 URL 可解析且 scheme 合规,不做白名单/私网/SSRF 校验
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("url is required")
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid url: %s", trimmed)
}
scheme := strings.ToLower(parsed.Scheme)
if scheme != "https" && (!allowInsecureHTTP || scheme != "http") {
return "", fmt.Errorf("invalid url scheme: %s", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return "", errors.New("invalid host")
}
if port := parsed.Port(); port != "" {
num, err := strconv.Atoi(port)
if err != nil || num <= 0 || num > 65535 {
return "", fmt.Errorf("invalid port: %s", port)
}
}
return strings.TrimRight(trimmed, "/"), nil
}
func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) {
return ValidateHTTPURL(raw, false, opts)
}
// ValidateResolvedIP 验证 DNS 解析后的 IP 地址是否安全
// 用于防止 DNS Rebinding 攻击:在实际 HTTP 请求时调用此函数验证解析后的 IP
func ValidateResolvedIP(host string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return fmt.Errorf("dns resolution failed: %w", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return fmt.Errorf("resolved ip %s is not allowed", ip.String())
}
}
return nil
}
func normalizeAllowlist(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, v := range values {
entry := strings.ToLower(strings.TrimSpace(v))
if entry == "" {
continue
}
if host, _, err := net.SplitHostPort(entry); err == nil {
entry = host
}
normalized = append(normalized, entry)
}
return normalized
}
func isAllowedHost(host string, allowlist []string) bool {
for _, entry := range allowlist {
if entry == "" {
continue
}
if strings.HasPrefix(entry, "*.") {
suffix := strings.TrimPrefix(entry, "*.")
if host == suffix || strings.HasSuffix(host, "."+suffix) {
return true
}
continue
}
if host == entry {
return true
}
}
return false
}
func isBlockedHost(host string) bool {
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
}
return false
}
|