// Package utils 工具集合 package utils import ( "archive/tar" "archive/zip" "bufio" "bytes" "compress/gzip" "context" "encoding/base64" "fmt" "io" "net" "os" "os/exec" "path/filepath" "runtime" "strconv" "strings" "time" "github.com/Tencent/AI-Infra-Guard/internal/gologger" "github.com/spaolacci/murmur3" ) // Duration2String 将时间段转换为可读的字符串格式 // 如果时间超过60秒则返回分钟,否则返回秒 func Duration2String(t time.Duration) string { sceond := t.Seconds() if sceond >= 60 { return fmt.Sprintf("%.2f min", t.Minutes()) } else { return fmt.Sprintf("%.2f s", sceond) } } // InsertInto 在字符串中每隔指定间隔插入分隔符 // s: 源字符串 // interval: 插入间隔 // sep: 分隔符 func InsertInto(s string, interval int, sep rune) string { var buffer bytes.Buffer before := interval - 1 last := len(s) - 1 for i, char := range s { buffer.WriteRune(char) if i%interval == before && i != last { buffer.WriteRune(sep) } } buffer.WriteRune(sep) return buffer.String() } // FaviconHash 计算网站图标的哈希值 // 将数据进行base64编码后使用murmur3哈希算法计算 func FaviconHash(data []byte) int32 { stdBase64 := base64.StdEncoding.EncodeToString(data) stdBase64 = InsertInto(stdBase64, 76, '\n') hasher := murmur3.New32WithSeed(0) hasher.Write([]byte(stdBase64)) return int32(hasher.Sum32()) } // ScanDir 递归扫描目录,返回所有文件的完整路径 // path: 要扫描的目录路径 // 返回文件路径列表和可能的错误 func ScanDir(path string) ([]string, error) { files := make([]string, 0) dir, err := os.ReadDir(path) if err != nil { return nil, err } for _, fi := range dir { if fi.IsDir() { newDir, err := ScanDir(filepath.Join(path, fi.Name())) if err != nil { return files, err } files = append(files, newDir...) } else { files = append(files, filepath.Join(path, fi.Name())) } } return files, nil } // IsCIDR 检查给定的字符串是否为有效的CIDR格式 func IsCIDR(cidr string) bool { _, _, err := net.ParseCIDR(cidr) return err == nil } // IsFileExists 检查文件是否存在 // path: 文件路径 // 返回布尔值表示文件是否存在 func IsFileExists(path string) bool { _, err := os.Stat(path) if err != nil { if os.IsExist(err) { return true } return false } return true } // IsDir 判断给定路径是否为目录 // path: 待检查的路径 // 返回布尔值表示是否为目录 func IsDir(path string) bool { s, err := os.Stat(path) if err != nil { return false } return s.IsDir() } // TrimProtocol 移除URL中的HTTP/HTTPS协议前缀 // targetURL: 目标URL // 返回去除协议前缀后的URL func TrimProtocol(targetURL string) string { URL := strings.TrimSpace(targetURL) if strings.HasPrefix(strings.ToLower(URL), "http://") || strings.HasPrefix(strings.ToLower(URL), "https://") { URL = URL[strings.Index(URL, "//")+2:] } URL = strings.TrimRight(URL, "/") return URL } // CompareVersions 比较两个版本号字符串 // version1, version2: 待比较的版本号 // 返回值: 1 表示 version1 大于 version2 // // -1 表示 version1 小于 version2 // 0 表示两个版本号相等 func CompareVersions(version1, version2 string) int { v1Parts := strings.Split(version1, ".") v2Parts := strings.Split(version2, ".") // Determine the max length to iterate over maxLen := len(v1Parts) if len(v2Parts) > maxLen { maxLen = len(v2Parts) } for i := 0; i < maxLen; i++ { var num1, num2 int if i < len(v1Parts) { num1, _ = strconv.Atoi(v1Parts[i]) } if i < len(v2Parts) { num2, _ = strconv.Atoi(v2Parts[i]) } if num1 > num2 { return 1 } else if num1 < num2 { return -1 } } return 0 } // GetMiddleText 获取两个字符串之间的文本内容 // left: 左边界字符串 // right: 右边界字符串 // html: 源文本 // 返回左右边界之间的文本,如果未找到则返回空字符串 func GetMiddleText(left, right, html string) string { start := strings.Index(html, left) if start == -1 { return "" // 如果找不到 left,返回空字符串 } start += len(left) end := strings.Index(html[start:], right) if end == -1 { return "" // 如果找不到 right,返回空字符串 } end += start return html[start:end] } // PortInfo 存储端口和地址信息 type PortInfo struct { Port int Address string } // GetLocalOpenPorts 获取本地开放的端口及其地址信息 func GetLocalOpenPorts() ([]PortInfo, error) { var portInfos []PortInfo switch runtime.GOOS { case "windows": cmd := exec.Command("netstat", "-an") output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("执行netstat命令失败: %v", err) } scanner := bufio.NewScanner(strings.NewReader(string(output))) for scanner.Scan() { line := scanner.Text() if strings.Contains(line, "LISTENING") { parts := strings.Fields(line) if len(parts) >= 2 { addrPort := strings.Split(parts[1], ":") if len(addrPort) == 2 { port, err := strconv.Atoi(addrPort[1]) if err == nil { addr := addrPort[0] portInfos = append(portInfos, PortInfo{ Port: port, Address: addr, }) } } } } } case "darwin", "linux": cmd := exec.Command("lsof", "-i", "-P", "-n") output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("执行lsof命令失败: %v", err) } scanner := bufio.NewScanner(strings.NewReader(string(output))) for scanner.Scan() { line := scanner.Text() if strings.Contains(line, "LISTEN") { parts := strings.Fields(line) for _, part := range parts { if strings.Contains(part, ":") { addrPort := strings.Split(part, ":") if len(addrPort) == 2 { port, err := strconv.Atoi(addrPort[1]) if err == nil { addr := addrPort[0] if addr == "*" || addr == "0.0.0.0" { addr = "0.0.0.0" } else if addr == "127.0.0.1" || addr == "localhost" { addr = "127.0.0.1" } portInfos = append(portInfos, PortInfo{ Port: port, Address: addr, }) } } } } } } default: return nil, fmt.Errorf("不支持的操作系统: %s", runtime.GOOS) } // 去重 seen := make(map[string]bool) var result []PortInfo for _, info := range portInfos { key := fmt.Sprintf("%s:%d", info.Address, info.Port) if !seen[key] { seen[key] = true result = append(result, info) } } return result, nil } // ExtractZipFile 解压ZIP文件 func ExtractZipFile(zipFile string, destPath string) error { // 打开ZIP文件 reader, err := zip.OpenReader(zipFile) if err != nil { return fmt.Errorf("打开ZIP文件失败: %v", err) } defer reader.Close() // 确保目标目录存在 if err := os.MkdirAll(destPath, 0755); err != nil { return fmt.Errorf("创建目标目录失败: %v", err) } // 解压文件 for _, file := range reader.File { // 检查文件路径是否安全 filePath := filepath.Join(destPath, file.Name) if !strings.HasPrefix(filePath, filepath.Clean(destPath)+string(os.PathSeparator)) { gologger.Errorln(fmt.Sprintf("不安全的路径: %s", file.Name)) continue } // 创建目录 if file.FileInfo().IsDir() { if err := os.MkdirAll(filePath, 0755); err != nil { return fmt.Errorf("创建目录失败: %v", err) } continue } // 确保文件的父目录存在 if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { return fmt.Errorf("创建父目录失败: %v", err) } // 创建文件 outFile, err := os.Create(filePath) if err != nil { return fmt.Errorf("创建文件失败: %v", err) } defer outFile.Close() // 打开文件内容 rc, err := file.Open() if err != nil { return fmt.Errorf("打开压缩文件内容失败: %v", err) } defer rc.Close() // 复制内容 if _, err := io.Copy(outFile, rc); err != nil { return fmt.Errorf("复制文件内容失败: %v", err) } } return nil } // ExtractTGZ 文件解压 func ExtractTGZ(src, dest string) error { // 打开 .tgz 文件 file, err := os.Open(src) if err != nil { return err } defer file.Close() // 创建 gzip Reader gzr, err := gzip.NewReader(file) if err != nil { return err } defer gzr.Close() // 创建 tar Reader tr := tar.NewReader(gzr) // 遍历 tar 文件中的每个条目 for { header, err := tr.Next() if err == io.EOF { break // 读取完毕 } if err != nil { return err } // 安全处理目标路径,防止路径穿越攻击 targetPath, err := safePath(dest, header.Name) if err != nil { return err } // 根据文件类型处理 switch header.Typeflag { case tar.TypeDir: // 目录 if err := os.MkdirAll(targetPath, 0755); err != nil { return err } case tar.TypeReg: // 普通文件 if err := writeFile(targetPath, tr, header.Mode); err != nil { return err } // 可选:处理符号链接等其他类型 default: fmt.Printf("未处理类型: %v in %s\n", header.Typeflag, header.Name) } } return nil } // 安全路径检查,防止路径穿越 func safePath(dest, name string) (string, error) { targetPath := filepath.Join(dest, name) cleanedPath := filepath.Clean(targetPath) dest = filepath.Clean(dest) // 检查目标路径是否在目标目录下 if !strings.HasPrefix(cleanedPath, dest+string(os.PathSeparator)) && cleanedPath != dest { return "", fmt.Errorf("非法路径: %s", name) } return targetPath, nil } // 写入文件内容 func writeFile(path string, r io.Reader, mode int64) error { // 确保目录存在 if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { return err } // 创建文件并设置权限 file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(mode)) if err != nil { return err } defer file.Close() // 复制内容 if _, err := io.Copy(file, r); err != nil { return err } return nil } // GitClone 克隆Git仓库 func GitClone(repoURL, targetDir string, timeout time.Duration) error { var err error for i := 0; i < 3; i++ { err = func() error { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() cmd := exec.CommandContext(ctx, "git", "clone", "--", repoURL, targetDir) done := make(chan error) go func() { _, err := cmd.CombinedOutput() done <- err }() select { case <-ctx.Done(): _ = cmd.Process.Kill() return fmt.Errorf("操作超时") case err = <-done: return err } }() if err == nil { return nil } } return err } func RunCmd(dir, name string, arg []string, callback func(line string)) error { // 命令行执行,stdio读取 cmd := exec.Command(name, arg...) cmd.Dir = dir cmd.Env = os.Environ() // 获取命令行 cmdStr := name + " " + strings.Join(arg, " ") gologger.Infof("开始执行命令: %s", cmdStr) // 使用管道获取标准输出 stdout, err := cmd.StdoutPipe() if err != nil { return err } cmd.Stderr = cmd.Stdout // 将错误输出合并到标准输出 // 启动扫描器goroutine scanner := bufio.NewScanner(stdout) // 设置更大的缓冲区以处理超长文本行 // 默认64KB,这里设置为1MB const maxCapacity = 1024 * 1024 * 10 // 1MB buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, maxCapacity) done := make(chan error) // 改为传递错误信息 go func() { defer close(done) for scanner.Scan() { line := scanner.Text() callback(line) } // 检查扫描器是否遇到错误 if err := scanner.Err(); err != nil { // 管道关闭是正常的结束条件,不应视为错误 if strings.Contains(err.Error(), "file already closed") || strings.Contains(err.Error(), "broken pipe") { done <- nil return } done <- fmt.Errorf("读取输出时发生错误: %v", err) return } done <- nil }() // 启动命令 if err = cmd.Start(); err != nil { return err } // 等待命令执行完成 cmdErr := cmd.Wait() // 等待读取完成并检查读取错误 readErr := <-done // 优先返回读取错误,其次返回命令执行错误 if readErr != nil { return readErr } if cmdErr != nil { return cmdErr } return nil } func IsHostname(hostname string) bool { ips := strings.Split(hostname, ":") if len(ips) != 2 { return false } p := net.ParseIP(strings.TrimSpace(ips[0])) if p == nil { return false } return true } // StrInSlice checks if a string is in a slice of strings. func StrInSlice(str string, list []string) bool { for _, v := range list { if v == str { return true } } return false }