| package handler |
|
|
| import ( |
| "log/slog" |
| "strings" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/config" |
| "github.com/Wei-Shaw/sub2api/internal/handler/dto" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/ip" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/response" |
| middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| type AuthHandler struct { |
| cfg *config.Config |
| authService *service.AuthService |
| userService *service.UserService |
| settingSvc *service.SettingService |
| promoService *service.PromoService |
| redeemService *service.RedeemService |
| totpService *service.TotpService |
| } |
|
|
| |
| func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler { |
| return &AuthHandler{ |
| cfg: cfg, |
| authService: authService, |
| userService: userService, |
| settingSvc: settingService, |
| promoService: promoService, |
| redeemService: redeemService, |
| totpService: totpService, |
| } |
| } |
|
|
| |
| type RegisterRequest struct { |
| Email string `json:"email" binding:"required,email"` |
| Password string `json:"password" binding:"required,min=6"` |
| VerifyCode string `json:"verify_code"` |
| TurnstileToken string `json:"turnstile_token"` |
| PromoCode string `json:"promo_code"` |
| InvitationCode string `json:"invitation_code"` |
| } |
|
|
| |
| type SendVerifyCodeRequest struct { |
| Email string `json:"email" binding:"required,email"` |
| TurnstileToken string `json:"turnstile_token"` |
| } |
|
|
| |
| type SendVerifyCodeResponse struct { |
| Message string `json:"message"` |
| Countdown int `json:"countdown"` |
| } |
|
|
| |
| type LoginRequest struct { |
| Email string `json:"email" binding:"required,email"` |
| Password string `json:"password" binding:"required"` |
| TurnstileToken string `json:"turnstile_token"` |
| } |
|
|
| |
| type AuthResponse struct { |
| AccessToken string `json:"access_token"` |
| RefreshToken string `json:"refresh_token,omitempty"` |
| ExpiresIn int `json:"expires_in,omitempty"` |
| TokenType string `json:"token_type"` |
| User *dto.User `json:"user"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { |
| tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") |
| if err != nil { |
| slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) |
| |
| token, tokenErr := h.authService.GenerateToken(user) |
| if tokenErr != nil { |
| response.InternalError(c, "Failed to generate token") |
| return |
| } |
| response.Success(c, AuthResponse{ |
| AccessToken: token, |
| TokenType: "Bearer", |
| User: dto.UserFromService(user), |
| }) |
| return |
| } |
| response.Success(c, AuthResponse{ |
| AccessToken: tokenPair.AccessToken, |
| RefreshToken: tokenPair.RefreshToken, |
| ExpiresIn: tokenPair.ExpiresIn, |
| TokenType: "Bearer", |
| User: dto.UserFromService(user), |
| }) |
| } |
|
|
| |
| |
| func (h *AuthHandler) Register(c *gin.Context) { |
| var req RegisterRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| h.respondWithTokenPair(c, user) |
| } |
|
|
| |
| |
| func (h *AuthHandler) SendVerifyCode(c *gin.Context) { |
| var req SendVerifyCodeRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, SendVerifyCodeResponse{ |
| Message: "Verification code sent successfully", |
| Countdown: result.Countdown, |
| }) |
| } |
|
|
| |
| |
| func (h *AuthHandler) Login(c *gin.Context) { |
| var req LoginRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
| _ = token |
|
|
| |
| if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { |
| |
| tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email) |
| if err != nil { |
| response.InternalError(c, "Failed to create 2FA session") |
| return |
| } |
|
|
| response.Success(c, TotpLoginResponse{ |
| Requires2FA: true, |
| TempToken: tempToken, |
| UserEmailMasked: service.MaskEmail(user.Email), |
| }) |
| return |
| } |
|
|
| |
| if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { |
| response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") |
| return |
| } |
|
|
| h.respondWithTokenPair(c, user) |
| } |
|
|
| |
| type TotpLoginResponse struct { |
| Requires2FA bool `json:"requires_2fa"` |
| TempToken string `json:"temp_token,omitempty"` |
| UserEmailMasked string `json:"user_email_masked,omitempty"` |
| } |
|
|
| |
| type Login2FARequest struct { |
| TempToken string `json:"temp_token" binding:"required"` |
| TotpCode string `json:"totp_code" binding:"required,len=6"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) Login2FA(c *gin.Context) { |
| var req Login2FARequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| slog.Debug("login_2fa_request", |
| "temp_token_len", len(req.TempToken), |
| "totp_code_len", len(req.TotpCode)) |
|
|
| |
| session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken) |
| if err != nil || session == nil { |
| tokenPrefix := "" |
| if len(req.TempToken) >= 8 { |
| tokenPrefix = req.TempToken[:8] |
| } |
| slog.Debug("login_2fa_session_invalid", |
| "temp_token_prefix", tokenPrefix, |
| "error", err) |
| response.BadRequest(c, "Invalid or expired 2FA session") |
| return |
| } |
|
|
| slog.Debug("login_2fa_session_found", |
| "user_id", session.UserID, |
| "email", session.Email) |
|
|
| |
| if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil { |
| slog.Debug("login_2fa_verify_failed", |
| "user_id", session.UserID, |
| "error", err) |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| user, err := h.userService.GetByID(c.Request.Context(), session.UserID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { |
| response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") |
| return |
| } |
|
|
| |
| _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) |
|
|
| h.respondWithTokenPair(c, user) |
| } |
|
|
| |
| |
| func (h *AuthHandler) GetCurrentUser(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| user, err := h.userService.GetByID(c.Request.Context(), subject.UserID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| type UserResponse struct { |
| *dto.User |
| RunMode string `json:"run_mode"` |
| } |
|
|
| runMode := config.RunModeStandard |
| if h.cfg != nil { |
| runMode = h.cfg.RunMode |
| } |
|
|
| response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) |
| } |
|
|
| |
| type ValidatePromoCodeRequest struct { |
| Code string `json:"code" binding:"required"` |
| } |
|
|
| |
| type ValidatePromoCodeResponse struct { |
| Valid bool `json:"valid"` |
| BonusAmount float64 `json:"bonus_amount,omitempty"` |
| ErrorCode string `json:"error_code,omitempty"` |
| Message string `json:"message,omitempty"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) ValidatePromoCode(c *gin.Context) { |
| |
| if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) { |
| response.Success(c, ValidatePromoCodeResponse{ |
| Valid: false, |
| ErrorCode: "PROMO_CODE_DISABLED", |
| }) |
| return |
| } |
|
|
| var req ValidatePromoCodeRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code) |
| if err != nil { |
| |
| errorCode := "PROMO_CODE_INVALID" |
| switch err { |
| case service.ErrPromoCodeNotFound: |
| errorCode = "PROMO_CODE_NOT_FOUND" |
| case service.ErrPromoCodeExpired: |
| errorCode = "PROMO_CODE_EXPIRED" |
| case service.ErrPromoCodeDisabled: |
| errorCode = "PROMO_CODE_DISABLED" |
| case service.ErrPromoCodeMaxUsed: |
| errorCode = "PROMO_CODE_MAX_USED" |
| case service.ErrPromoCodeAlreadyUsed: |
| errorCode = "PROMO_CODE_ALREADY_USED" |
| } |
|
|
| response.Success(c, ValidatePromoCodeResponse{ |
| Valid: false, |
| ErrorCode: errorCode, |
| }) |
| return |
| } |
|
|
| if promoCode == nil { |
| response.Success(c, ValidatePromoCodeResponse{ |
| Valid: false, |
| ErrorCode: "PROMO_CODE_INVALID", |
| }) |
| return |
| } |
|
|
| response.Success(c, ValidatePromoCodeResponse{ |
| Valid: true, |
| BonusAmount: promoCode.BonusAmount, |
| }) |
| } |
|
|
| |
| type ValidateInvitationCodeRequest struct { |
| Code string `json:"code" binding:"required"` |
| } |
|
|
| |
| type ValidateInvitationCodeResponse struct { |
| Valid bool `json:"valid"` |
| ErrorCode string `json:"error_code,omitempty"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) { |
| |
| if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) { |
| response.Success(c, ValidateInvitationCodeResponse{ |
| Valid: false, |
| ErrorCode: "INVITATION_CODE_DISABLED", |
| }) |
| return |
| } |
|
|
| var req ValidateInvitationCodeRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code) |
| if err != nil { |
| response.Success(c, ValidateInvitationCodeResponse{ |
| Valid: false, |
| ErrorCode: "INVITATION_CODE_NOT_FOUND", |
| }) |
| return |
| } |
|
|
| |
| if redeemCode.Type != service.RedeemTypeInvitation { |
| response.Success(c, ValidateInvitationCodeResponse{ |
| Valid: false, |
| ErrorCode: "INVITATION_CODE_INVALID", |
| }) |
| return |
| } |
|
|
| if redeemCode.Status != service.StatusUnused { |
| response.Success(c, ValidateInvitationCodeResponse{ |
| Valid: false, |
| ErrorCode: "INVITATION_CODE_USED", |
| }) |
| return |
| } |
|
|
| response.Success(c, ValidateInvitationCodeResponse{ |
| Valid: true, |
| }) |
| } |
|
|
| |
| type ForgotPasswordRequest struct { |
| Email string `json:"email" binding:"required,email"` |
| TurnstileToken string `json:"turnstile_token"` |
| } |
|
|
| |
| type ForgotPasswordResponse struct { |
| Message string `json:"message"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) ForgotPassword(c *gin.Context) { |
| var req ForgotPasswordRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context())) |
| if frontendBaseURL == "" { |
| slog.Error("frontend_url not configured in settings or config; cannot build password reset link") |
| response.InternalError(c, "Password reset is not configured") |
| return |
| } |
|
|
| |
| |
| if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, ForgotPasswordResponse{ |
| Message: "If your email is registered, you will receive a password reset link shortly.", |
| }) |
| } |
|
|
| |
| type ResetPasswordRequest struct { |
| Email string `json:"email" binding:"required,email"` |
| Token string `json:"token" binding:"required"` |
| NewPassword string `json:"new_password" binding:"required,min=6"` |
| } |
|
|
| |
| type ResetPasswordResponse struct { |
| Message string `json:"message"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) ResetPassword(c *gin.Context) { |
| var req ResetPasswordRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, ResetPasswordResponse{ |
| Message: "Your password has been reset successfully. You can now log in with your new password.", |
| }) |
| } |
|
|
| |
|
|
| |
| type RefreshTokenRequest struct { |
| RefreshToken string `json:"refresh_token" binding:"required"` |
| } |
|
|
| |
| type RefreshTokenResponse struct { |
| AccessToken string `json:"access_token"` |
| RefreshToken string `json:"refresh_token"` |
| ExpiresIn int `json:"expires_in"` |
| TokenType string `json:"token_type"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) RefreshToken(c *gin.Context) { |
| var req RefreshTokenRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" { |
| response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") |
| return |
| } |
|
|
| response.Success(c, RefreshTokenResponse{ |
| AccessToken: result.AccessToken, |
| RefreshToken: result.RefreshToken, |
| ExpiresIn: result.ExpiresIn, |
| TokenType: "Bearer", |
| }) |
| } |
|
|
| |
| type LogoutRequest struct { |
| RefreshToken string `json:"refresh_token,omitempty"` |
| } |
|
|
| |
| type LogoutResponse struct { |
| Message string `json:"message"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) Logout(c *gin.Context) { |
| var req LogoutRequest |
| |
| _ = c.ShouldBindJSON(&req) |
|
|
| |
| if req.RefreshToken != "" { |
| if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil { |
| slog.Debug("failed to revoke refresh token", "error", err) |
| |
| } |
| } |
|
|
| response.Success(c, LogoutResponse{ |
| Message: "Logged out successfully", |
| }) |
| } |
|
|
| |
| type RevokeAllSessionsResponse struct { |
| Message string `json:"message"` |
| } |
|
|
| |
| |
| func (h *AuthHandler) RevokeAllSessions(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { |
| slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err) |
| response.InternalError(c, "Failed to revoke sessions") |
| return |
| } |
|
|
| response.Success(c, RevokeAllSessionsResponse{ |
| Message: "All sessions have been revoked. Please log in again.", |
| }) |
| } |
|
|