Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, HTTPException, status, Depends, Response, Request | |
| from sqlmodel import Session, select | |
| from typing import Annotated | |
| from datetime import datetime, timedelta | |
| from uuid import uuid4 | |
| import secrets | |
| import hashlib | |
| from ..models.user import User, UserCreate, UserRead | |
| from ..models.refresh_token import RefreshToken | |
| from ..schemas.auth import RegisterRequest, RegisterResponse, LoginRequest, LoginResponse, ForgotPasswordRequest, ResetPasswordRequest | |
| from ..utils.security import hash_password, create_access_token, verify_password | |
| from ..utils.deps import get_current_user | |
| from ..database import get_session_dep | |
| from ..config import settings | |
| router = APIRouter(prefix="/api/auth", tags=["auth"]) | |
| # Refresh token settings | |
| REFRESH_TOKEN_EXPIRE_DAYS = 30 | |
| REFRESH_TOKEN_COOKIE_NAME = "refresh_token" | |
| # NOTE: For cross-site cookie auth to work (backend on a different domain than the frontend): | |
| # - The frontend must call login/register endpoints with `fetch(..., credentials: 'include')` | |
| # - The backend must have CORS allow_credentials=True and the exact frontend origin in allow_origins | |
| # - Cookies must be set with `Secure=True` and `samesite='none'` in production | |
| # See project README or docs for full details | |
| def register(user_data: RegisterRequest, response: Response, session: Session = Depends(get_session_dep)): | |
| """Register a new user with email and password. | |
| This returns a short-lived access token in the response and sets a HttpOnly refresh cookie. | |
| """ | |
| # Check if user already exists | |
| existing_user = session.exec(select(User).where(User.email == user_data.email)).first() | |
| if existing_user: | |
| raise HTTPException( | |
| status_code=status.HTTP_409_CONFLICT, | |
| detail="An account with this email already exists" | |
| ) | |
| # Validate password length | |
| if len(user_data.password) < 8: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Password must be at least 8 characters" | |
| ) | |
| # Hash the password | |
| password_hash = hash_password(user_data.password) | |
| # Create new user | |
| user = User( | |
| email=user_data.email, | |
| password_hash=password_hash | |
| ) | |
| session.add(user) | |
| session.commit() | |
| session.refresh(user) | |
| # Issue short-lived access token and refresh token | |
| access_token = create_access_token(data={"sub": str(user.id)}, expires_delta=timedelta(minutes=15)) | |
| # Create refresh token, store hashed token in DB | |
| raw_refresh = secrets.token_urlsafe(64) | |
| token_hash = hashlib.sha256(raw_refresh.encode()).hexdigest() | |
| expires_at = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) | |
| refresh_record = RefreshToken(token_hash=token_hash, user_id=user.id, expires_at=expires_at) | |
| session.add(refresh_record) | |
| session.commit() | |
| # Set HttpOnly refresh cookie | |
| response.set_cookie( | |
| key=REFRESH_TOKEN_COOKIE_NAME, | |
| value=raw_refresh, | |
| httponly=True, | |
| secure=settings.JWT_COOKIE_SECURE, | |
| samesite="none", | |
| max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, | |
| path="/api/auth" | |
| ) | |
| return RegisterResponse( | |
| id=user.id, | |
| email=user.email, | |
| message="Account created successfully" | |
| ) | |
| def login(login_data: LoginRequest, response: Response, session: Session = Depends(get_session_dep)): | |
| """Authenticate user with email and password. | |
| Returns a short-lived access token in the response body and sets a HttpOnly refresh cookie. | |
| Frontend should store access token in memory and call protected APIs with Authorization header, or let frontend call `/api/auth/refresh` (with `credentials: 'include'`) to obtain a new access token. | |
| """ | |
| # Find user by email | |
| user = session.exec(select(User).where(User.email == login_data.email)).first() | |
| if not user or not verify_password(login_data.password, user.password_hash): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid email or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # Create short-lived access token (e.g., 15 minutes) | |
| access_token = create_access_token(data={"sub": str(user.id)}, expires_delta=timedelta(minutes=15)) | |
| # Create refresh token, store hashed token in DB | |
| raw_refresh = secrets.token_urlsafe(64) | |
| token_hash = hashlib.sha256(raw_refresh.encode()).hexdigest() | |
| expires_at = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) | |
| refresh_record = RefreshToken(token_hash=token_hash, user_id=user.id, expires_at=expires_at) | |
| session.add(refresh_record) | |
| session.commit() | |
| # Set HttpOnly refresh cookie (sent to frontend with credentials: 'include') | |
| response.set_cookie( | |
| key=REFRESH_TOKEN_COOKIE_NAME, | |
| value=raw_refresh, | |
| httponly=True, | |
| secure=settings.JWT_COOKIE_SECURE, | |
| samesite="none", | |
| max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, | |
| path="/api/auth" | |
| ) | |
| print(f"Refresh cookie set (httponly={True}, secure={settings.JWT_COOKIE_SECURE}, samesite=none, max_age={REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60})") | |
| # Return short-lived access token in response body for immediate use | |
| return LoginResponse( | |
| access_token=access_token, | |
| token_type="bearer", | |
| user=RegisterResponse( | |
| id=user.id, | |
| email=user.email, | |
| message="Login successful" | |
| ) | |
| ) | |
| def refresh_token(request: Request, response: Response, session: Session = Depends(get_session_dep)): | |
| """Rotate refresh token and return a new short-lived access token. | |
| This endpoint reads the HttpOnly refresh cookie, validates it, rotates it (revokes old), | |
| and sets a new refresh cookie (cookie rotation) and returns a fresh access token. | |
| Frontend must call this with `credentials: 'include'`. | |
| """ | |
| raw_refresh = request.cookies.get(REFRESH_TOKEN_COOKIE_NAME) | |
| if not raw_refresh: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="No refresh token provided") | |
| token_hash = hashlib.sha256(raw_refresh.encode()).hexdigest() | |
| token_record = session.exec(select(RefreshToken).where(RefreshToken.token_hash == token_hash)).first() | |
| if not token_record or token_record.revoked: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") | |
| if token_record.expires_at < datetime.utcnow(): | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired") | |
| # Get user | |
| user = session.get(User, token_record.user_id) | |
| if not user: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") | |
| # Revoke old token and rotate | |
| token_record.revoked = True | |
| session.add(token_record) | |
| # Create new refresh token | |
| new_raw_refresh = secrets.token_urlsafe(64) | |
| new_token_hash = hashlib.sha256(new_raw_refresh.encode()).hexdigest() | |
| new_expires_at = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) | |
| new_refresh_record = RefreshToken(token_hash=new_token_hash, user_id=user.id, expires_at=new_expires_at) | |
| session.add(new_refresh_record) | |
| session.commit() | |
| # Set new refresh cookie | |
| response.set_cookie( | |
| key=REFRESH_TOKEN_COOKIE_NAME, | |
| value=new_raw_refresh, | |
| httponly=True, | |
| secure=settings.JWT_COOKIE_SECURE, | |
| samesite="none", | |
| max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, | |
| path="/api/auth" | |
| ) | |
| # Create new short-lived access token and return it | |
| access_token = create_access_token(data={"sub": str(user.id)}, expires_delta=timedelta(minutes=15)) | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| def logout(request: Request, response: Response, session: Session = Depends(get_session_dep)): | |
| """Logout user by revoking the refresh token cookie and clearing the cookie.""" | |
| raw_refresh = request.cookies.get(REFRESH_TOKEN_COOKIE_NAME) | |
| if raw_refresh: | |
| token_hash = hashlib.sha256(raw_refresh.encode()).hexdigest() | |
| token_record = session.exec(select(RefreshToken).where(RefreshToken.token_hash == token_hash)).first() | |
| if token_record: | |
| token_record.revoked = True | |
| session.add(token_record) | |
| session.commit() | |
| # Clear the refresh cookie | |
| response.set_cookie( | |
| key=REFRESH_TOKEN_COOKIE_NAME, | |
| value="", | |
| httponly=True, | |
| secure=settings.JWT_COOKIE_SECURE, | |
| samesite="none", | |
| max_age=0, # Expire immediately | |
| path="/api/auth" | |
| ) | |
| return {"message": "Logged out successfully"} | |
| def get_current_user_profile(request: Request, current_user: User = Depends(get_current_user)): | |
| """Get the current authenticated user's profile.""" | |
| # Debug: Print the cookies received (do not print token values) | |
| print(f"Received cookies: { {k: '***' for k in request.cookies.keys()} }") | |
| return RegisterResponse( | |
| id=current_user.id, | |
| email=current_user.email, | |
| message="User profile retrieved successfully" | |
| ) | |
| def forgot_password(forgot_data: ForgotPasswordRequest, session: Session = Depends(get_session_dep)): | |
| """Initiate password reset process by verifying email exists.""" | |
| # Check if user exists | |
| user = session.exec(select(User).where(User.email == forgot_data.email)).first() | |
| if not user: | |
| # For security reasons, we don't reveal if the email exists or not | |
| return {"message": "If the email exists, a reset link would be sent"} | |
| # In a real implementation, we would send an email here | |
| # But as per requirements, we're just simulating the process | |
| return {"message": "If the email exists, a reset link would be sent"} | |
| def reset_password(reset_data: ResetPasswordRequest, session: Session = Depends(get_session_dep)): | |
| """Reset user password after verification.""" | |
| # Check if user exists | |
| user = session.exec(select(User).where(User.email == reset_data.email)).first() | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="User not found" | |
| ) | |
| # Validate password length | |
| if len(reset_data.new_password) < 8: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Password must be at least 8 characters" | |
| ) | |
| # Hash the new password | |
| user.password_hash = hash_password(reset_data.new_password) | |
| # Update the user | |
| session.add(user) | |
| session.commit() | |
| return {"message": "Password reset successfully"} |