| from datetime import datetime, timedelta |
| from typing import Optional |
|
|
| from fastapi import APIRouter, Depends, HTTPException, status |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
| from jose import JWTError, jwt |
| from passlib.context import CryptContext |
|
|
| from src.models.user import User |
| from src.data.user_repository import UserRepository |
| from src.core.config import settings |
|
|
| router = APIRouter() |
|
|
| |
| pwd_context = CryptContext(schemes=["bcrypt_sha256"], deprecated="auto") |
|
|
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
| |
| def get_user_repository() -> UserRepository: |
| return UserRepository() |
|
|
|
|
| def verify_password(plain_password, hashed_password): |
| return pwd_context.verify(plain_password, hashed_password) |
|
|
|
|
| def get_password_hash(password): |
| return pwd_context.hash(password) |
|
|
|
|
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): |
| to_encode = data.copy() |
|
|
| if expires_delta: |
| expire = datetime.utcnow() + expires_delta |
| else: |
| expire = datetime.utcnow() + timedelta(minutes=15) |
|
|
| to_encode.update({"exp": expire}) |
|
|
| encoded_jwt = jwt.encode( |
| to_encode, |
| settings.SECRET_KEY, |
| algorithm=settings.ALGORITHM, |
| ) |
|
|
| return encoded_jwt |
|
|
|
|
| async def get_current_user( |
| token: str = Depends(oauth2_scheme), |
| user_repo: UserRepository = Depends(get_user_repository), |
| ): |
| credentials_exception = HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Could not validate credentials", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
|
|
| try: |
| payload = jwt.decode( |
| token, |
| settings.SECRET_KEY, |
| algorithms=[settings.ALGORITHM], |
| ) |
|
|
| email: str = payload.get("sub") |
| if email is None: |
| raise credentials_exception |
|
|
| except JWTError: |
| raise credentials_exception |
|
|
| user = user_repo.get_user_by_email(email) |
| if user is None: |
| raise credentials_exception |
|
|
| return user |
|
|
|
|
| @router.post("/signup", response_model=User, tags=["Authentication"]) |
| async def signup( |
| user: User, |
| user_repo: UserRepository = Depends(get_user_repository), |
| ): |
| db_user = user_repo.get_user_by_email(user.email) |
|
|
| if db_user: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Email already registered", |
| ) |
|
|
| hashed_password = get_password_hash(user.password) |
| created_user = user_repo.create_user(user, hashed_password) |
|
|
| if not created_user: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Failed to create user", |
| ) |
|
|
| |
| created_user.password = "[PASSWORD_REDACTED]" |
| return created_user |
|
|
|
|
| @router.post("/token", tags=["Authentication"]) |
| async def login_for_access_token( |
| form_data: OAuth2PasswordRequestForm = Depends(), |
| user_repo: UserRepository = Depends(get_user_repository), |
| ): |
| user = user_repo.get_user_by_email(form_data.username) |
|
|
| if not user or not verify_password(form_data.password, user.password): |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Incorrect username or password", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
|
|
| access_token_expires = timedelta( |
| minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES |
| ) |
|
|
| access_token = create_access_token( |
| data={"sub": user.email}, |
| expires_delta=access_token_expires, |
| ) |
|
|
| return {"access_token": access_token, "token_type": "bearer"} |
|
|
|
|
| @router.get("/users/me", response_model=User, tags=["Authentication"]) |
| async def read_users_me(current_user: User = Depends(get_current_user)): |
| current_user.password = "[PASSWORD_REDACTED]" |
| return current_user |