Spaces:
Sleeping
Sleeping
| import hashlib | |
| import secrets | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional | |
| from jose import JWTError, jwt | |
| from fastapi import Depends, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from bson import ObjectId | |
| from app.config import settings | |
| from app.database import get_db | |
| import bcrypt | |
| security = HTTPBearer() | |
| def hash_password(password: str) -> str: | |
| pwd_bytes = password.encode('utf-8') | |
| salt = bcrypt.gensalt() | |
| hashed_password = bcrypt.hashpw(pwd_bytes, salt) | |
| return hashed_password.decode('utf-8') | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| try: | |
| password_byte_enc = plain_password.encode('utf-8') | |
| hashed_password_byte_enc = hashed_password.encode('utf-8') | |
| return bcrypt.checkpw(password_byte_enc, hashed_password_byte_enc) | |
| except Exception: | |
| return False | |
| def hash_token(token: str) -> str: | |
| return hashlib.sha256(token.encode("utf-8")).hexdigest() | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | |
| to_encode = data.copy() | |
| expire = datetime.now(timezone.utc) + ( | |
| expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) | |
| ) | |
| to_encode.update( | |
| { | |
| "exp": expire, | |
| "token_type": "access", | |
| "jti": secrets.token_urlsafe(16), | |
| } | |
| ) | |
| return jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) | |
| def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | |
| to_encode = data.copy() | |
| expire = datetime.now(timezone.utc) + ( | |
| expires_delta or timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) | |
| ) | |
| to_encode.update( | |
| { | |
| "exp": expire, | |
| "token_type": "refresh", | |
| "jti": secrets.token_urlsafe(24), | |
| } | |
| ) | |
| return jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) | |
| def decode_jwt(token: str) -> dict: | |
| return jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]) | |
| async def get_token_payload(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict: | |
| token = credentials.credentials | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = decode_jwt(token) | |
| token_type = payload.get("token_type") | |
| if token_type != "access": | |
| raise credentials_exception | |
| user_id: str = payload.get("sub") | |
| if user_id is None: | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| return payload | |
| async def get_current_user( | |
| payload: dict = Depends(get_token_payload), | |
| ): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| user_id = payload.get("sub") | |
| token_jti = payload.get("jti") | |
| token_version = payload.get("tv", 0) | |
| db = get_db() | |
| user = await db.users.find_one({"_id": ObjectId(user_id)}) | |
| if user is None: | |
| raise credentials_exception | |
| if user.get("token_version", 0) != token_version: | |
| raise credentials_exception | |
| if token_jti: | |
| revoked = await db.revoked_access_tokens.find_one({"jti": token_jti}) | |
| if revoked: | |
| raise credentials_exception | |
| if not user.get("is_approved", False): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Account is pending admin approval", | |
| ) | |
| user["id"] = str(user["_id"]) | |
| return user | |
| async def require_role(required_role: str, user: dict = Depends(get_current_user)): | |
| if user.get("role") != required_role and user.get("role") != "admin": | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=f"Role '{required_role}' required", | |
| ) | |
| return user | |
| async def require_expert(user: dict = Depends(get_current_user)): | |
| if user.get("role") not in ("expert", "linguist", "translator", "admin"): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Reviewer role required", | |
| ) | |
| return user | |
| async def require_reviewer(user: dict = Depends(get_current_user)): | |
| if user.get("role") not in ("expert", "linguist", "translator", "admin"): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Reviewer role required", | |
| ) | |
| return user | |
| async def require_admin(user: dict = Depends(get_current_user)): | |
| if user.get("role") != "admin": | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Admin role required", | |
| ) | |
| return user | |
| async def revoke_access_token(payload: dict) -> None: | |
| jti = payload.get("jti") | |
| exp = payload.get("exp") | |
| if not jti: | |
| return | |
| db = get_db() | |
| expires_at = datetime.now(timezone.utc) | |
| if exp: | |
| try: | |
| expires_at = datetime.fromtimestamp(exp, tz=timezone.utc) | |
| except Exception: | |
| pass | |
| await db.revoked_access_tokens.update_one( | |
| {"jti": jti}, | |
| { | |
| "$set": { | |
| "jti": jti, | |
| "expires_at": expires_at, | |
| "created_at": datetime.now(timezone.utc), | |
| } | |
| }, | |
| upsert=True, | |
| ) | |