| package app |
|
|
| import ( |
| "bytes" |
| "encoding/csv" |
| "fmt" |
| "io" |
| "log" |
| "net/http" |
| "strconv" |
| "strings" |
| "time" |
|
|
| "ccLoad/internal/model" |
| "ccLoad/internal/util" |
|
|
| "github.com/bytedance/sonic" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| |
|
|
| |
| |
| func (s *Server) HandleExportChannelsCSV(c *gin.Context) { |
| cfgs, err := s.store.ListConfigs(c.Request.Context()) |
| if err != nil { |
| RespondError(c, http.StatusInternalServerError, err) |
| return |
| } |
|
|
| |
| allAPIKeys, err := s.store.GetAllAPIKeys(c.Request.Context()) |
| if err != nil { |
| log.Printf("[WARN] 批量查询API Keys失败: %v", err) |
| allAPIKeys = make(map[int64][]*model.APIKey) |
| } |
|
|
| buf := &bytes.Buffer{} |
| |
| buf.WriteString("\ufeff") |
|
|
| writer := csv.NewWriter(buf) |
| defer writer.Flush() |
|
|
| header := []string{"id", "name", "api_key", "url", "priority", "rpm_limit", "models", "model_redirects", "channel_type", "protocol_transforms", "protocol_transform_mode", "key_strategy", "enabled", "scheduled_check_enabled", "scheduled_check_model"} |
| if err := writer.Write(header); err != nil { |
| RespondError(c, http.StatusInternalServerError, err) |
| return |
| } |
|
|
| for _, cfg := range cfgs { |
| |
| apiKeys := allAPIKeys[cfg.ID] |
|
|
| |
| apiKeyStrs := make([]string, 0, len(apiKeys)) |
| for _, key := range apiKeys { |
| apiKeyStrs = append(apiKeyStrs, key.APIKey) |
| } |
| apiKeyStr := strings.Join(apiKeyStrs, ",") |
|
|
| |
| keyStrategy := model.KeyStrategySequential |
| if len(apiKeys) > 0 && apiKeys[0].KeyStrategy != "" { |
| keyStrategy = apiKeys[0].KeyStrategy |
| } |
|
|
| |
| |
| models := make([]string, 0, len(cfg.ModelEntries)) |
| redirects := make(map[string]string) |
| for _, entry := range cfg.ModelEntries { |
| models = append(models, entry.Model) |
| if entry.RedirectModel != "" { |
| redirects[entry.Model] = entry.RedirectModel |
| } |
| } |
|
|
| modelRedirectsJSON := "{}" |
| if len(redirects) > 0 { |
| if jsonBytes, err := sonic.Marshal(redirects); err == nil { |
| modelRedirectsJSON = string(jsonBytes) |
| } |
| } |
|
|
| record := []string{ |
| strconv.FormatInt(cfg.ID, 10), |
| cfg.Name, |
| apiKeyStr, |
| cfg.URL, |
| strconv.Itoa(cfg.Priority), |
| strconv.Itoa(cfg.RPMLimit), |
| strings.Join(models, ","), |
| modelRedirectsJSON, |
| cfg.GetChannelType(), |
| strings.Join(cfg.GetProtocolTransforms(), ","), |
| cfg.GetProtocolTransformMode(), |
| keyStrategy, |
| strconv.FormatBool(cfg.Enabled), |
| strconv.FormatBool(cfg.ScheduledCheckEnabled), |
| cfg.ScheduledCheckModel, |
| } |
| if err := writer.Write(record); err != nil { |
| RespondError(c, http.StatusInternalServerError, err) |
| return |
| } |
| } |
|
|
| writer.Flush() |
| if err := writer.Error(); err != nil { |
| RespondError(c, http.StatusInternalServerError, err) |
| return |
| } |
|
|
| filename := fmt.Sprintf("channels-%s.csv", time.Now().Format("20060102-150405")) |
| c.Header("Content-Type", "text/csv; charset=utf-8") |
| c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) |
| c.Header("Cache-Control", "no-cache") |
| c.String(http.StatusOK, buf.String()) |
| } |
|
|
| |
| |
| func (s *Server) HandleImportChannelsCSV(c *gin.Context) { |
| fileHeader, err := c.FormFile("file") |
| if err != nil { |
| RespondErrorMsg(c, http.StatusBadRequest, "缺少上传文件") |
| return |
| } |
|
|
| src, err := fileHeader.Open() |
| if err != nil { |
| RespondError(c, http.StatusInternalServerError, err) |
| return |
| } |
| defer func() { _ = src.Close() }() |
|
|
| reader := csv.NewReader(src) |
| reader.TrimLeadingSpace = true |
|
|
| headerRow, err := reader.Read() |
| if err == io.EOF { |
| RespondErrorMsg(c, http.StatusBadRequest, "CSV内容为空") |
| return |
| } |
| if err != nil { |
| RespondError(c, http.StatusBadRequest, err) |
| return |
| } |
|
|
| columnIndex := buildCSVColumnIndex(headerRow) |
| required := []string{"name", "api_key", "url", "models"} |
| for _, key := range required { |
| if _, ok := columnIndex[key]; !ok { |
| RespondErrorMsg(c, http.StatusBadRequest, fmt.Sprintf("缺少必需列: %s", key)) |
| return |
| } |
| } |
|
|
| _, hasScheduledCheckColumn := columnIndex["scheduled_check_enabled"] |
| _, hasScheduledCheckModelColumn := columnIndex["scheduled_check_model"] |
| existingScheduledCheckByName := make(map[string]bool) |
| existingScheduledCheckModelByName := make(map[string]string) |
| if !hasScheduledCheckColumn || !hasScheduledCheckModelColumn { |
| existingConfigs, err := s.store.ListConfigs(c.Request.Context()) |
| if err != nil { |
| RespondError(c, http.StatusInternalServerError, err) |
| return |
| } |
| for _, cfg := range existingConfigs { |
| existingScheduledCheckByName[cfg.Name] = cfg.ScheduledCheckEnabled |
| existingScheduledCheckModelByName[cfg.Name] = cfg.ScheduledCheckModel |
| } |
| } |
|
|
| summary := ChannelImportSummary{} |
| lineNo := 1 |
|
|
| |
| validChannels := make([]*model.ChannelWithKeys, 0, 100) |
|
|
| for { |
| record, err := reader.Read() |
| if err == io.EOF { |
| break |
| } |
| lineNo++ |
|
|
| if err != nil { |
| summary.Errors = append(summary.Errors, fmt.Sprintf("第%d行读取失败: %v", lineNo, err)) |
| summary.Skipped++ |
| continue |
| } |
|
|
| channel, errMsg, skip := s.parseChannelImportRow( |
| record, |
| columnIndex, |
| lineNo, |
| hasScheduledCheckColumn, |
| hasScheduledCheckModelColumn, |
| existingScheduledCheckByName, |
| existingScheduledCheckModelByName, |
| ) |
| if skip { |
| if errMsg != "" { |
| summary.Errors = append(summary.Errors, errMsg) |
| } |
| summary.Skipped++ |
| continue |
| } |
|
|
| |
| validChannels = append(validChannels, channel) |
| } |
|
|
| |
| if len(validChannels) > 0 { |
| created, updated, err := s.store.ImportChannelBatch(c.Request.Context(), validChannels) |
| if err != nil { |
| summary.Errors = append(summary.Errors, fmt.Sprintf("批量导入失败: %v", err)) |
| RespondErrorWithData(c, http.StatusInternalServerError, err.Error(), summary) |
| return |
| } |
| summary.Created = created |
| summary.Updated = updated |
|
|
| |
| if s.urlSelector != nil { |
| seenIDs := make(map[int64]struct{}, len(validChannels)) |
| for _, channel := range validChannels { |
| if channel == nil || channel.Config == nil || channel.Config.ID <= 0 { |
| continue |
| } |
| seenIDs[channel.Config.ID] = struct{}{} |
| } |
| for channelID := range seenIDs { |
| cfg, getErr := s.store.GetConfig(c.Request.Context(), channelID) |
| if getErr != nil || cfg == nil { |
| continue |
| } |
| s.urlSelector.PruneChannel(channelID, cfg.GetURLs()) |
| |
| s.cleanupOrphanedURLStates(c.Request.Context(), channelID, cfg.GetURLs()) |
| } |
| } |
| } |
|
|
| summary.Processed = summary.Created + summary.Updated + summary.Skipped |
|
|
| if len(validChannels) > 0 { |
| s.InvalidateChannelListCache() |
| s.InvalidateAllAPIKeysCache() |
| s.invalidateCooldownCache() |
| } |
|
|
| RespondJSON(c, http.StatusOK, summary) |
| } |
|
|
| |
| |
| |
| |
| |
| func (s *Server) parseChannelImportRow( |
| record []string, |
| columnIndex map[string]int, |
| lineNo int, |
| hasScheduledCheckColumn bool, |
| hasScheduledCheckModelColumn bool, |
| existingScheduledCheckByName map[string]bool, |
| existingScheduledCheckModelByName map[string]string, |
| ) (channel *model.ChannelWithKeys, errMsg string, skip bool) { |
| if isCSVRecordEmpty(record) { |
| return nil, "", true |
| } |
|
|
| fetch := func(key string) string { |
| idx, ok := columnIndex[key] |
| if !ok || idx >= len(record) { |
| return "" |
| } |
| return strings.TrimSpace(record[idx]) |
| } |
|
|
| name := fetch("name") |
| rawID := fetch("id") |
| apiKey := fetch("api_key") |
| url := fetch("url") |
| modelsRaw := fetch("models") |
| modelRedirectsRaw := fetch("model_redirects") |
| channelType := fetch("channel_type") |
| protocolTransformsRaw := fetch("protocol_transforms") |
| protocolTransformMode := model.NormalizeProtocolTransformMode(fetch("protocol_transform_mode")) |
| keyStrategy := fetch("key_strategy") |
|
|
| var missing []string |
| if name == "" { |
| missing = append(missing, "name") |
| } |
| if apiKey == "" { |
| missing = append(missing, "api_key") |
| } |
| if url == "" { |
| missing = append(missing, "url") |
| } |
| if modelsRaw == "" { |
| missing = append(missing, "models") |
| } |
| if len(missing) > 0 { |
| return nil, fmt.Sprintf("第%d行缺少必填字段: %s", lineNo, strings.Join(missing, ", ")), true |
| } |
|
|
| channelID, err := parseImportChannelID(rawID) |
| if err != nil { |
| return nil, fmt.Sprintf("第%d行渠道ID格式错误: %v", lineNo, err), true |
| } |
|
|
| normalizedURL, err := validateChannelURLs(url) |
| if err != nil { |
| return nil, fmt.Sprintf("第%d行URL无效: %v", lineNo, err), true |
| } |
| url = normalizedURL |
|
|
| |
| channelType = util.NormalizeChannelType(channelType) |
| if !util.IsValidChannelType(channelType) { |
| return nil, fmt.Sprintf("第%d行渠道类型无效: %s(仅支持anthropic/codex/gemini)", lineNo, channelType), true |
| } |
|
|
| |
| if keyStrategy == "" { |
| keyStrategy = model.KeyStrategySequential |
| } else if !model.IsValidKeyStrategy(keyStrategy) { |
| return nil, fmt.Sprintf("第%d行Key使用策略无效: %s(仅支持sequential/round_robin)", lineNo, keyStrategy), true |
| } |
| if protocolTransformMode == "" { |
| return nil, fmt.Sprintf("第%d行 protocol_transform_mode 无效: %s", lineNo, fetch("protocol_transform_mode")), true |
| } |
| rawProtocolTransforms := parseProtocolTransformsCSV(protocolTransformsRaw) |
| if err := validateProtocolTransforms(channelType, protocolTransformMode, rawProtocolTransforms); err != nil { |
| return nil, fmt.Sprintf("第%d行 protocol_transforms 无效: %v", lineNo, err), true |
| } |
| protocolTransforms := normalizeProtocolTransforms(channelType, protocolTransformMode, rawProtocolTransforms) |
|
|
| models := parseImportModels(modelsRaw) |
| if len(models) == 0 { |
| return nil, fmt.Sprintf("第%d行模型格式无效", lineNo), true |
| } |
|
|
| |
| var modelRedirects map[string]string |
| if modelRedirectsRaw != "" && modelRedirectsRaw != "{}" { |
| if err := sonic.Unmarshal([]byte(modelRedirectsRaw), &modelRedirects); err != nil { |
| return nil, fmt.Sprintf("第%d行模型重定向格式错误: %v", lineNo, err), true |
| } |
| } |
|
|
| priority := 0 |
| if pRaw := fetch("priority"); pRaw != "" { |
| p, err := strconv.Atoi(pRaw) |
| if err != nil { |
| return nil, fmt.Sprintf("第%d行优先级格式错误: %v", lineNo, err), true |
| } |
| priority = p |
| } |
|
|
| rpmLimit := 0 |
| if rpmRaw := fetch("rpm_limit"); rpmRaw != "" { |
| parsed, err := strconv.Atoi(rpmRaw) |
| if err != nil || parsed < 0 { |
| return nil, fmt.Sprintf("第%d行RPM限制格式错误: %s", lineNo, rpmRaw), true |
| } |
| rpmLimit = parsed |
| } |
|
|
| enabled := true |
| if eRaw := fetch("enabled"); eRaw != "" { |
| if val, ok := parseImportEnabled(eRaw); ok { |
| enabled = val |
| } else { |
| return nil, fmt.Sprintf("第%d行启用状态格式错误: %s", lineNo, eRaw), true |
| } |
| } |
|
|
| scheduledCheckEnabled := existingScheduledCheckByName[name] |
| if raw := fetch("scheduled_check_enabled"); raw != "" { |
| if val, ok := parseImportEnabled(raw); ok { |
| scheduledCheckEnabled = val |
| } else { |
| return nil, fmt.Sprintf("第%d行定时检测开关格式错误: %s", lineNo, raw), true |
| } |
| } else if hasScheduledCheckColumn { |
| scheduledCheckEnabled = false |
| } |
|
|
| rawScheduledCheckModel := fetch("scheduled_check_model") |
| scheduledCheckModel := existingScheduledCheckModelByName[name] |
| shouldValidateScheduledCheckModel := false |
| if rawScheduledCheckModel != "" { |
| scheduledCheckModel = rawScheduledCheckModel |
| shouldValidateScheduledCheckModel = true |
| } else if hasScheduledCheckModelColumn { |
| scheduledCheckModel = "" |
| } |
|
|
| |
| modelEntries := make([]model.ModelEntry, 0, len(models)) |
| for _, m := range models { |
| entry := model.ModelEntry{Model: m} |
| if redirect, ok := modelRedirects[m]; ok { |
| entry.RedirectModel = redirect |
| } |
| modelEntries = append(modelEntries, entry) |
| } |
| if scheduledCheckModel != "" { |
| declared := false |
| for _, entry := range modelEntries { |
| if entry.Model == scheduledCheckModel { |
| declared = true |
| break |
| } |
| } |
| if !declared { |
| if shouldValidateScheduledCheckModel { |
| return nil, fmt.Sprintf("第%d行 scheduled_check_model 无效: %s", lineNo, scheduledCheckModel), true |
| } |
| scheduledCheckModel = "" |
| } |
| } |
|
|
| |
| cfg := &model.Config{ |
| ID: channelID, |
| Name: name, |
| URL: url, |
| Priority: priority, |
| RPMLimit: rpmLimit, |
| ModelEntries: modelEntries, |
| ChannelType: channelType, |
| ProtocolTransformMode: protocolTransformMode, |
| ProtocolTransforms: protocolTransforms, |
| Enabled: enabled, |
| ScheduledCheckEnabled: scheduledCheckEnabled, |
| ScheduledCheckModel: scheduledCheckModel, |
| } |
|
|
| |
| apiKeyList := util.ParseAPIKeys(apiKey) |
| apiKeys := make([]model.APIKey, len(apiKeyList)) |
| for i, key := range apiKeyList { |
| apiKeys[i] = model.APIKey{ |
| KeyIndex: i, |
| APIKey: key, |
| KeyStrategy: keyStrategy, |
| } |
| } |
|
|
| return &model.ChannelWithKeys{ |
| Config: cfg, |
| APIKeys: apiKeys, |
| }, "", false |
| } |
|
|
| func parseProtocolTransformsCSV(raw string) []string { |
| raw = strings.TrimSpace(raw) |
| if raw == "" { |
| return nil |
| } |
| parts := strings.Split(raw, ",") |
| transforms := make([]string, 0, len(parts)) |
| for _, part := range parts { |
| part = strings.TrimSpace(part) |
| if part == "" { |
| continue |
| } |
| transforms = append(transforms, part) |
| } |
| return transforms |
| } |
|
|
| |
|
|
| |
| func buildCSVColumnIndex(header []string) map[string]int { |
| index := make(map[string]int, len(header)) |
| for i, col := range header { |
| norm := normalizeCSVHeader(col) |
| if norm == "" { |
| continue |
| } |
| index[norm] = i |
| } |
| return index |
| } |
|
|
| |
| func normalizeCSVHeader(name string) string { |
| trimmed := strings.TrimSpace(name) |
| trimmed = strings.TrimPrefix(trimmed, "\ufeff") |
| lower := strings.ToLower(trimmed) |
| switch lower { |
| case "apikey", "api-key", "api key": |
| return "api_key" |
| case "model", "model_list", "model(s)": |
| return "models" |
| case "model_redirect", "model-redirects", "modelredirects", "redirects": |
| return "model_redirects" |
| case "key_strategy", "key-strategy", "keystrategy", "策略", "使用策略": |
| return "key_strategy" |
| case "rpm-limit", "rpmlimit", "rpm limit": |
| return "rpm_limit" |
| case "scheduled-check-enabled", "scheduledcheckenabled", "scheduled check enabled": |
| return "scheduled_check_enabled" |
| case "scheduled-check-model", "scheduledcheckmodel", "scheduled check model": |
| return "scheduled_check_model" |
| case "status": |
| return "enabled" |
| default: |
| return lower |
| } |
| } |
|
|
| |
| func isCSVRecordEmpty(record []string) bool { |
| for _, cell := range record { |
| if strings.TrimSpace(cell) != "" { |
| return false |
| } |
| } |
| return true |
| } |
|
|
| |
| func parseImportModels(raw string) []string { |
| if raw == "" { |
| return nil |
| } |
| splitter := func(r rune) bool { |
| switch r { |
| case ',', ';', '|', '\n', '\r', '\t': |
| return true |
| default: |
| return false |
| } |
| } |
| parts := strings.FieldsFunc(raw, splitter) |
| if len(parts) == 0 { |
| return nil |
| } |
| out := make([]string, 0, len(parts)) |
| seen := make(map[string]struct{}, len(parts)) |
| for _, p := range parts { |
| clean := strings.TrimSpace(p) |
| if clean == "" { |
| continue |
| } |
| if _, exists := seen[clean]; exists { |
| continue |
| } |
| seen[clean] = struct{}{} |
| out = append(out, clean) |
| } |
| return out |
| } |
|
|
| |
| func parseImportEnabled(raw string) (bool, bool) { |
| return util.ParseBool(raw) |
| } |
|
|
| func parseImportChannelID(raw string) (int64, error) { |
| raw = strings.TrimSpace(raw) |
| if raw == "" { |
| return 0, nil |
| } |
|
|
| id, err := strconv.ParseInt(raw, 10, 64) |
| if err != nil { |
| return 0, err |
| } |
| if id <= 0 { |
| return 0, fmt.Errorf("must be a positive integer") |
| } |
| return id, nil |
| } |
|
|