|
|
package model |
|
|
|
|
|
import ( |
|
|
"fmt" |
|
|
"log" |
|
|
"os" |
|
|
"strings" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
|
|
|
"github.com/glebarez/sqlite" |
|
|
"gorm.io/driver/mysql" |
|
|
"gorm.io/driver/postgres" |
|
|
"gorm.io/gorm" |
|
|
) |
|
|
|
|
|
var commonGroupCol string |
|
|
var commonKeyCol string |
|
|
var commonTrueVal string |
|
|
var commonFalseVal string |
|
|
|
|
|
var logKeyCol string |
|
|
var logGroupCol string |
|
|
|
|
|
func initCol() { |
|
|
|
|
|
if common.UsingPostgreSQL { |
|
|
commonGroupCol = `"group"` |
|
|
commonKeyCol = `"key"` |
|
|
commonTrueVal = "true" |
|
|
commonFalseVal = "false" |
|
|
} else { |
|
|
commonGroupCol = "`group`" |
|
|
commonKeyCol = "`key`" |
|
|
commonTrueVal = "1" |
|
|
commonFalseVal = "0" |
|
|
} |
|
|
if os.Getenv("LOG_SQL_DSN") != "" { |
|
|
switch common.LogSqlType { |
|
|
case common.DatabaseTypePostgreSQL: |
|
|
logGroupCol = `"group"` |
|
|
logKeyCol = `"key"` |
|
|
default: |
|
|
logGroupCol = commonGroupCol |
|
|
logKeyCol = commonKeyCol |
|
|
} |
|
|
} else { |
|
|
|
|
|
if common.UsingPostgreSQL { |
|
|
logGroupCol = `"group"` |
|
|
logKeyCol = `"key"` |
|
|
} else { |
|
|
logGroupCol = commonGroupCol |
|
|
logKeyCol = commonKeyCol |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
var DB *gorm.DB |
|
|
|
|
|
var LOG_DB *gorm.DB |
|
|
|
|
|
func createRootAccountIfNeed() error { |
|
|
var user User |
|
|
|
|
|
if err := DB.First(&user).Error; err != nil { |
|
|
common.SysLog("no user exists, create a root user for you: username is root, password is 123456") |
|
|
hashedPassword, err := common.Password2Hash("123456") |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
rootUser := User{ |
|
|
Username: "root", |
|
|
Password: hashedPassword, |
|
|
Role: common.RoleRootUser, |
|
|
Status: common.UserStatusEnabled, |
|
|
DisplayName: "Root User", |
|
|
AccessToken: nil, |
|
|
Quota: 100000000, |
|
|
} |
|
|
DB.Create(&rootUser) |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
func CheckSetup() { |
|
|
setup := GetSetup() |
|
|
if setup == nil { |
|
|
|
|
|
if RootUserExists() { |
|
|
common.SysLog("system is not initialized, but root user exists") |
|
|
|
|
|
newSetup := Setup{ |
|
|
Version: common.Version, |
|
|
InitializedAt: time.Now().Unix(), |
|
|
} |
|
|
err := DB.Create(&newSetup).Error |
|
|
if err != nil { |
|
|
common.SysLog("failed to create setup record: " + err.Error()) |
|
|
} |
|
|
constant.Setup = true |
|
|
} else { |
|
|
common.SysLog("system is not initialized and no root user exists") |
|
|
constant.Setup = false |
|
|
} |
|
|
} else { |
|
|
|
|
|
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) |
|
|
constant.Setup = true |
|
|
} |
|
|
} |
|
|
|
|
|
func chooseDB(envName string, isLog bool) (*gorm.DB, error) { |
|
|
defer func() { |
|
|
initCol() |
|
|
}() |
|
|
dsn := os.Getenv(envName) |
|
|
if dsn != "" { |
|
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { |
|
|
|
|
|
common.SysLog("using PostgreSQL as database") |
|
|
if !isLog { |
|
|
common.UsingPostgreSQL = true |
|
|
} else { |
|
|
common.LogSqlType = common.DatabaseTypePostgreSQL |
|
|
} |
|
|
return gorm.Open(postgres.New(postgres.Config{ |
|
|
DSN: dsn, |
|
|
PreferSimpleProtocol: true, |
|
|
}), &gorm.Config{ |
|
|
PrepareStmt: true, |
|
|
}) |
|
|
} |
|
|
if strings.HasPrefix(dsn, "local") { |
|
|
common.SysLog("SQL_DSN not set, using SQLite as database") |
|
|
if !isLog { |
|
|
common.UsingSQLite = true |
|
|
} else { |
|
|
common.LogSqlType = common.DatabaseTypeSQLite |
|
|
} |
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ |
|
|
PrepareStmt: true, |
|
|
}) |
|
|
} |
|
|
|
|
|
common.SysLog("using MySQL as database") |
|
|
|
|
|
if !strings.Contains(dsn, "parseTime") { |
|
|
if strings.Contains(dsn, "?") { |
|
|
dsn += "&parseTime=true" |
|
|
} else { |
|
|
dsn += "?parseTime=true" |
|
|
} |
|
|
} |
|
|
if !isLog { |
|
|
common.UsingMySQL = true |
|
|
} else { |
|
|
common.LogSqlType = common.DatabaseTypeMySQL |
|
|
} |
|
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{ |
|
|
PrepareStmt: true, |
|
|
}) |
|
|
} |
|
|
|
|
|
common.SysLog("SQL_DSN not set, using SQLite as database") |
|
|
common.UsingSQLite = true |
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ |
|
|
PrepareStmt: true, |
|
|
}) |
|
|
} |
|
|
|
|
|
func InitDB() (err error) { |
|
|
db, err := chooseDB("SQL_DSN", false) |
|
|
if err == nil { |
|
|
if common.DebugEnabled { |
|
|
db = db.Debug() |
|
|
} |
|
|
DB = db |
|
|
|
|
|
if common.UsingMySQL { |
|
|
if err := checkMySQLChineseSupport(DB); err != nil { |
|
|
panic(err) |
|
|
} |
|
|
} |
|
|
sqlDB, err := DB.DB() |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) |
|
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) |
|
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) |
|
|
|
|
|
if !common.IsMasterNode { |
|
|
return nil |
|
|
} |
|
|
if common.UsingMySQL { |
|
|
|
|
|
} |
|
|
common.SysLog("database migration started") |
|
|
err = migrateDB() |
|
|
return err |
|
|
} else { |
|
|
common.FatalLog(err) |
|
|
} |
|
|
return err |
|
|
} |
|
|
|
|
|
func InitLogDB() (err error) { |
|
|
if os.Getenv("LOG_SQL_DSN") == "" { |
|
|
LOG_DB = DB |
|
|
return |
|
|
} |
|
|
db, err := chooseDB("LOG_SQL_DSN", true) |
|
|
if err == nil { |
|
|
if common.DebugEnabled { |
|
|
db = db.Debug() |
|
|
} |
|
|
LOG_DB = db |
|
|
|
|
|
if common.LogSqlType == common.DatabaseTypeMySQL { |
|
|
if err := checkMySQLChineseSupport(LOG_DB); err != nil { |
|
|
panic(err) |
|
|
} |
|
|
} |
|
|
sqlDB, err := LOG_DB.DB() |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) |
|
|
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) |
|
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) |
|
|
|
|
|
if !common.IsMasterNode { |
|
|
return nil |
|
|
} |
|
|
common.SysLog("database migration started") |
|
|
err = migrateLOGDB() |
|
|
return err |
|
|
} else { |
|
|
common.FatalLog(err) |
|
|
} |
|
|
return err |
|
|
} |
|
|
|
|
|
func migrateDB() error { |
|
|
err := DB.AutoMigrate( |
|
|
&Channel{}, |
|
|
&Token{}, |
|
|
&User{}, |
|
|
&PasskeyCredential{}, |
|
|
&Option{}, |
|
|
&Redemption{}, |
|
|
&Ability{}, |
|
|
&Log{}, |
|
|
&Midjourney{}, |
|
|
&TopUp{}, |
|
|
&QuotaData{}, |
|
|
&Task{}, |
|
|
&Model{}, |
|
|
&Vendor{}, |
|
|
&PrefillGroup{}, |
|
|
&Setup{}, |
|
|
&TwoFA{}, |
|
|
&TwoFABackupCode{}, |
|
|
) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
func migrateDBFast() error { |
|
|
|
|
|
var wg sync.WaitGroup |
|
|
|
|
|
migrations := []struct { |
|
|
model interface{} |
|
|
name string |
|
|
}{ |
|
|
{&Channel{}, "Channel"}, |
|
|
{&Token{}, "Token"}, |
|
|
{&User{}, "User"}, |
|
|
{&PasskeyCredential{}, "PasskeyCredential"}, |
|
|
{&Option{}, "Option"}, |
|
|
{&Redemption{}, "Redemption"}, |
|
|
{&Ability{}, "Ability"}, |
|
|
{&Log{}, "Log"}, |
|
|
{&Midjourney{}, "Midjourney"}, |
|
|
{&TopUp{}, "TopUp"}, |
|
|
{&QuotaData{}, "QuotaData"}, |
|
|
{&Task{}, "Task"}, |
|
|
{&Model{}, "Model"}, |
|
|
{&Vendor{}, "Vendor"}, |
|
|
{&PrefillGroup{}, "PrefillGroup"}, |
|
|
{&Setup{}, "Setup"}, |
|
|
{&TwoFA{}, "TwoFA"}, |
|
|
{&TwoFABackupCode{}, "TwoFABackupCode"}, |
|
|
} |
|
|
|
|
|
errChan := make(chan error, len(migrations)) |
|
|
|
|
|
for _, m := range migrations { |
|
|
wg.Add(1) |
|
|
go func(model interface{}, name string) { |
|
|
defer wg.Done() |
|
|
if err := DB.AutoMigrate(model); err != nil { |
|
|
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err) |
|
|
} |
|
|
}(m.model, m.name) |
|
|
} |
|
|
|
|
|
|
|
|
wg.Wait() |
|
|
close(errChan) |
|
|
|
|
|
|
|
|
for err := range errChan { |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
} |
|
|
common.SysLog("database migrated") |
|
|
return nil |
|
|
} |
|
|
|
|
|
func migrateLOGDB() error { |
|
|
var err error |
|
|
if err = LOG_DB.AutoMigrate(&Log{}); err != nil { |
|
|
return err |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
func closeDB(db *gorm.DB) error { |
|
|
sqlDB, err := db.DB() |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
err = sqlDB.Close() |
|
|
return err |
|
|
} |
|
|
|
|
|
func CloseDB() error { |
|
|
if LOG_DB != DB { |
|
|
err := closeDB(LOG_DB) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
} |
|
|
return closeDB(DB) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func checkMySQLChineseSupport(db *gorm.DB) error { |
|
|
|
|
|
|
|
|
|
|
|
var schemaCharset, schemaCollation string |
|
|
err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation) |
|
|
if err != nil { |
|
|
return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err) |
|
|
} |
|
|
|
|
|
toLower := func(s string) string { return strings.ToLower(s) } |
|
|
|
|
|
allowedCharsets := map[string]string{ |
|
|
"utf8mb4": "utf8mb4_", |
|
|
"utf8": "utf8_", |
|
|
"gbk": "gbk_", |
|
|
"big5": "big5_", |
|
|
"gb18030": "gb18030_", |
|
|
} |
|
|
isChineseCapable := func(cs, cl string) bool { |
|
|
csLower := toLower(cs) |
|
|
clLower := toLower(cl) |
|
|
if prefix, ok := allowedCharsets[csLower]; ok { |
|
|
if clLower == "" { |
|
|
return true |
|
|
} |
|
|
return strings.HasPrefix(clLower, prefix) |
|
|
} |
|
|
|
|
|
for _, prefix := range allowedCharsets { |
|
|
if strings.HasPrefix(clLower, prefix) { |
|
|
return true |
|
|
} |
|
|
} |
|
|
return false |
|
|
} |
|
|
|
|
|
|
|
|
if !isChineseCapable(schemaCharset, schemaCollation) { |
|
|
return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030", |
|
|
schemaCharset, schemaCollation, schemaCharset, schemaCollation) |
|
|
} |
|
|
|
|
|
|
|
|
type tableInfo struct { |
|
|
Name string |
|
|
Collation *string |
|
|
} |
|
|
var tables []tableInfo |
|
|
if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil { |
|
|
return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err) |
|
|
} |
|
|
|
|
|
var badTables []string |
|
|
for _, t := range tables { |
|
|
|
|
|
if t.Collation == nil || *t.Collation == "" { |
|
|
continue |
|
|
} |
|
|
cl := *t.Collation |
|
|
|
|
|
ok := false |
|
|
lower := strings.ToLower(cl) |
|
|
for _, prefix := range allowedCharsets { |
|
|
if strings.HasPrefix(lower, prefix) { |
|
|
ok = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if !ok { |
|
|
badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl)) |
|
|
} |
|
|
} |
|
|
|
|
|
if len(badTables) > 0 { |
|
|
|
|
|
maxShow := 20 |
|
|
shown := badTables |
|
|
if len(shown) > maxShow { |
|
|
shown = shown[:maxShow] |
|
|
} |
|
|
return fmt.Errorf( |
|
|
"存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v", |
|
|
maxShow, shown, maxShow, shown, |
|
|
) |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
var ( |
|
|
lastPingTime time.Time |
|
|
pingMutex sync.Mutex |
|
|
) |
|
|
|
|
|
func PingDB() error { |
|
|
pingMutex.Lock() |
|
|
defer pingMutex.Unlock() |
|
|
|
|
|
if time.Since(lastPingTime) < time.Second*10 { |
|
|
return nil |
|
|
} |
|
|
|
|
|
sqlDB, err := DB.DB() |
|
|
if err != nil { |
|
|
log.Printf("Error getting sql.DB from GORM: %v", err) |
|
|
return err |
|
|
} |
|
|
|
|
|
err = sqlDB.Ping() |
|
|
if err != nil { |
|
|
log.Printf("Error pinging DB: %v", err) |
|
|
return err |
|
|
} |
|
|
|
|
|
lastPingTime = time.Now() |
|
|
common.SysLog("Database pinged successfully") |
|
|
return nil |
|
|
} |
|
|
|