from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from ..models.user import User from ..schemas.user import UserCreate, UserLogin, Token from ..database import get_collection from ..utils.security import verify_password, get_password_hash, create_access_token from ..config import get_settings settings = get_settings() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") async def get_user_by_email(email: str) -> User | None: """Retrieve a user by email from the database.""" users_collection = get_collection(User.collection_name) user_data = await users_collection.find_one({"email": email}) if user_data: return User.from_dict(user_data) return None async def create_user(user_in: UserCreate) -> User: """Register a new user.""" existing_user = await get_user_by_email(user_in.email) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) hashed_password = get_password_hash(user_in.password) user = User(email=user_in.email, hashed_password=hashed_password) users_collection = get_collection(User.collection_name) await users_collection.insert_one(user.to_dict()) return user async def authenticate_user(user_in: UserLogin) -> Token: """Authenticate a user and return a JWT token.""" user = await get_user_by_email(user_in.email) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"}, ) if not verify_password(user_in.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token = create_access_token(data={"sub": user.email}) return Token(access_token=access_token, token_type="bearer") async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: """Validate JWT token and return current user.""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) email: str = payload.get("sub") if email is None: raise credentials_exception except JWTError: raise credentials_exception user = await get_user_by_email(email) if user is None: raise credentials_exception return user