package sql import ( "database/sql" "encoding/json" "fmt" "log/slog" "strconv" "strings" "time" "ccLoad/internal/model" ) // WhereBuilder SQL WHERE 子句构建器 type WhereBuilder struct { conditions []string args []any } // NewWhereBuilder 创建新的 WHERE 构建器 func NewWhereBuilder() *WhereBuilder { return &WhereBuilder{ conditions: make([]string, 0), args: make([]any, 0), } } // AddCondition 添加SQL WHERE条件子句 // // 【SQL注入防护约束】 // - condition参数必须是代码中的字符串字面量或常量,禁止拼接用户输入 // - 用户输入必须通过args参数传递,自动参数化为占位符(?) // - 违反约束将导致SQL注入漏洞,必须通过代码审查/静态分析工具检测 // // 正确示例: // // wb.AddCondition("channel_id = ?", userInputChannelID) // ✅ 用户输入通过args传递 // wb.AddCondition("status IN (?, ?)", "active", "pending") // ✅ 多个占位符 // // 错误示例: // // wb.AddCondition("channel_id = " + userInput) // ❌ SQL注入风险! // wb.AddCondition(fmt.Sprintf("name LIKE '%%%s%%'", userInput)) // ❌ SQL注入风险! // // 静态检查建议: 使用gosec/semgrep扫描所有调用点,确保condition参数不包含fmt.Sprintf/字符串拼接 func (wb *WhereBuilder) AddCondition(condition string, args ...any) *WhereBuilder { if condition == "" { return wb } wb.conditions = append(wb.conditions, condition) wb.args = append(wb.args, args...) return wb } // ApplyLogFilter 应用日志过滤器,消除重复的过滤逻辑 func (wb *WhereBuilder) ApplyLogFilter(filter *model.LogFilter) *WhereBuilder { if filter == nil { wb.AddCondition("log_source = ?", model.LogSourceProxy) return wb } if filter.ChannelID != nil { wb.AddCondition("channel_id = ?", *filter.ChannelID) } // 注意:ChannelType/ChannelName/ChannelNameLike 不在此处处理。 // logs 表只有 channel_id;这类过滤应由 SQLStore.applyChannelFilter 先解析出候选 channel_id 集合再 WhereIn。 if filter.Model != "" { wb.AddCondition("model = ?", filter.Model) } if filter.ModelLike != "" { wb.AddCondition("model LIKE ?", "%"+filter.ModelLike+"%") } if filter.StatusCode != nil { wb.AddCondition("status_code = ?", *filter.StatusCode) } if filter.AuthTokenID != nil { wb.AddCondition("auth_token_id = ?", *filter.AuthTokenID) } switch filter.LogSource { case model.LogSourceAll: case model.LogSourceDetection: wb.AddCondition("log_source IN (?, ?)", model.LogSourceScheduledCheck, model.LogSourceManualTest) case "": wb.AddCondition("log_source = ?", model.LogSourceProxy) default: wb.AddCondition("log_source = ?", filter.LogSource) } return wb } // Build 构建最终的 WHERE 子句和参数 func (wb *WhereBuilder) Build() (string, []any) { if len(wb.conditions) == 0 { return "", wb.args } return strings.Join(wb.conditions, " AND "), wb.args } // BuildWithPrefix 构建带前缀的 WHERE 子句 func (wb *WhereBuilder) BuildWithPrefix(prefix string) (string, []any) { whereClause, args := wb.Build() if whereClause == "" { return "", args } return prefix + " " + whereClause, args } // ConfigScanner 统一的 Config 扫描器 type ConfigScanner struct{} // NewConfigScanner 创建新的配置扫描器 func NewConfigScanner() *ConfigScanner { return &ConfigScanner{} } // ScanConfig 扫描单行配置数据(不含模型数据,需要单独查询channel_models表) func (cs *ConfigScanner) ScanConfig(scanner interface { Scan(...any) error }) (*model.Config, error) { var c model.Config var enabledInt int var scheduledCheckEnabledInt int var scheduledCheckModel string var customRequestRules sql.NullString var createdAtRaw, updatedAtRaw any // 使用any接受任意类型(兼容字符串、整数或RFC3339) // 扫描key_count字段(从JOIN查询获取) // 注意:不再包含 models 和 model_redirects 字段 if err := scanner.Scan(&c.ID, &c.Name, &c.URL, &c.Priority, &c.RPMLimit, &c.ChannelType, &c.ProtocolTransformMode, &enabledInt, &scheduledCheckEnabledInt, &scheduledCheckModel, &c.CooldownUntil, &c.CooldownDurationMs, &c.DailyCostLimit, &c.CostMultiplier, &customRequestRules, &c.KeyCount, &createdAtRaw, &updatedAtRaw); err != nil { return nil, err } c.Enabled = enabledInt != 0 c.ScheduledCheckEnabled = scheduledCheckEnabledInt != 0 c.ScheduledCheckModel = scheduledCheckModel c.CustomRequestRules = parseCustomRequestRules(c.ID, customRequestRules) if c.CostMultiplier < 0 { c.CostMultiplier = 1 } // 转换时间戳(支持不同数据库) now := time.Now() c.CreatedAt = model.JSONTime{Time: cs.parseTimestampOrNow(createdAtRaw, now)} c.UpdatedAt = model.JSONTime{Time: cs.parseTimestampOrNow(updatedAtRaw, now)} // ModelEntries 需要通过 LoadModelEntries 方法单独加载 c.ModelEntries = nil return &c, nil } // ScanConfigs 扫描多行配置数据 func (cs *ConfigScanner) ScanConfigs(rows interface { Next() bool Scan(...any) error Err() error }) ([]*model.Config, error) { var configs []*model.Config for rows.Next() { config, err := cs.ScanConfig(rows) if err != nil { return nil, err } configs = append(configs, config) } if err := rows.Err(); err != nil { return nil, err } return configs, nil } // parseTimestampOrNow 解析时间戳或使用当前时间(支持Unix时间戳和RFC3339格式) // 优先级:int64 > int > string(数字) > string(RFC3339) > fallback func (cs *ConfigScanner) parseTimestampOrNow(val any, fallback time.Time) time.Time { switch v := val.(type) { case int64: if v > 0 { return unixToTime(v) } case int: if v > 0 { return unixToTime(int64(v)) } case string: // 1. 尝试解析字符串为Unix时间戳 if ts, err := strconv.ParseInt(v, 10, 64); err == nil && ts > 0 { return unixToTime(ts) } // 2. 尝试解析RFC3339格式 if t, err := time.Parse(time.RFC3339, v); err == nil { return t } // 3. 尝试解析常见ISO8601变体(兼容数据库TIMESTAMP格式) for _, layout := range []string{ time.RFC3339Nano, "2006-01-02T15:04:05.999999999Z07:00", "2006-01-02 15:04:05.999999999 -07:00 MST", } { if t, err := time.Parse(layout, v); err == nil { return t } } } // 非法值:返回fallback return fallback } // QueryBuilder 通用查询构建器 type QueryBuilder struct { baseQuery string wb *WhereBuilder } // NewQueryBuilder 创建新的查询构建器 func NewQueryBuilder(baseQuery string) *QueryBuilder { return &QueryBuilder{ baseQuery: baseQuery, wb: NewWhereBuilder(), } } // Where 添加 WHERE 条件 func (qb *QueryBuilder) Where(condition string, args ...any) *QueryBuilder { qb.wb.AddCondition(condition, args...) return qb } // ApplyFilter 应用过滤器 func (qb *QueryBuilder) ApplyFilter(filter *model.LogFilter) *QueryBuilder { qb.wb.ApplyLogFilter(filter) return qb } // WhereIn 添加 IN 条件,自动生成占位符 func (qb *QueryBuilder) WhereIn(column string, values []any) *QueryBuilder { if len(values) == 0 { // 无值时添加恒为假的条件,确保不返回记录 qb.wb.AddCondition("1=0") return qb } // 生成 ?,?,? 占位符 placeholders := make([]string, len(values)) for i := range values { placeholders[i] = "?" } cond := fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")) qb.wb.AddCondition(cond, values...) return qb } // Build 构建最终查询 func (qb *QueryBuilder) Build() (string, []any) { whereClause, args := qb.wb.BuildWithPrefix("WHERE") query := qb.baseQuery if whereClause != "" { query += " " + whereClause } return query, args } // BuildWithSuffix 构建带后缀的查询(如 ORDER BY, LIMIT 等) func (qb *QueryBuilder) BuildWithSuffix(suffix string) (string, []any) { query, args := qb.Build() if suffix != "" { query += " " + suffix } return query, args } // parseCustomRequestRules 将数据库列值解析为 CustomRequestRules,解析失败时返回 nil 并写入警告日志。 func parseCustomRequestRules(channelID int64, raw sql.NullString) *model.CustomRequestRules { if !raw.Valid { return nil } trimmed := strings.TrimSpace(raw.String) if trimmed == "" || trimmed == "null" { return nil } var rules model.CustomRequestRules if err := json.Unmarshal([]byte(trimmed), &rules); err != nil { slog.Warn("custom_request_rules: unmarshal failed, treated as empty", "channel_id", channelID, "error", err.Error()) return nil } if rules.IsEmpty() { return nil } return &rules } // marshalCustomRequestRules 将结构体序列化为数据库存储字符串;空规则返回空字符串(NULL)。 func marshalCustomRequestRules(rules *model.CustomRequestRules) (sql.NullString, error) { if rules == nil || rules.IsEmpty() { return sql.NullString{}, nil } data, err := json.Marshal(rules) if err != nil { return sql.NullString{}, fmt.Errorf("marshal custom_request_rules: %w", err) } return sql.NullString{String: string(data), Valid: true}, nil }