| package db
|
|
|
| import (
|
| "crypto/rand"
|
| "encoding/hex"
|
| "fmt"
|
| "log"
|
| "math/big"
|
| "os"
|
| "path/filepath"
|
| "sync"
|
| "time"
|
|
|
| "gorm.io/driver/sqlite"
|
|
|
| "gorm.io/driver/postgres"
|
| "gorm.io/gorm"
|
| "gorm.io/gorm/logger"
|
| )
|
|
|
|
|
| type Credential struct {
|
| ID uint `gorm:"primarykey"`
|
| Email string `gorm:"uniqueIndex;not null"`
|
| Token string `gorm:"not null"`
|
| }
|
|
|
|
|
| type APIToken struct {
|
| ID uint `gorm:"primarykey"`
|
| Token string `gorm:"uniqueIndex;not null"`
|
| CreatedAt time.Time
|
| }
|
|
|
|
|
| type AdminPassword struct {
|
| ID uint `gorm:"primarykey"`
|
| PasswordHash string `gorm:"not null"`
|
| IsInitial *bool `gorm:"default:true"`
|
| CreatedAt time.Time
|
| }
|
|
|
| var (
|
| db *gorm.DB
|
| dbOnce sync.Once
|
| )
|
|
|
|
|
| func InitDB() (*gorm.DB, error) {
|
| var err error
|
| dbOnce.Do(func() {
|
|
|
| dsn := os.Getenv("DATABASE_URL")
|
| if dsn == "" {
|
| log.Println("DATABASE_URL environment variable not set. Using default SQLite for local development.")
|
|
|
| dbPath := "./credentials_dev.db"
|
| config := &gorm.Config{
|
| Logger: logger.Default.LogMode(logger.Silent),
|
| }
|
| db, err = gorm.Open(sqlite.Open(dbPath), config)
|
| if err != nil {
|
| log.Printf("Failed to connect to local SQLite database: %v", err)
|
| return
|
| }
|
| } else {
|
|
|
| config := &gorm.Config{
|
| Logger: logger.Default.LogMode(logger.Silent),
|
| }
|
|
|
|
|
| db, err = gorm.Open(postgres.Open(dsn), config)
|
| if err != nil {
|
| log.Printf("Failed to connect to PostgreSQL database: %v", err)
|
| return
|
| }
|
| }
|
|
|
|
|
| err = db.AutoMigrate(&Credential{}, &APIToken{}, &AdminPassword{})
|
| if err != nil {
|
| log.Printf("Failed to migrate table structure: %v", err)
|
| return
|
| }
|
| })
|
|
|
| return db, err
|
| }
|
|
|
|
|
| func ensureDBDir(dbPath string) error {
|
| dir := filepath.Dir(dbPath)
|
| return os.MkdirAll(dir, 0755)
|
| }
|
|
|
|
|
| func GetDB() *gorm.DB {
|
| if db == nil {
|
| var err error
|
| db, err = InitDB()
|
| if err != nil {
|
| log.Fatalf("Failed to get database connection: %v", err)
|
| }
|
| }
|
| return db
|
| }
|
|
|
|
|
| func GetAllCredentials() ([]Credential, error) {
|
| var credentials []Credential
|
| result := GetDB().Find(&credentials)
|
| return credentials, result.Error
|
| }
|
|
|
|
|
| func AddCredential(email, token string) error {
|
| credential := Credential{
|
| Email: email,
|
| Token: token,
|
| }
|
| result := GetDB().Create(&credential)
|
| return result.Error
|
| }
|
|
|
|
|
| func DeleteCredential(id uint) error {
|
| result := GetDB().Delete(&Credential{}, id)
|
| return result.Error
|
| }
|
|
|
|
|
| func GetCredentialByID(id uint) (Credential, error) {
|
| var credential Credential
|
| result := GetDB().First(&credential, id)
|
| return credential, result.Error
|
| }
|
|
|
|
|
| func UpdateCredential(id uint, email, token string) error {
|
| result := GetDB().Model(&Credential{}).Where("id = ?", id).Updates(map[string]interface{}{
|
| "email": email,
|
| "token": token,
|
| })
|
| return result.Error
|
| }
|
|
|
|
|
| func GetAPIToken() (string, error) {
|
| var token APIToken
|
| result := GetDB().First(&token)
|
| if result.Error != nil {
|
| return "", result.Error
|
| }
|
| return token.Token, nil
|
| }
|
|
|
|
|
| func GenerateAPIToken() (string, error) {
|
|
|
| b := make([]byte, 32)
|
| _, err := rand.Read(b)
|
| if err != nil {
|
| return "", err
|
| }
|
| token := fmt.Sprintf("sk-%s", hex.EncodeToString(b))
|
|
|
|
|
| GetDB().Where("1=1").Delete(&APIToken{})
|
|
|
|
|
| apiToken := APIToken{
|
| Token: token,
|
| CreatedAt: time.Now(),
|
| }
|
| result := GetDB().Create(&apiToken)
|
| if result.Error != nil {
|
| return "", result.Error
|
| }
|
|
|
| return token, nil
|
| }
|
|
|
|
|
| func ValidateAPIToken(token string) bool {
|
| var count int64
|
| GetDB().Model(&APIToken{}).Where("token = ?", token).Count(&count)
|
| return count > 0
|
| }
|
|
|
|
|
| func SetAdminPassword(passwordHash string, isInitial bool) error {
|
|
|
| GetDB().Where("1=1").Delete(&AdminPassword{})
|
|
|
|
|
| adminPassword := AdminPassword{
|
| PasswordHash: passwordHash,
|
| IsInitial: &isInitial,
|
| CreatedAt: time.Now(),
|
| }
|
| result := GetDB().Create(&adminPassword)
|
| return result.Error
|
| }
|
|
|
|
|
| func GetAdminPassword() (string, bool, error) {
|
| var adminPassword AdminPassword
|
| result := GetDB().First(&adminPassword)
|
| if result.Error != nil {
|
| return "", false, result.Error
|
| }
|
| return adminPassword.PasswordHash, *adminPassword.IsInitial, nil
|
| }
|
|
|
|
|
| func IsPasswordInitial() (bool, error) {
|
| var adminPassword AdminPassword
|
| result := GetDB().First(&adminPassword)
|
| fmt.Printf("adminPassword: %v\n", adminPassword)
|
| if result.Error != nil {
|
| return true, result.Error
|
| }
|
| return *adminPassword.IsInitial, nil
|
| }
|
|
|
|
|
| func GenerateRandomPassword(length int) string {
|
| const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()-_=+"
|
| b := make([]byte, length)
|
| for i := range b {
|
| randomIndex, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
| b[i] = charset[randomIndex.Int64()]
|
| }
|
| return string(b)
|
| }
|
|
|