Spaces:
Build error
Build error
| package common | |
| import ( | |
| "fmt" | |
| "net" | |
| "net/url" | |
| "strconv" | |
| "strings" | |
| ) | |
| // SSRFProtection SSRF防护配置 | |
| type SSRFProtection struct { | |
| AllowPrivateIp bool | |
| DomainFilterMode bool // true: 白名单, false: 黑名单 | |
| DomainList []string // domain format, e.g. example.com, *.example.com | |
| IpFilterMode bool // true: 白名单, false: 黑名单 | |
| IpList []string // CIDR or single IP | |
| AllowedPorts []int // 允许的端口范围 | |
| ApplyIPFilterForDomain bool // 对域名启用IP过滤 | |
| } | |
| // DefaultSSRFProtection 默认SSRF防护配置 | |
| var DefaultSSRFProtection = &SSRFProtection{ | |
| AllowPrivateIp: false, | |
| DomainFilterMode: true, | |
| DomainList: []string{}, | |
| IpFilterMode: true, | |
| IpList: []string{}, | |
| AllowedPorts: []int{}, | |
| } | |
| // isPrivateIP 检查IP是否为私有地址 | |
| func isPrivateIP(ip net.IP) bool { | |
| if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { | |
| return true | |
| } | |
| // 检查私有网段 | |
| private := []net.IPNet{ | |
| {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 | |
| {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 | |
| {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 | |
| {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 | |
| {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地) | |
| {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播) | |
| {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留) | |
| } | |
| for _, privateNet := range private { | |
| if privateNet.Contains(ip) { | |
| return true | |
| } | |
| } | |
| // 检查IPv6私有地址 | |
| if ip.To4() == nil { | |
| // IPv6 loopback | |
| if ip.Equal(net.IPv6loopback) { | |
| return true | |
| } | |
| // IPv6 link-local | |
| if strings.HasPrefix(ip.String(), "fe80:") { | |
| return true | |
| } | |
| // IPv6 unique local | |
| if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") { | |
| return true | |
| } | |
| } | |
| return false | |
| } | |
| // parsePortRanges 解析端口范围配置 | |
| // 支持格式: "80", "443", "8000-9000" | |
| func parsePortRanges(portConfigs []string) ([]int, error) { | |
| var ports []int | |
| for _, config := range portConfigs { | |
| config = strings.TrimSpace(config) | |
| if config == "" { | |
| continue | |
| } | |
| if strings.Contains(config, "-") { | |
| // 处理端口范围 "8000-9000" | |
| parts := strings.Split(config, "-") | |
| if len(parts) != 2 { | |
| return nil, fmt.Errorf("invalid port range format: %s", config) | |
| } | |
| startPort, err := strconv.Atoi(strings.TrimSpace(parts[0])) | |
| if err != nil { | |
| return nil, fmt.Errorf("invalid start port in range %s: %v", config, err) | |
| } | |
| endPort, err := strconv.Atoi(strings.TrimSpace(parts[1])) | |
| if err != nil { | |
| return nil, fmt.Errorf("invalid end port in range %s: %v", config, err) | |
| } | |
| if startPort > endPort { | |
| return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config) | |
| } | |
| if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 { | |
| return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config) | |
| } | |
| // 添加范围内的所有端口 | |
| for port := startPort; port <= endPort; port++ { | |
| ports = append(ports, port) | |
| } | |
| } else { | |
| // 处理单个端口 "80" | |
| port, err := strconv.Atoi(config) | |
| if err != nil { | |
| return nil, fmt.Errorf("invalid port number: %s", config) | |
| } | |
| if port < 1 || port > 65535 { | |
| return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port) | |
| } | |
| ports = append(ports, port) | |
| } | |
| } | |
| return ports, nil | |
| } | |
| // isAllowedPort 检查端口是否被允许 | |
| func (p *SSRFProtection) isAllowedPort(port int) bool { | |
| if len(p.AllowedPorts) == 0 { | |
| return true // 如果没有配置端口限制,则允许所有端口 | |
| } | |
| for _, allowedPort := range p.AllowedPorts { | |
| if port == allowedPort { | |
| return true | |
| } | |
| } | |
| return false | |
| } | |
| // isDomainWhitelisted 检查域名是否在白名单中 | |
| func isDomainListed(domain string, list []string) bool { | |
| if len(list) == 0 { | |
| return false | |
| } | |
| domain = strings.ToLower(domain) | |
| for _, item := range list { | |
| item = strings.ToLower(strings.TrimSpace(item)) | |
| if item == "" { | |
| continue | |
| } | |
| // 精确匹配 | |
| if domain == item { | |
| return true | |
| } | |
| // 通配符匹配 (*.example.com) | |
| if strings.HasPrefix(item, "*.") { | |
| suffix := strings.TrimPrefix(item, "*.") | |
| if strings.HasSuffix(domain, "."+suffix) || domain == suffix { | |
| return true | |
| } | |
| } | |
| } | |
| return false | |
| } | |
| func (p *SSRFProtection) isDomainAllowed(domain string) bool { | |
| listed := isDomainListed(domain, p.DomainList) | |
| if p.DomainFilterMode { // 白名单 | |
| return listed | |
| } | |
| // 黑名单 | |
| return !listed | |
| } | |
| // isIPWhitelisted 检查IP是否在白名单中 | |
| func isIPListed(ip net.IP, list []string) bool { | |
| if len(list) == 0 { | |
| return false | |
| } | |
| for _, whitelistCIDR := range list { | |
| _, network, err := net.ParseCIDR(whitelistCIDR) | |
| if err != nil { | |
| // 尝试作为单个IP处理 | |
| if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil { | |
| if ip.Equal(whitelistIP) { | |
| return true | |
| } | |
| } | |
| continue | |
| } | |
| if network.Contains(ip) { | |
| return true | |
| } | |
| } | |
| return false | |
| } | |
| // IsIPAccessAllowed 检查IP是否允许访问 | |
| func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool { | |
| // 私有IP限制 | |
| if isPrivateIP(ip) && !p.AllowPrivateIp { | |
| return false | |
| } | |
| listed := isIPListed(ip, p.IpList) | |
| if p.IpFilterMode { // 白名单 | |
| return listed | |
| } | |
| // 黑名单 | |
| return !listed | |
| } | |
| // ValidateURL 验证URL是否安全 | |
| func (p *SSRFProtection) ValidateURL(urlStr string) error { | |
| // 解析URL | |
| u, err := url.Parse(urlStr) | |
| if err != nil { | |
| return fmt.Errorf("invalid URL format: %v", err) | |
| } | |
| // 只允许HTTP/HTTPS协议 | |
| if u.Scheme != "http" && u.Scheme != "https" { | |
| return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme) | |
| } | |
| // 解析主机和端口 | |
| host, portStr, err := net.SplitHostPort(u.Host) | |
| if err != nil { | |
| // 没有端口,使用默认端口 | |
| host = u.Hostname() | |
| if u.Scheme == "https" { | |
| portStr = "443" | |
| } else { | |
| portStr = "80" | |
| } | |
| } | |
| // 验证端口 | |
| port, err := strconv.Atoi(portStr) | |
| if err != nil { | |
| return fmt.Errorf("invalid port: %s", portStr) | |
| } | |
| if !p.isAllowedPort(port) { | |
| return fmt.Errorf("port %d is not allowed", port) | |
| } | |
| // 如果 host 是 IP,则跳过域名检查 | |
| if ip := net.ParseIP(host); ip != nil { | |
| if !p.IsIPAccessAllowed(ip) { | |
| if isPrivateIP(ip) { | |
| return fmt.Errorf("private IP address not allowed: %s", ip.String()) | |
| } | |
| if p.IpFilterMode { | |
| return fmt.Errorf("ip not in whitelist: %s", ip.String()) | |
| } | |
| return fmt.Errorf("ip in blacklist: %s", ip.String()) | |
| } | |
| return nil | |
| } | |
| // 先进行域名过滤 | |
| if !p.isDomainAllowed(host) { | |
| if p.DomainFilterMode { | |
| return fmt.Errorf("domain not in whitelist: %s", host) | |
| } | |
| return fmt.Errorf("domain in blacklist: %s", host) | |
| } | |
| // 若未启用对域名应用IP过滤,则到此通过 | |
| if !p.ApplyIPFilterForDomain { | |
| return nil | |
| } | |
| // 解析域名对应IP并检查 | |
| ips, err := net.LookupIP(host) | |
| if err != nil { | |
| return fmt.Errorf("DNS resolution failed for %s: %v", host, err) | |
| } | |
| for _, ip := range ips { | |
| if !p.IsIPAccessAllowed(ip) { | |
| if isPrivateIP(ip) && !p.AllowPrivateIp { | |
| return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) | |
| } | |
| if p.IpFilterMode { | |
| return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String()) | |
| } | |
| return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String()) | |
| } | |
| } | |
| return nil | |
| } | |
| // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL | |
| func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error { | |
| // 如果SSRF防护被禁用,直接返回成功 | |
| if !enableSSRFProtection { | |
| return nil | |
| } | |
| // 解析端口范围配置 | |
| allowedPortInts, err := parsePortRanges(allowedPorts) | |
| if err != nil { | |
| return fmt.Errorf("request reject - invalid port configuration: %v", err) | |
| } | |
| protection := &SSRFProtection{ | |
| AllowPrivateIp: allowPrivateIp, | |
| DomainFilterMode: domainFilterMode, | |
| DomainList: domainList, | |
| IpFilterMode: ipFilterMode, | |
| IpList: ipList, | |
| AllowedPorts: allowedPortInts, | |
| ApplyIPFilterForDomain: applyIPFilterForDomain, | |
| } | |
| return protection.ValidateURL(urlStr) | |
| } | |