|
|
package db
|
|
|
|
|
|
import (
|
|
|
"crypto/rand"
|
|
|
"encoding/hex"
|
|
|
"fmt"
|
|
|
"log"
|
|
|
"math/big"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
"gorm.io/driver/sqlite"
|
|
|
|
|
|
"gorm.io/driver/postgres"
|
|
|
"gorm.io/gorm"
|
|
|
"gorm.io/gorm/logger"
|
|
|
)
|
|
|
|
|
|
|
|
|
type Credential struct {
|
|
|
ID uint `gorm:"primarykey"`
|
|
|
Email string `gorm:"uniqueIndex;not null"`
|
|
|
Token string `gorm:"not null"`
|
|
|
}
|
|
|
|
|
|
|
|
|
type APIToken struct {
|
|
|
ID uint `gorm:"primarykey"`
|
|
|
Token string `gorm:"uniqueIndex;not null"`
|
|
|
CreatedAt time.Time
|
|
|
}
|
|
|
|
|
|
|
|
|
type AdminPassword struct {
|
|
|
ID uint `gorm:"primarykey"`
|
|
|
PasswordHash string `gorm:"not null"`
|
|
|
IsInitial *bool `gorm:"default:true"`
|
|
|
CreatedAt time.Time
|
|
|
}
|
|
|
|
|
|
var (
|
|
|
db *gorm.DB
|
|
|
dbOnce sync.Once
|
|
|
)
|
|
|
|
|
|
|
|
|
func InitDB() (*gorm.DB, error) {
|
|
|
var err error
|
|
|
dbOnce.Do(func() {
|
|
|
|
|
|
dsn := os.Getenv("DATABASE_URL")
|
|
|
if dsn == "" {
|
|
|
log.Println("DATABASE_URL environment variable not set. Using default SQLite for local development.")
|
|
|
|
|
|
dbPath := "./credentials_dev.db"
|
|
|
config := &gorm.Config{
|
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
|
}
|
|
|
db, err = gorm.Open(sqlite.Open(dbPath), config)
|
|
|
if err != nil {
|
|
|
log.Printf("Failed to connect to local SQLite database: %v", err)
|
|
|
return
|
|
|
}
|
|
|
} else {
|
|
|
|
|
|
config := &gorm.Config{
|
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
|
}
|
|
|
|
|
|
|
|
|
db, err = gorm.Open(postgres.Open(dsn), config)
|
|
|
if err != nil {
|
|
|
log.Printf("Failed to connect to PostgreSQL database: %v", err)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
err = db.AutoMigrate(&Credential{}, &APIToken{}, &AdminPassword{})
|
|
|
if err != nil {
|
|
|
log.Printf("Failed to migrate table structure: %v", err)
|
|
|
return
|
|
|
}
|
|
|
})
|
|
|
|
|
|
return db, err
|
|
|
}
|
|
|
|
|
|
|
|
|
func ensureDBDir(dbPath string) error {
|
|
|
dir := filepath.Dir(dbPath)
|
|
|
return os.MkdirAll(dir, 0755)
|
|
|
}
|
|
|
|
|
|
|
|
|
func GetDB() *gorm.DB {
|
|
|
if db == nil {
|
|
|
var err error
|
|
|
db, err = InitDB()
|
|
|
if err != nil {
|
|
|
log.Fatalf("Failed to get database connection: %v", err)
|
|
|
}
|
|
|
}
|
|
|
return db
|
|
|
}
|
|
|
|
|
|
|
|
|
func GetAllCredentials() ([]Credential, error) {
|
|
|
var credentials []Credential
|
|
|
result := GetDB().Find(&credentials)
|
|
|
return credentials, result.Error
|
|
|
}
|
|
|
|
|
|
|
|
|
func AddCredential(email, token string) error {
|
|
|
credential := Credential{
|
|
|
Email: email,
|
|
|
Token: token,
|
|
|
}
|
|
|
result := GetDB().Create(&credential)
|
|
|
return result.Error
|
|
|
}
|
|
|
|
|
|
|
|
|
func DeleteCredential(id uint) error {
|
|
|
result := GetDB().Delete(&Credential{}, id)
|
|
|
return result.Error
|
|
|
}
|
|
|
|
|
|
|
|
|
func GetCredentialByID(id uint) (Credential, error) {
|
|
|
var credential Credential
|
|
|
result := GetDB().First(&credential, id)
|
|
|
return credential, result.Error
|
|
|
}
|
|
|
|
|
|
|
|
|
func UpdateCredential(id uint, email, token string) error {
|
|
|
result := GetDB().Model(&Credential{}).Where("id = ?", id).Updates(map[string]interface{}{
|
|
|
"email": email,
|
|
|
"token": token,
|
|
|
})
|
|
|
return result.Error
|
|
|
}
|
|
|
|
|
|
|
|
|
func GetAPIToken() (string, error) {
|
|
|
var token APIToken
|
|
|
result := GetDB().First(&token)
|
|
|
if result.Error != nil {
|
|
|
return "", result.Error
|
|
|
}
|
|
|
return token.Token, nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func GenerateAPIToken() (string, error) {
|
|
|
|
|
|
b := make([]byte, 32)
|
|
|
_, err := rand.Read(b)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
token := fmt.Sprintf("sk-%s", hex.EncodeToString(b))
|
|
|
|
|
|
|
|
|
GetDB().Where("1=1").Delete(&APIToken{})
|
|
|
|
|
|
|
|
|
apiToken := APIToken{
|
|
|
Token: token,
|
|
|
CreatedAt: time.Now(),
|
|
|
}
|
|
|
result := GetDB().Create(&apiToken)
|
|
|
if result.Error != nil {
|
|
|
return "", result.Error
|
|
|
}
|
|
|
|
|
|
return token, nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func ValidateAPIToken(token string) bool {
|
|
|
var count int64
|
|
|
GetDB().Model(&APIToken{}).Where("token = ?", token).Count(&count)
|
|
|
return count > 0
|
|
|
}
|
|
|
|
|
|
|
|
|
func SetAdminPassword(passwordHash string, isInitial bool) error {
|
|
|
|
|
|
GetDB().Where("1=1").Delete(&AdminPassword{})
|
|
|
|
|
|
|
|
|
adminPassword := AdminPassword{
|
|
|
PasswordHash: passwordHash,
|
|
|
IsInitial: &isInitial,
|
|
|
CreatedAt: time.Now(),
|
|
|
}
|
|
|
result := GetDB().Create(&adminPassword)
|
|
|
return result.Error
|
|
|
}
|
|
|
|
|
|
|
|
|
func GetAdminPassword() (string, bool, error) {
|
|
|
var adminPassword AdminPassword
|
|
|
result := GetDB().First(&adminPassword)
|
|
|
if result.Error != nil {
|
|
|
return "", false, result.Error
|
|
|
}
|
|
|
return adminPassword.PasswordHash, *adminPassword.IsInitial, nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func IsPasswordInitial() (bool, error) {
|
|
|
var adminPassword AdminPassword
|
|
|
result := GetDB().First(&adminPassword)
|
|
|
fmt.Printf("adminPassword: %v\n", adminPassword)
|
|
|
if result.Error != nil {
|
|
|
return true, result.Error
|
|
|
}
|
|
|
return *adminPassword.IsInitial, nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func GenerateRandomPassword(length int) string {
|
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()-_=+"
|
|
|
b := make([]byte, length)
|
|
|
for i := range b {
|
|
|
randomIndex, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
|
|
b[i] = charset[randomIndex.Int64()]
|
|
|
}
|
|
|
return string(b)
|
|
|
}
|
|
|
|