opus / internal /service /auth_service.go
llzai's picture
Upload 42 files
644c352 verified
package service
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"time"
"opus-api/internal/model"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
)
var (
ErrInvalidCredentials = errors.New("invalid username or password")
ErrUserNotFound = errors.New("user not found")
ErrInvalidToken = errors.New("invalid token")
)
// AuthService 认证服务
type AuthService struct {
db *gorm.DB
jwtSecret []byte
}
// NewAuthService 创建认证服务
func NewAuthService(db *gorm.DB) *AuthService {
secret := os.Getenv("JWT_SECRET")
if secret == "" {
secret = "default-secret-change-me-in-production"
}
return &AuthService{
db: db,
jwtSecret: []byte(secret),
}
}
// Claims JWT 声明
type Claims struct {
UserID uint `json:"user_id"`
Username string `json:"username"`
jwt.RegisteredClaims
}
// Login 用户登录
func (s *AuthService) Login(username, password string) (*model.User, string, error) {
var user model.User
if err := s.db.Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, "", ErrInvalidCredentials
}
return nil, "", err
}
if !user.CheckPassword(password) {
return nil, "", ErrInvalidCredentials
}
// 生成 JWT token
token, err := s.generateToken(&user)
if err != nil {
return nil, "", err
}
// 保存会话
if err := s.saveSession(&user, token); err != nil {
return nil, "", err
}
return &user, token, nil
}
// generateToken 生成 JWT token
func (s *AuthService) generateToken(user *model.User) (string, error) {
expirationTime := time.Now().Add(24 * time.Hour)
claims := &Claims{
UserID: user.ID,
Username: user.Username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.jwtSecret)
}
// saveSession 保存会话
func (s *AuthService) saveSession(user *model.User, token string) error {
tokenHash := hashToken(token)
session := &model.UserSession{
UserID: user.ID,
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
return s.db.Create(session).Error
}
// ValidateToken 验证 token,返回用户 ID
func (s *AuthService) ValidateToken(tokenString string) (uint, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return 0, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.jwtSecret, nil
})
if err != nil {
return 0, ErrInvalidToken
}
if !token.Valid {
return 0, ErrInvalidToken
}
// 检查会话是否存在且未过期
tokenHash := hashToken(tokenString)
var session model.UserSession
if err := s.db.Where("token_hash = ? AND expires_at > ?", tokenHash, time.Now()).First(&session).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, ErrInvalidToken
}
return 0, err
}
return claims.UserID, nil
}
// Logout 用户登出(删除用户的所有会话)
func (s *AuthService) Logout(userID uint) error {
return s.db.Where("user_id = ?", userID).Delete(&model.UserSession{}).Error
}
// GetUserByID 根据 ID 获取用户
func (s *AuthService) GetUserByID(userID uint) (*model.User, error) {
var user model.User
if err := s.db.First(&user, userID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// hashToken 对 token 进行哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// CleanExpiredSessions 清理过期会话
func (s *AuthService) CleanExpiredSessions() error {
return s.db.Where("expires_at < ?", time.Now()).Delete(&model.UserSession{}).Error
}
// ChangePassword 修改用户密码
func (s *AuthService) ChangePassword(userID uint, oldPassword, newPassword string) error {
var user model.User
// 查找用户
if err := s.db.First(&user, userID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return err
}
// 验证旧密码
if !user.CheckPassword(oldPassword) {
return ErrInvalidCredentials
}
// 设置新密码
if err := user.SetPassword(newPassword); err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// 保存到数据库
if err := s.db.Save(&user).Error; err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}