Spaces:
Sleeping
Sleeping
| from datetime import datetime, timedelta, timezone | |
| import hashlib | |
| import logging | |
| from typing import Any, Dict, Optional | |
| from uuid import uuid4 | |
| import bcrypt | |
| from bson import ObjectId | |
| from fastapi import Depends, HTTPException, status | |
| from fastapi.security import OAuth2PasswordBearer | |
| from jose import JWTError, jwt | |
| from config.database import db_manager | |
| from config.settings import settings | |
| from validation.validation import UserPublic | |
| logger = logging.getLogger(__name__) | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") | |
| JWT_SECRET_KEY = settings.jwt_secret_key | |
| JWT_ALGORITHM = settings.jwt_algorithm | |
| ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes | |
| EMAIL_VERIFICATION_EXPIRE_MINUTES = settings.email_verification_expire_minutes | |
| PASSWORD_HASH_PREFIX = "bcrypt_sha256$" | |
| def _bcrypt_sha256_input(password: str) -> bytes: | |
| return hashlib.sha256(password.encode("utf-8")).hexdigest().encode("ascii") | |
| def hash_password(plain_password: str) -> str: | |
| try: | |
| password_bytes = _bcrypt_sha256_input(plain_password) | |
| hashed_password = bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8") | |
| return f"{PASSWORD_HASH_PREFIX}{hashed_password}" | |
| except Exception as e: | |
| logger.error(f"Error hashing password: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| try: | |
| if hashed_password.startswith(PASSWORD_HASH_PREFIX): | |
| stored_hash = hashed_password[len(PASSWORD_HASH_PREFIX):] | |
| return bcrypt.checkpw( | |
| _bcrypt_sha256_input(plain_password), | |
| stored_hash.encode("utf-8"), | |
| ) | |
| # Backward compatibility for existing bcrypt hashes created before this change. | |
| return bcrypt.checkpw( | |
| plain_password.encode("utf-8"), | |
| hashed_password.encode("utf-8"), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error verifying password: {e}") | |
| return False | |
| def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: | |
| to_encode = data.copy() | |
| expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) | |
| to_encode.update({"exp": expire, "type": "access", "jti": str(uuid4())}) | |
| return jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) | |
| def create_email_verification_token(user_id: str, email: str) -> str: | |
| expire = datetime.now(timezone.utc) + timedelta( | |
| minutes=EMAIL_VERIFICATION_EXPIRE_MINUTES | |
| ) | |
| payload = { | |
| "sub": user_id, | |
| "email": email, | |
| "exp": expire, | |
| "type": "verify_email", | |
| "jti": str(uuid4()), | |
| } | |
| return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) | |
| async def is_token_blacklisted(jti: str) -> bool: | |
| try: | |
| blacklisted = await db_manager.blacklisted_tokens_collection.find_one({"jti": jti}) | |
| return blacklisted is not None | |
| except Exception as e: | |
| logger.error(f"Error checking token blacklist: {e}") | |
| return True | |
| async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserPublic: | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) | |
| subject: Optional[str] = payload.get("sub") | |
| token_type: Optional[str] = payload.get("type") | |
| jti: Optional[str] = payload.get("jti") | |
| if subject is None or token_type != "access" or jti is None: | |
| raise credentials_exception | |
| if await is_token_blacklisted(jti): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has been revoked", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| except JWTError as e: | |
| logger.warning(f"JWT decode error: {e}") | |
| raise credentials_exception | |
| try: | |
| if not ObjectId.is_valid(subject): | |
| raise credentials_exception | |
| user_doc = await db_manager.users_collection.find_one({"_id": ObjectId(subject)}) | |
| if not user_doc: | |
| raise credentials_exception | |
| if not user_doc.get("is_active", True) or not user_doc.get( | |
| "is_email_verified", False | |
| ): | |
| raise credentials_exception | |
| return UserPublic( | |
| id=str(user_doc["_id"]), | |
| email=user_doc["email"], | |
| name=user_doc["name"], | |
| created_at=user_doc["created_at"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error fetching user: {e}") | |
| raise credentials_exception | |