Spaces:
Running
Running
| from __future__ import annotations | |
| import hashlib | |
| import logging | |
| import secrets | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional | |
| import jwt | |
| from argon2 import PasswordHasher | |
| from argon2.exceptions import VerifyMismatchError | |
| from fastapi import HTTPException, Request, status | |
| from sqlalchemy import select, update | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from app.config import get_settings | |
| from app.core.auth.models import RefreshSession, Role, User | |
| from app.core.auth.schemas import ( | |
| ChangePasswordSchema, | |
| ForgotPasswordSchema, | |
| LoginSchema, | |
| RegisterSchema, | |
| ResetPasswordSchema, | |
| TokenData, | |
| UpdateProfileSchema, | |
| UserProfile, | |
| ) | |
| logger = logging.getLogger("auth_service") | |
| _settings = get_settings() | |
| ph = PasswordHasher() | |
| def _now() -> datetime: | |
| return datetime.now(timezone.utc) | |
| def _token_key(raw: str) -> str: | |
| return hashlib.sha256(raw.encode("utf-8")).hexdigest() | |
| class AuthService: | |
| async def register(db: AsyncSession, schema: RegisterSchema) -> User: | |
| result = await db.execute(select(User).where(User.email == schema.email)) | |
| if result.scalars().first(): | |
| raise HTTPException(status_code=409, detail="Email already registered") | |
| if schema.username: | |
| result = await db.execute(select(User).where(User.username == schema.username)) | |
| if result.scalars().first(): | |
| raise HTTPException(status_code=409, detail="Username already taken") | |
| user = User( | |
| email=schema.email, | |
| username=schema.username, | |
| full_name=schema.full_name, | |
| password_hash=ph.hash(schema.password), | |
| ) | |
| role_result = await db.execute(select(Role).where(Role.name == "User")) | |
| default_role = role_result.scalars().first() | |
| if default_role: | |
| user.roles.append(default_role) | |
| db.add(user) | |
| await db.commit() | |
| await db.refresh(user) | |
| return user | |
| async def login(db: AsyncSession, request: Request, schema: LoginSchema) -> TokenData: | |
| result = await db.execute( | |
| select(User).where(User.email == schema.email, User.deleted_at.is_(None)) | |
| ) | |
| user = result.scalars().first() | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Incorrect email or password") | |
| if user.locked_until and user.locked_until > _now(): | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"Account locked until {user.locked_until.isoformat()}", | |
| ) | |
| if not AuthService._verify_password(schema.password, user.password_hash): | |
| user.failed_login_attempts += 1 | |
| if user.failed_login_attempts >= _settings.max_login_attempts: | |
| user.locked_until = _now() + timedelta(minutes=_settings.lockout_minutes) | |
| await db.commit() | |
| raise HTTPException(status_code=401, detail="Incorrect email or password") | |
| user.failed_login_attempts = 0 | |
| user.locked_until = None | |
| user.last_login = _now() | |
| access_token = AuthService._create_access_token(user.id) | |
| raw_refresh, refresh_hash, token_key, expires_at = AuthService._create_refresh_token() | |
| session = RefreshSession( | |
| user_id=user.id, | |
| token_key=token_key, | |
| token_hash=refresh_hash, | |
| expires_at=expires_at, | |
| device_info=request.headers.get("User-Agent", "Unknown"), | |
| ip_address=request.client.host if request.client else "Unknown", | |
| ) | |
| db.add(session) | |
| await db.commit() | |
| return TokenData(access_token=access_token, refresh_token=raw_refresh) | |
| async def refresh(db: AsyncSession, raw_refresh_token: str) -> TokenData: | |
| key = _token_key(raw_refresh_token) | |
| result = await db.execute( | |
| select(RefreshSession).where( | |
| RefreshSession.token_key == key, | |
| RefreshSession.revoked_at.is_(None), | |
| RefreshSession.expires_at > _now(), | |
| ) | |
| ) | |
| session = result.scalars().first() | |
| if not session: | |
| session = await db.execute( | |
| select(RefreshSession).where(RefreshSession.token_key == key) | |
| ) | |
| existing = session.scalars().first() | |
| if existing and existing.revoked_at is not None: | |
| await db.execute( | |
| update(RefreshSession) | |
| .where(RefreshSession.user_id == existing.user_id, RefreshSession.revoked_at.is_(None)) | |
| .values(revoked_at=_now()) | |
| ) | |
| await db.commit() | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Session compromised. All sessions revoked. Please login again.", | |
| ) | |
| raise HTTPException(status_code=401, detail="Invalid or expired refresh token") | |
| if not AuthService._verify_token(raw_refresh_token, session.token_hash): | |
| raise HTTPException(status_code=401, detail="Invalid refresh token") | |
| user_result = await db.execute( | |
| select(User).where( | |
| User.id == session.user_id, User.deleted_at.is_(None), User.is_active.is_(True) | |
| ) | |
| ) | |
| user = user_result.scalars().first() | |
| if not user: | |
| raise HTTPException(status_code=401, detail="User not found or inactive") | |
| session.revoked_at = _now() | |
| new_access = AuthService._create_access_token(user.id) | |
| new_raw_refresh, new_hash, new_key, new_expires = AuthService._create_refresh_token() | |
| new_session = RefreshSession( | |
| user_id=user.id, | |
| token_key=new_key, | |
| token_hash=new_hash, | |
| expires_at=new_expires, | |
| device_info=session.device_info, | |
| ip_address=session.ip_address, | |
| ) | |
| db.add(new_session) | |
| user.last_login = _now() | |
| await db.commit() | |
| return TokenData(access_token=new_access, refresh_token=new_raw_refresh) | |
| async def logout(db: AsyncSession, user: User, raw_refresh_token: str): | |
| key = _token_key(raw_refresh_token) | |
| result = await db.execute( | |
| select(RefreshSession).where( | |
| RefreshSession.token_key == key, | |
| RefreshSession.user_id == user.id, | |
| RefreshSession.revoked_at.is_(None), | |
| ) | |
| ) | |
| session = result.scalars().first() | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session.revoked_at = _now() | |
| await db.commit() | |
| async def logout_all(db: AsyncSession, user: User): | |
| now = _now() | |
| await db.execute( | |
| update(RefreshSession) | |
| .where( | |
| RefreshSession.user_id == user.id, | |
| RefreshSession.revoked_at.is_(None), | |
| ) | |
| .values(revoked_at=now) | |
| ) | |
| await db.commit() | |
| async def change_password(db: AsyncSession, user: User, schema: ChangePasswordSchema): | |
| if not AuthService._verify_password(schema.current_password, user.password_hash): | |
| raise HTTPException(status_code=400, detail="Incorrect current password") | |
| user.password_hash = ph.hash(schema.new_password) | |
| user.password_changed_at = _now() | |
| await db.commit() | |
| async def forgot_password(db: AsyncSession, schema: ForgotPasswordSchema): | |
| result = await db.execute( | |
| select(User).where(User.email == schema.email, User.deleted_at.is_(None)) | |
| ) | |
| user = result.scalars().first() | |
| if user: | |
| reset_token = jwt.encode( | |
| { | |
| "sub": user.id, | |
| "type": "reset_password", | |
| "exp": _now() + timedelta(hours=1), | |
| }, | |
| _settings.jwt_secret_key, | |
| algorithm=_settings.jwt_algorithm, | |
| ) | |
| logger.info("Password reset token for %s: %s", user.email, reset_token) | |
| async def reset_password(db: AsyncSession, schema: ResetPasswordSchema): | |
| try: | |
| payload = jwt.decode( | |
| schema.token, | |
| _settings.jwt_secret_key, | |
| algorithms=[_settings.jwt_algorithm], | |
| ) | |
| if payload.get("type") != "reset_password": | |
| raise HTTPException(status_code=400, detail="Invalid token type") | |
| user_id = payload.get("sub") | |
| except jwt.PyJWTError: | |
| raise HTTPException(status_code=400, detail="Invalid or expired reset token") | |
| result = await db.execute(select(User).where(User.id == user_id)) | |
| user = result.scalars().first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| user.password_hash = ph.hash(schema.new_password) | |
| user.password_changed_at = _now() | |
| await db.execute( | |
| update(RefreshSession) | |
| .where(RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None)) | |
| .values(revoked_at=_now()) | |
| ) | |
| await db.commit() | |
| async def update_profile(db: AsyncSession, user: User, schema: UpdateProfileSchema) -> User: | |
| if schema.username is not None: | |
| result = await db.execute( | |
| select(User).where(User.username == schema.username, User.id != user.id) | |
| ) | |
| if result.scalars().first(): | |
| raise HTTPException(status_code=409, detail="Username already taken") | |
| user.username = schema.username | |
| if schema.full_name is not None: | |
| user.full_name = schema.full_name | |
| await db.commit() | |
| await db.refresh(user) | |
| return user | |
| async def soft_delete(db: AsyncSession, user: User): | |
| user.deleted_at = _now() | |
| user.is_active = False | |
| await db.execute( | |
| update(RefreshSession) | |
| .where(RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None)) | |
| .values(revoked_at=_now()) | |
| ) | |
| await db.commit() | |
| async def list_sessions(db: AsyncSession, user: User) -> list[RefreshSession]: | |
| result = await db.execute( | |
| select(RefreshSession) | |
| .where( | |
| RefreshSession.user_id == user.id, | |
| RefreshSession.revoked_at.is_(None), | |
| ) | |
| .order_by(RefreshSession.created_at.desc()) | |
| ) | |
| return list(result.scalars().all()) | |
| async def revoke_session(db: AsyncSession, user: User, session_id: str): | |
| result = await db.execute( | |
| select(RefreshSession).where( | |
| RefreshSession.id == session_id, | |
| RefreshSession.user_id == user.id, | |
| RefreshSession.revoked_at.is_(None), | |
| ) | |
| ) | |
| session = result.scalars().first() | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session.revoked_at = _now() | |
| await db.commit() | |
| async def cleanup_expired_sessions(db: AsyncSession): | |
| result = await db.execute( | |
| select(RefreshSession).where(RefreshSession.expires_at < _now()) | |
| ) | |
| count = 0 | |
| for s in result.scalars().all(): | |
| if s.revoked_at is None: | |
| s.revoked_at = _now() | |
| count += 1 | |
| if count: | |
| await db.commit() | |
| logger.info("Cleaned up %s expired sessions", count) | |
| def _create_access_token(user_id: str) -> str: | |
| now = _now() | |
| payload = { | |
| "sub": user_id, | |
| "type": "access", | |
| "iat": now, | |
| "exp": now + timedelta(minutes=_settings.access_token_expire_minutes), | |
| "iss": _settings.jwt_issuer, | |
| } | |
| return jwt.encode(payload, _settings.jwt_secret_key, algorithm=_settings.jwt_algorithm) | |
| def _create_refresh_token() -> tuple[str, str, str, datetime]: | |
| token = secrets.token_urlsafe(64) | |
| token_hash = ph.hash(token) | |
| token_key = _token_key(token) | |
| expires_at = _now() + timedelta(days=_settings.refresh_token_expire_days) | |
| return token, token_hash, token_key, expires_at | |
| def _verify_password(plain: str, hashed: str) -> bool: | |
| try: | |
| ph.verify(hashed, plain) | |
| return True | |
| except VerifyMismatchError: | |
| return False | |
| def _verify_token(raw: str, hashed: str) -> bool: | |
| try: | |
| ph.verify(hashed, raw) | |
| return True | |
| except VerifyMismatchError: | |
| return False | |
| def user_to_profile(user: User) -> UserProfile: | |
| return UserProfile( | |
| id=user.id, | |
| email=user.email, | |
| username=user.username, | |
| full_name=user.full_name, | |
| is_active=user.is_active, | |
| is_verified=user.is_verified, | |
| roles=[r.name for r in user.roles], | |
| created_at=user.created_at, | |
| ) | |