Spaces:
Sleeping
Sleeping
| import os | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional | |
| import bcrypt | |
| from dotenv import load_dotenv | |
| from fastapi import Depends, HTTPException, status | |
| from fastapi.security import OAuth2PasswordBearer | |
| from jose import JWTError, jwt | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from .connection import get_db | |
| from .models import User | |
| load_dotenv() | |
| SECRET_KEY: str = os.getenv("SECRET_KEY", "dev-secret-key-change-in-production") | |
| ALGORITHM: str = os.getenv("ALGORITHM", "HS256") | |
| # Default token lifetime: 100 days, so persistent logins survive app restarts. | |
| ACCESS_TOKEN_EXPIRE_MINUTES: int = int( | |
| os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "144000") | |
| ) | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") | |
| # --------------------------------------------------------------------------- | |
| # Password helpers | |
| # --------------------------------------------------------------------------- | |
| # bcrypt operates on the first 72 bytes of the input only and raises on longer | |
| # inputs in v5+. Encode and truncate explicitly so any password length is safe. | |
| def _to_bcrypt_bytes(password: str) -> bytes: | |
| return password.encode("utf-8")[:72] | |
| def hash_password(password: str) -> str: | |
| hashed = bcrypt.hashpw(_to_bcrypt_bytes(password), bcrypt.gensalt()) | |
| return hashed.decode("utf-8") | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| try: | |
| return bcrypt.checkpw( | |
| _to_bcrypt_bytes(plain_password), | |
| hashed_password.encode("utf-8"), | |
| ) | |
| except ValueError: | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # JWT helpers | |
| # --------------------------------------------------------------------------- | |
| def create_access_token( | |
| data: dict, | |
| expires_delta: Optional[timedelta] = None, | |
| ) -> str: | |
| to_encode = data.copy() | |
| expire = datetime.now(timezone.utc) + ( | |
| expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| ) | |
| to_encode["exp"] = expire | |
| return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| def decode_access_token(token: str) -> Optional[dict]: | |
| try: | |
| return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| except JWTError: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # FastAPI dependency | |
| # --------------------------------------------------------------------------- | |
| async def get_current_user( | |
| token: str = Depends(oauth2_scheme), | |
| db: AsyncSession = Depends(get_db), | |
| ) -> User: | |
| """Validate a Bearer token and return the authenticated User.""" | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| payload = decode_access_token(token) | |
| if payload is None: | |
| raise credentials_exception | |
| user_id = payload.get("sub") | |
| if user_id is None: | |
| raise credentials_exception | |
| result = await db.execute(select(User).where(User.id == int(user_id))) | |
| user: Optional[User] = result.scalar_one_or_none() | |
| if user is None or not user.is_active: | |
| raise credentials_exception | |
| return user | |