from datetime import datetime, timedelta, timezone import hashlib import logging from typing import Any, Dict, Optional from uuid import uuid4 import bcrypt from bson import ObjectId from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from config.database import db_manager from config.settings import settings from validation.validation import UserPublic logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") JWT_SECRET_KEY = settings.jwt_secret_key JWT_ALGORITHM = settings.jwt_algorithm ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes EMAIL_VERIFICATION_EXPIRE_MINUTES = settings.email_verification_expire_minutes PASSWORD_HASH_PREFIX = "bcrypt_sha256$" def _bcrypt_sha256_input(password: str) -> bytes: return hashlib.sha256(password.encode("utf-8")).hexdigest().encode("ascii") def hash_password(plain_password: str) -> str: try: password_bytes = _bcrypt_sha256_input(plain_password) hashed_password = bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8") return f"{PASSWORD_HASH_PREFIX}{hashed_password}" except Exception as e: logger.error(f"Error hashing password: {e}") raise HTTPException(status_code=500, detail="Internal server error") def verify_password(plain_password: str, hashed_password: str) -> bool: try: if hashed_password.startswith(PASSWORD_HASH_PREFIX): stored_hash = hashed_password[len(PASSWORD_HASH_PREFIX):] return bcrypt.checkpw( _bcrypt_sha256_input(plain_password), stored_hash.encode("utf-8"), ) # Backward compatibility for existing bcrypt hashes created before this change. return bcrypt.checkpw( plain_password.encode("utf-8"), hashed_password.encode("utf-8"), ) except Exception as e: logger.error(f"Error verifying password: {e}") return False def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: to_encode = data.copy() expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) to_encode.update({"exp": expire, "type": "access", "jti": str(uuid4())}) return jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) def create_email_verification_token(user_id: str, email: str) -> str: expire = datetime.now(timezone.utc) + timedelta( minutes=EMAIL_VERIFICATION_EXPIRE_MINUTES ) payload = { "sub": user_id, "email": email, "exp": expire, "type": "verify_email", "jti": str(uuid4()), } return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) async def is_token_blacklisted(jti: str) -> bool: try: blacklisted = await db_manager.blacklisted_tokens_collection.find_one({"jti": jti}) return blacklisted is not None except Exception as e: logger.error(f"Error checking token blacklist: {e}") return True async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserPublic: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) subject: Optional[str] = payload.get("sub") token_type: Optional[str] = payload.get("type") jti: Optional[str] = payload.get("jti") if subject is None or token_type != "access" or jti is None: raise credentials_exception if await is_token_blacklisted(jti): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"}, ) except JWTError as e: logger.warning(f"JWT decode error: {e}") raise credentials_exception try: if not ObjectId.is_valid(subject): raise credentials_exception user_doc = await db_manager.users_collection.find_one({"_id": ObjectId(subject)}) if not user_doc: raise credentials_exception if not user_doc.get("is_active", True) or not user_doc.get( "is_email_verified", False ): raise credentials_exception return UserPublic( id=str(user_doc["_id"]), email=user_doc["email"], name=user_doc["name"], created_at=user_doc["created_at"] ) except Exception as e: logger.error(f"Error fetching user: {e}") raise credentials_exception