memora / utils /utils.py
167AliRaza's picture
add email verification
d3fd803
from datetime import datetime, timedelta, timezone
import hashlib
import logging
from typing import Any, Dict, Optional
from uuid import uuid4
import bcrypt
from bson import ObjectId
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from config.database import db_manager
from config.settings import settings
from validation.validation import UserPublic
logger = logging.getLogger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
JWT_SECRET_KEY = settings.jwt_secret_key
JWT_ALGORITHM = settings.jwt_algorithm
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
EMAIL_VERIFICATION_EXPIRE_MINUTES = settings.email_verification_expire_minutes
PASSWORD_HASH_PREFIX = "bcrypt_sha256$"
def _bcrypt_sha256_input(password: str) -> bytes:
return hashlib.sha256(password.encode("utf-8")).hexdigest().encode("ascii")
def hash_password(plain_password: str) -> str:
try:
password_bytes = _bcrypt_sha256_input(plain_password)
hashed_password = bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8")
return f"{PASSWORD_HASH_PREFIX}{hashed_password}"
except Exception as e:
logger.error(f"Error hashing password: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
def verify_password(plain_password: str, hashed_password: str) -> bool:
try:
if hashed_password.startswith(PASSWORD_HASH_PREFIX):
stored_hash = hashed_password[len(PASSWORD_HASH_PREFIX):]
return bcrypt.checkpw(
_bcrypt_sha256_input(plain_password),
stored_hash.encode("utf-8"),
)
# Backward compatibility for existing bcrypt hashes created before this change.
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
except Exception as e:
logger.error(f"Error verifying password: {e}")
return False
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire, "type": "access", "jti": str(uuid4())})
return jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
def create_email_verification_token(user_id: str, email: str) -> str:
expire = datetime.now(timezone.utc) + timedelta(
minutes=EMAIL_VERIFICATION_EXPIRE_MINUTES
)
payload = {
"sub": user_id,
"email": email,
"exp": expire,
"type": "verify_email",
"jti": str(uuid4()),
}
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
async def is_token_blacklisted(jti: str) -> bool:
try:
blacklisted = await db_manager.blacklisted_tokens_collection.find_one({"jti": jti})
return blacklisted is not None
except Exception as e:
logger.error(f"Error checking token blacklist: {e}")
return True
async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserPublic:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
subject: Optional[str] = payload.get("sub")
token_type: Optional[str] = payload.get("type")
jti: Optional[str] = payload.get("jti")
if subject is None or token_type != "access" or jti is None:
raise credentials_exception
if await is_token_blacklisted(jti):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
except JWTError as e:
logger.warning(f"JWT decode error: {e}")
raise credentials_exception
try:
if not ObjectId.is_valid(subject):
raise credentials_exception
user_doc = await db_manager.users_collection.find_one({"_id": ObjectId(subject)})
if not user_doc:
raise credentials_exception
if not user_doc.get("is_active", True) or not user_doc.get(
"is_email_verified", False
):
raise credentials_exception
return UserPublic(
id=str(user_doc["_id"]),
email=user_doc["email"],
name=user_doc["name"],
created_at=user_doc["created_at"]
)
except Exception as e:
logger.error(f"Error fetching user: {e}")
raise credentials_exception