Spaces:
Sleeping
Sleeping
| # app/api/auth.py | |
| """ | |
| Authentication endpoints β production-hardened. | |
| Endpoints: | |
| POST /api/auth/register β create account, returns token pair | |
| POST /api/auth/login β authenticate, returns token pair | |
| POST /api/auth/refresh β rotate refresh token | |
| POST /api/auth/logout β blacklist access token, clear refresh token | |
| POST /api/auth/change-password β change password with old password verification | |
| GET /api/auth/profile β get current user profile | |
| PATCH /api/auth/profile β update name / coach_personality | |
| """ | |
| import logging | |
| from datetime import timezone | |
| from fastapi import APIRouter, Depends, HTTPException, Request, status | |
| from sqlalchemy.orm import Session | |
| from app.core.database import get_db | |
| from app.core.security import ( | |
| blacklist_token, | |
| create_token_pair, | |
| get_current_token, | |
| get_current_user, | |
| hash_password, | |
| hash_refresh_token, | |
| record_failed_login, | |
| reset_failed_logins, | |
| verify_password, | |
| verify_refresh_token, | |
| ) | |
| from app.main import limiter | |
| from app.models.user import User | |
| from app.schemas.auth import ( | |
| ChangePasswordRequest, | |
| LoginRequest, | |
| MessageResponse, | |
| RefreshTokenRequest, | |
| RefreshTokenResponse, | |
| RegisterRequest, | |
| TokenResponse, | |
| UpdateProfileRequest, | |
| UserProfile, | |
| ) | |
| logger = logging.getLogger("hale.api.auth") | |
| router = APIRouter(prefix="/api/auth", tags=["Authentication"]) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Register | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def register( | |
| request: Request, | |
| payload: RegisterRequest, | |
| db: Session = Depends(get_db), | |
| ) -> TokenResponse: | |
| # Username uniqueness | |
| if db.query(User).filter(User.user_id == payload.user_id).first(): | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail=f"Username '{payload.user_id}' is already taken", | |
| ) | |
| # Email uniqueness (if provided) | |
| if payload.email: | |
| if db.query(User).filter(User.email == payload.email).first(): | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="Email already registered", | |
| ) | |
| user = User( | |
| user_id=payload.user_id, | |
| email=payload.email, | |
| name=payload.name or payload.user_id, | |
| hashed_password=hash_password(payload.password), | |
| ) | |
| db.add(user) | |
| db.commit() | |
| db.refresh(user) | |
| # Issue token pair | |
| tokens = create_token_pair(user.user_id) | |
| user.refresh_token = hash_refresh_token(tokens["refresh_token"]) | |
| db.commit() | |
| logger.info("New user registered: %s", user.user_id) | |
| return TokenResponse( | |
| access_token=tokens["access_token"], | |
| refresh_token=tokens["refresh_token"], | |
| user_id=user.user_id, | |
| name=user.name, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Login | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def login( | |
| request: Request, | |
| payload: LoginRequest, | |
| db: Session = Depends(get_db), | |
| ) -> TokenResponse: | |
| # Find user by user_id OR email (case-insensitive) | |
| user_id_lower = payload.user_id.lower().strip() | |
| user = db.query(User).filter( | |
| (User.user_id.ilike(user_id_lower)) | (User.email.ilike(user_id_lower)) | |
| ).first() | |
| # Unknown user | |
| if not user: | |
| logger.warning("[Login] User not found: %s", payload.user_id) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid credentials", | |
| ) | |
| if not user.hashed_password: | |
| logger.error("[Login] User %s exists but has no password hash!", user.user_id) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid credentials", | |
| ) | |
| # Account lockout check | |
| if user.is_locked: | |
| logger.warning("[Login] Locked account attempt: %s", user.user_id) | |
| raise HTTPException( | |
| status_code=status.HTTP_423_LOCKED, | |
| detail=( | |
| "Account temporarily locked due to too many failed login attempts. " | |
| "Please try again in a few minutes." | |
| ), | |
| ) | |
| # Wrong password | |
| if not verify_password(payload.password, user.hashed_password): | |
| logger.warning("[Login] Password mismatch for user: %s", user.user_id) | |
| record_failed_login(db, user) | |
| remaining = max(0, 5 - (user.failed_login_count or 0)) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=f"Invalid credentials. {remaining} attempt(s) remaining before lockout.", | |
| ) | |
| # Inactive account | |
| if not user.is_active: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Account is deactivated", | |
| ) | |
| # Success β reset lockout, issue tokens | |
| reset_failed_logins(db, user) | |
| tokens = create_token_pair(user.user_id) | |
| user.refresh_token = hash_refresh_token(tokens["refresh_token"]) | |
| db.commit() | |
| logger.info("User logged in: %s", user.user_id) | |
| return TokenResponse( | |
| access_token=tokens["access_token"], | |
| refresh_token=tokens["refresh_token"], | |
| user_id=user.user_id, | |
| name=user.name, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Refresh | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def refresh_token( | |
| request: Request, | |
| payload: RefreshTokenRequest, | |
| db: Session = Depends(get_db), | |
| ) -> RefreshTokenResponse: | |
| """ | |
| Validates the provided refresh token, rotates it, and issues a fresh | |
| access token + new refresh token (refresh token rotation for security). | |
| """ | |
| # Find user by trying to match the hashed refresh token. | |
| # We need to check all users with non-null refresh tokens. | |
| # Inefficient for large DBs but acceptable with token hashing. | |
| # Production: store a token_id alongside the hash to look up directly. | |
| users_with_token = db.query(User).filter(User.refresh_token.isnot(None)).all() | |
| matched_user = None | |
| for u in users_with_token: | |
| if verify_refresh_token(payload.refresh_token, u.refresh_token): | |
| matched_user = u | |
| break | |
| if not matched_user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid or expired refresh token", | |
| ) | |
| if not matched_user.is_active: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Account is deactivated", | |
| ) | |
| # Rotate: issue new pair, invalidate old refresh token | |
| tokens = create_token_pair(matched_user.user_id) | |
| matched_user.refresh_token = hash_refresh_token(tokens["refresh_token"]) | |
| db.commit() | |
| logger.info("Token refreshed for: %s", matched_user.user_id) | |
| return RefreshTokenResponse( | |
| access_token=tokens["access_token"], | |
| refresh_token=tokens["refresh_token"], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Logout | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def logout( | |
| current_user: User = Depends(get_current_user), | |
| token: str = Depends(get_current_token), | |
| db: Session = Depends(get_db), | |
| ) -> MessageResponse: | |
| # Blacklist the current access token | |
| blacklist_token(token) | |
| # Clear the refresh token from DB | |
| current_user.refresh_token = None | |
| db.commit() | |
| logger.info("User logged out: %s", current_user.user_id) | |
| return MessageResponse(message="Logged out successfully") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Change password | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def change_password( | |
| request: Request, | |
| payload: ChangePasswordRequest, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db), | |
| ) -> MessageResponse: | |
| if not verify_password(payload.old_password, current_user.hashed_password): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Old password is incorrect", | |
| ) | |
| if payload.old_password == payload.new_password: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="New password must be different from the old password", | |
| ) | |
| current_user.hashed_password = hash_password(payload.new_password) | |
| # Invalidate all sessions by clearing refresh token | |
| current_user.refresh_token = None | |
| db.commit() | |
| logger.info("Password changed for: %s", current_user.user_id) | |
| return MessageResponse(message="Password changed successfully. Please log in again.") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Profile | |
| # ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_profile(current_user: User = Depends(get_current_user)) -> User: | |
| return current_user | |
| def update_profile( | |
| updates: UpdateProfileRequest, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db), | |
| ) -> User: | |
| """Update allowed profile fields. Only non-None fields are applied.""" | |
| if updates.name is not None: | |
| current_user.name = updates.name | |
| if updates.coach_personality is not None: | |
| current_user.coach_personality = updates.coach_personality | |
| db.commit() | |
| db.refresh(current_user) | |
| return current_user | |