Spaces:
Build error
Build error
| package model | |
| import ( | |
| "context" | |
| "fmt" | |
| "one-api/common" | |
| "one-api/logger" | |
| "one-api/types" | |
| "os" | |
| "strings" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "github.com/bytedance/gopkg/util/gopool" | |
| "gorm.io/gorm" | |
| ) | |
| type Log struct { | |
| Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"` | |
| UserId int `json:"user_id" gorm:"index"` | |
| CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | |
| Type int `json:"type" gorm:"index:idx_created_at_type"` | |
| Content string `json:"content"` | |
| Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"` | |
| TokenName string `json:"token_name" gorm:"index;default:''"` | |
| ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | |
| Quota int `json:"quota" gorm:"default:0"` | |
| PromptTokens int `json:"prompt_tokens" gorm:"default:0"` | |
| CompletionTokens int `json:"completion_tokens" gorm:"default:0"` | |
| UseTime int `json:"use_time" gorm:"default:0"` | |
| IsStream bool `json:"is_stream"` | |
| ChannelId int `json:"channel" gorm:"index"` | |
| ChannelName string `json:"channel_name" gorm:"->"` | |
| TokenId int `json:"token_id" gorm:"default:0;index"` | |
| Group string `json:"group" gorm:"index"` | |
| Ip string `json:"ip" gorm:"index;default:''"` | |
| Other string `json:"other"` | |
| } | |
| const ( | |
| LogTypeUnknown = iota | |
| LogTypeTopup | |
| LogTypeConsume | |
| LogTypeManage | |
| LogTypeSystem | |
| LogTypeError | |
| ) | |
| func formatUserLogs(logs []*Log) { | |
| for i := range logs { | |
| logs[i].ChannelName = "" | |
| var otherMap map[string]interface{} | |
| otherMap, _ = common.StrToMap(logs[i].Other) | |
| if otherMap != nil { | |
| // delete admin | |
| delete(otherMap, "admin_info") | |
| } | |
| logs[i].Other = common.MapToJsonStr(otherMap) | |
| logs[i].Id = logs[i].Id % 1024 | |
| } | |
| } | |
| func GetLogByKey(key string) (logs []*Log, err error) { | |
| if os.Getenv("LOG_SQL_DSN") != "" { | |
| var tk Token | |
| if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil { | |
| return nil, err | |
| } | |
| err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error | |
| } else { | |
| err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error | |
| } | |
| formatUserLogs(logs) | |
| return logs, err | |
| } | |
| func RecordLog(userId int, logType int, content string) { | |
| if logType == LogTypeConsume && !common.LogConsumeEnabled { | |
| return | |
| } | |
| username, _ := GetUsernameById(userId, false) | |
| log := &Log{ | |
| UserId: userId, | |
| Username: username, | |
| CreatedAt: common.GetTimestamp(), | |
| Type: logType, | |
| Content: content, | |
| } | |
| err := LOG_DB.Create(log).Error | |
| if err != nil { | |
| common.SysLog("failed to record log: " + err.Error()) | |
| } | |
| } | |
| func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, | |
| isStream bool, group string, other map[string]interface{}) { | |
| logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) | |
| username := c.GetString("username") | |
| otherStr := common.MapToJsonStr(other) | |
| // 判断是否需要记录 IP | |
| needRecordIp := false | |
| if settingMap, err := GetUserSetting(userId, false); err == nil { | |
| if settingMap.RecordIpLog { | |
| needRecordIp = true | |
| } | |
| } | |
| log := &Log{ | |
| UserId: userId, | |
| Username: username, | |
| CreatedAt: common.GetTimestamp(), | |
| Type: LogTypeError, | |
| Content: content, | |
| PromptTokens: 0, | |
| CompletionTokens: 0, | |
| TokenName: tokenName, | |
| ModelName: modelName, | |
| Quota: 0, | |
| ChannelId: channelId, | |
| TokenId: tokenId, | |
| UseTime: useTimeSeconds, | |
| IsStream: isStream, | |
| Group: group, | |
| Ip: func() string { | |
| if needRecordIp { | |
| return c.ClientIP() | |
| } | |
| return "" | |
| }(), | |
| Other: otherStr, | |
| } | |
| err := LOG_DB.Create(log).Error | |
| if err != nil { | |
| logger.LogError(c, "failed to record log: "+err.Error()) | |
| } | |
| } | |
| type RecordConsumeLogParams struct { | |
| ChannelId int `json:"channel_id"` | |
| PromptTokens int `json:"prompt_tokens"` | |
| CompletionTokens int `json:"completion_tokens"` | |
| ModelName string `json:"model_name"` | |
| TokenName string `json:"token_name"` | |
| Quota int `json:"quota"` | |
| Content string `json:"content"` | |
| TokenId int `json:"token_id"` | |
| UseTimeSeconds int `json:"use_time_seconds"` | |
| IsStream bool `json:"is_stream"` | |
| Group string `json:"group"` | |
| Other map[string]interface{} `json:"other"` | |
| } | |
| func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { | |
| if !common.LogConsumeEnabled { | |
| return | |
| } | |
| logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) | |
| username := c.GetString("username") | |
| otherStr := common.MapToJsonStr(params.Other) | |
| // 判断是否需要记录 IP | |
| needRecordIp := false | |
| if settingMap, err := GetUserSetting(userId, false); err == nil { | |
| if settingMap.RecordIpLog { | |
| needRecordIp = true | |
| } | |
| } | |
| log := &Log{ | |
| UserId: userId, | |
| Username: username, | |
| CreatedAt: common.GetTimestamp(), | |
| Type: LogTypeConsume, | |
| Content: params.Content, | |
| PromptTokens: params.PromptTokens, | |
| CompletionTokens: params.CompletionTokens, | |
| TokenName: params.TokenName, | |
| ModelName: params.ModelName, | |
| Quota: params.Quota, | |
| ChannelId: params.ChannelId, | |
| TokenId: params.TokenId, | |
| UseTime: params.UseTimeSeconds, | |
| IsStream: params.IsStream, | |
| Group: params.Group, | |
| Ip: func() string { | |
| if needRecordIp { | |
| return c.ClientIP() | |
| } | |
| return "" | |
| }(), | |
| Other: otherStr, | |
| } | |
| err := LOG_DB.Create(log).Error | |
| if err != nil { | |
| logger.LogError(c, "failed to record log: "+err.Error()) | |
| } | |
| if common.DataExportEnabled { | |
| gopool.Go(func() { | |
| LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens) | |
| }) | |
| } | |
| } | |
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) { | |
| var tx *gorm.DB | |
| if logType == LogTypeUnknown { | |
| tx = LOG_DB | |
| } else { | |
| tx = LOG_DB.Where("logs.type = ?", logType) | |
| } | |
| if modelName != "" { | |
| tx = tx.Where("logs.model_name like ?", modelName) | |
| } | |
| if username != "" { | |
| tx = tx.Where("logs.username = ?", username) | |
| } | |
| if tokenName != "" { | |
| tx = tx.Where("logs.token_name = ?", tokenName) | |
| } | |
| if startTimestamp != 0 { | |
| tx = tx.Where("logs.created_at >= ?", startTimestamp) | |
| } | |
| if endTimestamp != 0 { | |
| tx = tx.Where("logs.created_at <= ?", endTimestamp) | |
| } | |
| if channel != 0 { | |
| tx = tx.Where("logs.channel_id = ?", channel) | |
| } | |
| if group != "" { | |
| tx = tx.Where("logs."+logGroupCol+" = ?", group) | |
| } | |
| err = tx.Model(&Log{}).Count(&total).Error | |
| if err != nil { | |
| return nil, 0, err | |
| } | |
| err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error | |
| if err != nil { | |
| return nil, 0, err | |
| } | |
| channelIds := types.NewSet[int]() | |
| for _, log := range logs { | |
| if log.ChannelId != 0 { | |
| channelIds.Add(log.ChannelId) | |
| } | |
| } | |
| if channelIds.Len() > 0 { | |
| var channels []struct { | |
| Id int `gorm:"column:id"` | |
| Name string `gorm:"column:name"` | |
| } | |
| if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { | |
| return logs, total, err | |
| } | |
| channelMap := make(map[int]string, len(channels)) | |
| for _, channel := range channels { | |
| channelMap[channel.Id] = channel.Name | |
| } | |
| for i := range logs { | |
| logs[i].ChannelName = channelMap[logs[i].ChannelId] | |
| } | |
| } | |
| return logs, total, err | |
| } | |
| func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) { | |
| var tx *gorm.DB | |
| if logType == LogTypeUnknown { | |
| tx = LOG_DB.Where("logs.user_id = ?", userId) | |
| } else { | |
| tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType) | |
| } | |
| if modelName != "" { | |
| tx = tx.Where("logs.model_name like ?", modelName) | |
| } | |
| if tokenName != "" { | |
| tx = tx.Where("logs.token_name = ?", tokenName) | |
| } | |
| if startTimestamp != 0 { | |
| tx = tx.Where("logs.created_at >= ?", startTimestamp) | |
| } | |
| if endTimestamp != 0 { | |
| tx = tx.Where("logs.created_at <= ?", endTimestamp) | |
| } | |
| if group != "" { | |
| tx = tx.Where("logs."+logGroupCol+" = ?", group) | |
| } | |
| err = tx.Model(&Log{}).Count(&total).Error | |
| if err != nil { | |
| return nil, 0, err | |
| } | |
| err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error | |
| if err != nil { | |
| return nil, 0, err | |
| } | |
| formatUserLogs(logs) | |
| return logs, total, err | |
| } | |
| func SearchAllLogs(keyword string) (logs []*Log, err error) { | |
| err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error | |
| return logs, err | |
| } | |
| func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | |
| err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error | |
| formatUserLogs(logs) | |
| return logs, err | |
| } | |
| type Stat struct { | |
| Quota int `json:"quota"` | |
| Rpm int `json:"rpm"` | |
| Tpm int `json:"tpm"` | |
| } | |
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) { | |
| tx := LOG_DB.Table("logs").Select("sum(quota) quota") | |
| // 为rpm和tpm创建单独的查询 | |
| rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") | |
| if username != "" { | |
| tx = tx.Where("username = ?", username) | |
| rpmTpmQuery = rpmTpmQuery.Where("username = ?", username) | |
| } | |
| if tokenName != "" { | |
| tx = tx.Where("token_name = ?", tokenName) | |
| rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName) | |
| } | |
| if startTimestamp != 0 { | |
| tx = tx.Where("created_at >= ?", startTimestamp) | |
| } | |
| if endTimestamp != 0 { | |
| tx = tx.Where("created_at <= ?", endTimestamp) | |
| } | |
| if modelName != "" { | |
| tx = tx.Where("model_name like ?", modelName) | |
| rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName) | |
| } | |
| if channel != 0 { | |
| tx = tx.Where("channel_id = ?", channel) | |
| rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) | |
| } | |
| if group != "" { | |
| tx = tx.Where(logGroupCol+" = ?", group) | |
| rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group) | |
| } | |
| tx = tx.Where("type = ?", LogTypeConsume) | |
| rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume) | |
| // 只统计最近60秒的rpm和tpm | |
| rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix()) | |
| // 执行查询 | |
| tx.Scan(&stat) | |
| rpmTpmQuery.Scan(&stat) | |
| return stat | |
| } | |
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | |
| tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | |
| if username != "" { | |
| tx = tx.Where("username = ?", username) | |
| } | |
| if tokenName != "" { | |
| tx = tx.Where("token_name = ?", tokenName) | |
| } | |
| if startTimestamp != 0 { | |
| tx = tx.Where("created_at >= ?", startTimestamp) | |
| } | |
| if endTimestamp != 0 { | |
| tx = tx.Where("created_at <= ?", endTimestamp) | |
| } | |
| if modelName != "" { | |
| tx = tx.Where("model_name = ?", modelName) | |
| } | |
| tx.Where("type = ?", LogTypeConsume).Scan(&token) | |
| return token | |
| } | |
| func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) { | |
| var total int64 = 0 | |
| for { | |
| if nil != ctx.Err() { | |
| return total, ctx.Err() | |
| } | |
| result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{}) | |
| if nil != result.Error { | |
| return total, result.Error | |
| } | |
| total += result.RowsAffected | |
| if result.RowsAffected < int64(limit) { | |
| break | |
| } | |
| } | |
| return total, nil | |
| } | |