| package service
|
|
|
| import (
|
| "crypto/sha256"
|
| "encoding/hex"
|
| "errors"
|
| "fmt"
|
| "os"
|
| "time"
|
|
|
| "opus-api/internal/model"
|
|
|
| "github.com/golang-jwt/jwt/v5"
|
| "gorm.io/gorm"
|
| )
|
|
|
| var (
|
| ErrInvalidCredentials = errors.New("invalid username or password")
|
| ErrUserNotFound = errors.New("user not found")
|
| ErrInvalidToken = errors.New("invalid token")
|
| )
|
|
|
|
|
| type AuthService struct {
|
| db *gorm.DB
|
| jwtSecret []byte
|
| }
|
|
|
|
|
| func NewAuthService(db *gorm.DB) *AuthService {
|
| secret := os.Getenv("JWT_SECRET")
|
| if secret == "" {
|
| secret = "default-secret-change-me-in-production"
|
| }
|
| return &AuthService{
|
| db: db,
|
| jwtSecret: []byte(secret),
|
| }
|
| }
|
|
|
|
|
| type Claims struct {
|
| UserID uint `json:"user_id"`
|
| Username string `json:"username"`
|
| jwt.RegisteredClaims
|
| }
|
|
|
|
|
| func (s *AuthService) Login(username, password string) (*model.User, string, error) {
|
| var user model.User
|
| if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
|
| if errors.Is(err, gorm.ErrRecordNotFound) {
|
| return nil, "", ErrInvalidCredentials
|
| }
|
| return nil, "", err
|
| }
|
|
|
| if !user.CheckPassword(password) {
|
| return nil, "", ErrInvalidCredentials
|
| }
|
|
|
|
|
| token, err := s.generateToken(&user)
|
| if err != nil {
|
| return nil, "", err
|
| }
|
|
|
|
|
| if err := s.saveSession(&user, token); err != nil {
|
| return nil, "", err
|
| }
|
|
|
| return &user, token, nil
|
| }
|
|
|
|
|
| func (s *AuthService) generateToken(user *model.User) (string, error) {
|
| expirationTime := time.Now().Add(24 * time.Hour)
|
| claims := &Claims{
|
| UserID: user.ID,
|
| Username: user.Username,
|
| RegisteredClaims: jwt.RegisteredClaims{
|
| ExpiresAt: jwt.NewNumericDate(expirationTime),
|
| IssuedAt: jwt.NewNumericDate(time.Now()),
|
| },
|
| }
|
|
|
| token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
| return token.SignedString(s.jwtSecret)
|
| }
|
|
|
|
|
| func (s *AuthService) saveSession(user *model.User, token string) error {
|
| tokenHash := hashToken(token)
|
| session := &model.UserSession{
|
| UserID: user.ID,
|
| TokenHash: tokenHash,
|
| ExpiresAt: time.Now().Add(24 * time.Hour),
|
| }
|
| return s.db.Create(session).Error
|
| }
|
|
|
|
|
| func (s *AuthService) ValidateToken(tokenString string) (uint, error) {
|
| claims := &Claims{}
|
| token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
| if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
| return 0, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
| }
|
| return s.jwtSecret, nil
|
| })
|
|
|
| if err != nil {
|
| return 0, ErrInvalidToken
|
| }
|
|
|
| if !token.Valid {
|
| return 0, ErrInvalidToken
|
| }
|
|
|
|
|
| tokenHash := hashToken(tokenString)
|
| var session model.UserSession
|
| if err := s.db.Where("token_hash = ? AND expires_at > ?", tokenHash, time.Now()).First(&session).Error; err != nil {
|
| if errors.Is(err, gorm.ErrRecordNotFound) {
|
| return 0, ErrInvalidToken
|
| }
|
| return 0, err
|
| }
|
|
|
| return claims.UserID, nil
|
| }
|
|
|
|
|
| func (s *AuthService) Logout(userID uint) error {
|
| return s.db.Where("user_id = ?", userID).Delete(&model.UserSession{}).Error
|
| }
|
|
|
|
|
| func (s *AuthService) GetUserByID(userID uint) (*model.User, error) {
|
| var user model.User
|
| if err := s.db.First(&user, userID).Error; err != nil {
|
| if errors.Is(err, gorm.ErrRecordNotFound) {
|
| return nil, ErrUserNotFound
|
| }
|
| return nil, err
|
| }
|
| return &user, nil
|
| }
|
|
|
|
|
| func hashToken(token string) string {
|
| hash := sha256.Sum256([]byte(token))
|
| return hex.EncodeToString(hash[:])
|
| }
|
|
|
|
|
| func (s *AuthService) CleanExpiredSessions() error {
|
| return s.db.Where("expires_at < ?", time.Now()).Delete(&model.UserSession{}).Error
|
| }
|
|
|
|
|
| func (s *AuthService) ChangePassword(userID uint, oldPassword, newPassword string) error {
|
| var user model.User
|
|
|
|
|
| if err := s.db.First(&user, userID).Error; err != nil {
|
| if errors.Is(err, gorm.ErrRecordNotFound) {
|
| return ErrUserNotFound
|
| }
|
| return err
|
| }
|
|
|
|
|
| if !user.CheckPassword(oldPassword) {
|
| return ErrInvalidCredentials
|
| }
|
|
|
|
|
| if err := user.SetPassword(newPassword); err != nil {
|
| return fmt.Errorf("failed to hash password: %w", err)
|
| }
|
|
|
|
|
| if err := s.db.Save(&user).Error; err != nil {
|
| return fmt.Errorf("failed to update password: %w", err)
|
| }
|
|
|
| return nil
|
| } |