ccpoad / internal /app /admin_csv.go
anyalerob's picture
Upload folder using huggingface_hub
2986042 verified
Raw
History Blame Contribute Delete
17 kB
package app
import (
"bytes"
"encoding/csv"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
"ccLoad/internal/model"
"ccLoad/internal/util"
"github.com/bytedance/sonic"
"github.com/gin-gonic/gin"
)
// ==================== CSV导入导出 ====================
// 从admin.go拆分CSV功能,遵循SRP原则
// HandleExportChannelsCSV 导出渠道为CSV
// GET /admin/channels/export
func (s *Server) HandleExportChannelsCSV(c *gin.Context) {
cfgs, err := s.store.ListConfigs(c.Request.Context())
if err != nil {
RespondError(c, http.StatusInternalServerError, err)
return
}
// 批量查询所有API Keys,消除 N+1
allAPIKeys, err := s.store.GetAllAPIKeys(c.Request.Context())
if err != nil {
log.Printf("[WARN] 批量查询API Keys失败: %v", err)
allAPIKeys = make(map[int64][]*model.APIKey) // 降级:使用空map
}
buf := &bytes.Buffer{}
// 添加 UTF-8 BOM,兼容 Excel 等工具
buf.WriteString("\ufeff")
writer := csv.NewWriter(buf)
defer writer.Flush()
header := []string{"id", "name", "api_key", "url", "priority", "rpm_limit", "models", "model_redirects", "channel_type", "protocol_transforms", "protocol_transform_mode", "key_strategy", "enabled", "scheduled_check_enabled", "scheduled_check_model"}
if err := writer.Write(header); err != nil {
RespondError(c, http.StatusInternalServerError, err)
return
}
for _, cfg := range cfgs {
// 从预加载的map中获取API Keys,O(1)查找
apiKeys := allAPIKeys[cfg.ID]
// 格式化API Keys为逗号分隔字符串
apiKeyStrs := make([]string, 0, len(apiKeys))
for _, key := range apiKeys {
apiKeyStrs = append(apiKeyStrs, key.APIKey)
}
apiKeyStr := strings.Join(apiKeyStrs, ",")
// 获取Key策略(从第一个Key)
keyStrategy := model.KeyStrategySequential // 默认值
if len(apiKeys) > 0 && apiKeys[0].KeyStrategy != "" {
keyStrategy = apiKeys[0].KeyStrategy
}
// 序列化模型列表和重定向为CSV兼容格式
// 格式设计:models用逗号分隔(人类可读+Excel友好),redirects用JSON(结构化数据)
models := make([]string, 0, len(cfg.ModelEntries))
redirects := make(map[string]string)
for _, entry := range cfg.ModelEntries {
models = append(models, entry.Model)
if entry.RedirectModel != "" {
redirects[entry.Model] = entry.RedirectModel
}
}
modelRedirectsJSON := "{}"
if len(redirects) > 0 {
if jsonBytes, err := sonic.Marshal(redirects); err == nil {
modelRedirectsJSON = string(jsonBytes)
}
}
record := []string{
strconv.FormatInt(cfg.ID, 10),
cfg.Name,
apiKeyStr,
cfg.URL,
strconv.Itoa(cfg.Priority),
strconv.Itoa(cfg.RPMLimit),
strings.Join(models, ","),
modelRedirectsJSON,
cfg.GetChannelType(), // 使用GetChannelType确保默认值
strings.Join(cfg.GetProtocolTransforms(), ","),
cfg.GetProtocolTransformMode(),
keyStrategy,
strconv.FormatBool(cfg.Enabled),
strconv.FormatBool(cfg.ScheduledCheckEnabled),
cfg.ScheduledCheckModel,
}
if err := writer.Write(record); err != nil {
RespondError(c, http.StatusInternalServerError, err)
return
}
}
writer.Flush()
if err := writer.Error(); err != nil {
RespondError(c, http.StatusInternalServerError, err)
return
}
filename := fmt.Sprintf("channels-%s.csv", time.Now().Format("20060102-150405"))
c.Header("Content-Type", "text/csv; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
c.Header("Cache-Control", "no-cache")
c.String(http.StatusOK, buf.String())
}
// HandleImportChannelsCSV 导入渠道CSV
// POST /admin/channels/import
func (s *Server) HandleImportChannelsCSV(c *gin.Context) {
fileHeader, err := c.FormFile("file")
if err != nil {
RespondErrorMsg(c, http.StatusBadRequest, "缺少上传文件")
return
}
src, err := fileHeader.Open()
if err != nil {
RespondError(c, http.StatusInternalServerError, err)
return
}
defer func() { _ = src.Close() }()
reader := csv.NewReader(src)
reader.TrimLeadingSpace = true
headerRow, err := reader.Read()
if err == io.EOF {
RespondErrorMsg(c, http.StatusBadRequest, "CSV内容为空")
return
}
if err != nil {
RespondError(c, http.StatusBadRequest, err)
return
}
columnIndex := buildCSVColumnIndex(headerRow)
required := []string{"name", "api_key", "url", "models"}
for _, key := range required {
if _, ok := columnIndex[key]; !ok {
RespondErrorMsg(c, http.StatusBadRequest, fmt.Sprintf("缺少必需列: %s", key))
return
}
}
_, hasScheduledCheckColumn := columnIndex["scheduled_check_enabled"]
_, hasScheduledCheckModelColumn := columnIndex["scheduled_check_model"]
existingScheduledCheckByName := make(map[string]bool)
existingScheduledCheckModelByName := make(map[string]string)
if !hasScheduledCheckColumn || !hasScheduledCheckModelColumn {
existingConfigs, err := s.store.ListConfigs(c.Request.Context())
if err != nil {
RespondError(c, http.StatusInternalServerError, err)
return
}
for _, cfg := range existingConfigs {
existingScheduledCheckByName[cfg.Name] = cfg.ScheduledCheckEnabled
existingScheduledCheckModelByName[cfg.Name] = cfg.ScheduledCheckModel
}
}
summary := ChannelImportSummary{}
lineNo := 1
// 批量收集有效记录,最后一次性导入(减少数据库往返)
validChannels := make([]*model.ChannelWithKeys, 0, 100) // 预分配容量,减少扩容
for {
record, err := reader.Read()
if err == io.EOF {
break
}
lineNo++
if err != nil {
summary.Errors = append(summary.Errors, fmt.Sprintf("第%d行读取失败: %v", lineNo, err))
summary.Skipped++
continue
}
channel, errMsg, skip := s.parseChannelImportRow(
record,
columnIndex,
lineNo,
hasScheduledCheckColumn,
hasScheduledCheckModelColumn,
existingScheduledCheckByName,
existingScheduledCheckModelByName,
)
if skip {
if errMsg != "" {
summary.Errors = append(summary.Errors, errMsg)
}
summary.Skipped++
continue
}
// 收集有效记录
validChannels = append(validChannels, channel)
}
// 批量导入所有有效记录(单事务 + 预编译语句)
if len(validChannels) > 0 {
created, updated, err := s.store.ImportChannelBatch(c.Request.Context(), validChannels)
if err != nil {
summary.Errors = append(summary.Errors, fmt.Sprintf("批量导入失败: %v", err))
RespondErrorWithData(c, http.StatusInternalServerError, err.Error(), summary)
return
}
summary.Created = created
summary.Updated = updated
// 导入会更新渠道URL,立即清理 URLSelector 中失效URL状态,避免旧状态长期残留。
if s.urlSelector != nil {
seenIDs := make(map[int64]struct{}, len(validChannels))
for _, channel := range validChannels {
if channel == nil || channel.Config == nil || channel.Config.ID <= 0 {
continue
}
seenIDs[channel.Config.ID] = struct{}{}
}
for channelID := range seenIDs {
cfg, getErr := s.store.GetConfig(c.Request.Context(), channelID)
if getErr != nil || cfg == nil {
continue
}
s.urlSelector.PruneChannel(channelID, cfg.GetURLs())
// 同步清理数据库中已移除URL的禁用状态记录
s.cleanupOrphanedURLStates(c.Request.Context(), channelID, cfg.GetURLs())
}
}
}
summary.Processed = summary.Created + summary.Updated + summary.Skipped
if len(validChannels) > 0 {
s.InvalidateChannelListCache()
s.InvalidateAllAPIKeysCache()
s.invalidateCooldownCache()
}
RespondJSON(c, http.StatusOK, summary)
}
// parseChannelImportRow 解析单行 CSV 记录为渠道配置。
// 返回三态:
// - skip=true, errMsg=="": 空行,调用方仅累加 Skipped
// - skip=true, errMsg!="": 解析错误,调用方追加 errors 并 Skipped++
// - skip=false, channel!=nil: 解析成功,调用方追加 validChannels
func (s *Server) parseChannelImportRow(
record []string,
columnIndex map[string]int,
lineNo int,
hasScheduledCheckColumn bool,
hasScheduledCheckModelColumn bool,
existingScheduledCheckByName map[string]bool,
existingScheduledCheckModelByName map[string]string,
) (channel *model.ChannelWithKeys, errMsg string, skip bool) {
if isCSVRecordEmpty(record) {
return nil, "", true
}
fetch := func(key string) string {
idx, ok := columnIndex[key]
if !ok || idx >= len(record) {
return ""
}
return strings.TrimSpace(record[idx])
}
name := fetch("name")
rawID := fetch("id")
apiKey := fetch("api_key")
url := fetch("url")
modelsRaw := fetch("models")
modelRedirectsRaw := fetch("model_redirects")
channelType := fetch("channel_type")
protocolTransformsRaw := fetch("protocol_transforms")
protocolTransformMode := model.NormalizeProtocolTransformMode(fetch("protocol_transform_mode"))
keyStrategy := fetch("key_strategy")
var missing []string
if name == "" {
missing = append(missing, "name")
}
if apiKey == "" {
missing = append(missing, "api_key")
}
if url == "" {
missing = append(missing, "url")
}
if modelsRaw == "" {
missing = append(missing, "models")
}
if len(missing) > 0 {
return nil, fmt.Sprintf("第%d行缺少必填字段: %s", lineNo, strings.Join(missing, ", ")), true
}
channelID, err := parseImportChannelID(rawID)
if err != nil {
return nil, fmt.Sprintf("第%d行渠道ID格式错误: %v", lineNo, err), true
}
normalizedURL, err := validateChannelURLs(url)
if err != nil {
return nil, fmt.Sprintf("第%d行URL无效: %v", lineNo, err), true
}
url = normalizedURL
// 渠道类型规范化与校验(openai → codex,空值 → anthropic)
channelType = util.NormalizeChannelType(channelType)
if !util.IsValidChannelType(channelType) {
return nil, fmt.Sprintf("第%d行渠道类型无效: %s(仅支持anthropic/codex/gemini)", lineNo, channelType), true
}
// 验证Key使用策略(可选字段,默认sequential)
if keyStrategy == "" {
keyStrategy = model.KeyStrategySequential // 默认值
} else if !model.IsValidKeyStrategy(keyStrategy) {
return nil, fmt.Sprintf("第%d行Key使用策略无效: %s(仅支持sequential/round_robin)", lineNo, keyStrategy), true
}
if protocolTransformMode == "" {
return nil, fmt.Sprintf("第%d行 protocol_transform_mode 无效: %s", lineNo, fetch("protocol_transform_mode")), true
}
rawProtocolTransforms := parseProtocolTransformsCSV(protocolTransformsRaw)
if err := validateProtocolTransforms(channelType, protocolTransformMode, rawProtocolTransforms); err != nil {
return nil, fmt.Sprintf("第%d行 protocol_transforms 无效: %v", lineNo, err), true
}
protocolTransforms := normalizeProtocolTransforms(channelType, protocolTransformMode, rawProtocolTransforms)
models := parseImportModels(modelsRaw)
if len(models) == 0 {
return nil, fmt.Sprintf("第%d行模型格式无效", lineNo), true
}
// 解析模型重定向(可选字段)
var modelRedirects map[string]string
if modelRedirectsRaw != "" && modelRedirectsRaw != "{}" {
if err := sonic.Unmarshal([]byte(modelRedirectsRaw), &modelRedirects); err != nil {
return nil, fmt.Sprintf("第%d行模型重定向格式错误: %v", lineNo, err), true
}
}
priority := 0
if pRaw := fetch("priority"); pRaw != "" {
p, err := strconv.Atoi(pRaw)
if err != nil {
return nil, fmt.Sprintf("第%d行优先级格式错误: %v", lineNo, err), true
}
priority = p
}
rpmLimit := 0
if rpmRaw := fetch("rpm_limit"); rpmRaw != "" {
parsed, err := strconv.Atoi(rpmRaw)
if err != nil || parsed < 0 {
return nil, fmt.Sprintf("第%d行RPM限制格式错误: %s", lineNo, rpmRaw), true
}
rpmLimit = parsed
}
enabled := true
if eRaw := fetch("enabled"); eRaw != "" {
if val, ok := parseImportEnabled(eRaw); ok {
enabled = val
} else {
return nil, fmt.Sprintf("第%d行启用状态格式错误: %s", lineNo, eRaw), true
}
}
scheduledCheckEnabled := existingScheduledCheckByName[name]
if raw := fetch("scheduled_check_enabled"); raw != "" {
if val, ok := parseImportEnabled(raw); ok {
scheduledCheckEnabled = val
} else {
return nil, fmt.Sprintf("第%d行定时检测开关格式错误: %s", lineNo, raw), true
}
} else if hasScheduledCheckColumn {
scheduledCheckEnabled = false
}
rawScheduledCheckModel := fetch("scheduled_check_model")
scheduledCheckModel := existingScheduledCheckModelByName[name]
shouldValidateScheduledCheckModel := false
if rawScheduledCheckModel != "" {
scheduledCheckModel = rawScheduledCheckModel
shouldValidateScheduledCheckModel = true
} else if hasScheduledCheckModelColumn {
scheduledCheckModel = ""
}
// 构建模型条目(合并models和modelRedirects)
modelEntries := make([]model.ModelEntry, 0, len(models))
for _, m := range models {
entry := model.ModelEntry{Model: m}
if redirect, ok := modelRedirects[m]; ok {
entry.RedirectModel = redirect
}
modelEntries = append(modelEntries, entry)
}
if scheduledCheckModel != "" {
declared := false
for _, entry := range modelEntries {
if entry.Model == scheduledCheckModel {
declared = true
break
}
}
if !declared {
if shouldValidateScheduledCheckModel {
return nil, fmt.Sprintf("第%d行 scheduled_check_model 无效: %s", lineNo, scheduledCheckModel), true
}
scheduledCheckModel = ""
}
}
// 构建渠道配置
cfg := &model.Config{
ID: channelID,
Name: name,
URL: url,
Priority: priority,
RPMLimit: rpmLimit,
ModelEntries: modelEntries,
ChannelType: channelType,
ProtocolTransformMode: protocolTransformMode,
ProtocolTransforms: protocolTransforms,
Enabled: enabled,
ScheduledCheckEnabled: scheduledCheckEnabled,
ScheduledCheckModel: scheduledCheckModel,
}
// 解析并构建API Keys
apiKeyList := util.ParseAPIKeys(apiKey)
apiKeys := make([]model.APIKey, len(apiKeyList))
for i, key := range apiKeyList {
apiKeys[i] = model.APIKey{
KeyIndex: i,
APIKey: key,
KeyStrategy: keyStrategy,
}
}
return &model.ChannelWithKeys{
Config: cfg,
APIKeys: apiKeys,
}, "", false
}
func parseProtocolTransformsCSV(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
transforms := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
transforms = append(transforms, part)
}
return transforms
}
// ==================== CSV辅助函数 ====================
// buildCSVColumnIndex 构建CSV列索引映射
func buildCSVColumnIndex(header []string) map[string]int {
index := make(map[string]int, len(header))
for i, col := range header {
norm := normalizeCSVHeader(col)
if norm == "" {
continue
}
index[norm] = i
}
return index
}
// normalizeCSVHeader 规范化CSV列名
func normalizeCSVHeader(name string) string {
trimmed := strings.TrimSpace(name)
trimmed = strings.TrimPrefix(trimmed, "\ufeff")
lower := strings.ToLower(trimmed)
switch lower {
case "apikey", "api-key", "api key":
return "api_key"
case "model", "model_list", "model(s)":
return "models"
case "model_redirect", "model-redirects", "modelredirects", "redirects":
return "model_redirects"
case "key_strategy", "key-strategy", "keystrategy", "策略", "使用策略":
return "key_strategy"
case "rpm-limit", "rpmlimit", "rpm limit":
return "rpm_limit"
case "scheduled-check-enabled", "scheduledcheckenabled", "scheduled check enabled":
return "scheduled_check_enabled"
case "scheduled-check-model", "scheduledcheckmodel", "scheduled check model":
return "scheduled_check_model"
case "status":
return "enabled"
default:
return lower
}
}
// isCSVRecordEmpty 检查CSV记录是否为空
func isCSVRecordEmpty(record []string) bool {
for _, cell := range record {
if strings.TrimSpace(cell) != "" {
return false
}
}
return true
}
// parseImportModels 解析CSV中的模型列表
func parseImportModels(raw string) []string {
if raw == "" {
return nil
}
splitter := func(r rune) bool {
switch r {
case ',', ';', '|', '\n', '\r', '\t':
return true
default:
return false
}
}
parts := strings.FieldsFunc(raw, splitter)
if len(parts) == 0 {
return nil
}
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
clean := strings.TrimSpace(p)
if clean == "" {
continue
}
if _, exists := seen[clean]; exists {
continue
}
seen[clean] = struct{}{}
out = append(out, clean)
}
return out
}
// parseImportEnabled 解析CSV中的启用状态
func parseImportEnabled(raw string) (bool, bool) {
return util.ParseBool(raw)
}
func parseImportChannelID(raw string) (int64, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return 0, nil
}
id, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return 0, err
}
if id <= 0 {
return 0, fmt.Errorf("must be a positive integer")
}
return id, nil
}