from __future__ import annotations import hashlib import logging import secrets from datetime import datetime, timedelta, timezone from typing import Optional import jwt from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from fastapi import HTTPException, Request, status from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings from app.core.auth.models import RefreshSession, Role, User from app.core.auth.schemas import ( ChangePasswordSchema, ForgotPasswordSchema, LoginSchema, RegisterSchema, ResetPasswordSchema, TokenData, UpdateProfileSchema, UserProfile, ) logger = logging.getLogger("auth_service") _settings = get_settings() ph = PasswordHasher() def _now() -> datetime: return datetime.now(timezone.utc) def _token_key(raw: str) -> str: return hashlib.sha256(raw.encode("utf-8")).hexdigest() class AuthService: @staticmethod async def register(db: AsyncSession, schema: RegisterSchema) -> User: result = await db.execute(select(User).where(User.email == schema.email)) if result.scalars().first(): raise HTTPException(status_code=409, detail="Email already registered") if schema.username: result = await db.execute(select(User).where(User.username == schema.username)) if result.scalars().first(): raise HTTPException(status_code=409, detail="Username already taken") user = User( email=schema.email, username=schema.username, full_name=schema.full_name, password_hash=ph.hash(schema.password), ) role_result = await db.execute(select(Role).where(Role.name == "User")) default_role = role_result.scalars().first() if default_role: user.roles.append(default_role) db.add(user) await db.commit() await db.refresh(user) return user @staticmethod async def login(db: AsyncSession, request: Request, schema: LoginSchema) -> TokenData: result = await db.execute( select(User).where(User.email == schema.email, User.deleted_at.is_(None)) ) user = result.scalars().first() if not user: raise HTTPException(status_code=401, detail="Incorrect email or password") if user.locked_until and user.locked_until > _now(): raise HTTPException( status_code=403, detail=f"Account locked until {user.locked_until.isoformat()}", ) if not AuthService._verify_password(schema.password, user.password_hash): user.failed_login_attempts += 1 if user.failed_login_attempts >= _settings.max_login_attempts: user.locked_until = _now() + timedelta(minutes=_settings.lockout_minutes) await db.commit() raise HTTPException(status_code=401, detail="Incorrect email or password") user.failed_login_attempts = 0 user.locked_until = None user.last_login = _now() access_token = AuthService._create_access_token(user.id) raw_refresh, refresh_hash, token_key, expires_at = AuthService._create_refresh_token() session = RefreshSession( user_id=user.id, token_key=token_key, token_hash=refresh_hash, expires_at=expires_at, device_info=request.headers.get("User-Agent", "Unknown"), ip_address=request.client.host if request.client else "Unknown", ) db.add(session) await db.commit() return TokenData(access_token=access_token, refresh_token=raw_refresh) @staticmethod async def refresh(db: AsyncSession, raw_refresh_token: str) -> TokenData: key = _token_key(raw_refresh_token) result = await db.execute( select(RefreshSession).where( RefreshSession.token_key == key, RefreshSession.revoked_at.is_(None), RefreshSession.expires_at > _now(), ) ) session = result.scalars().first() if not session: session = await db.execute( select(RefreshSession).where(RefreshSession.token_key == key) ) existing = session.scalars().first() if existing and existing.revoked_at is not None: await db.execute( update(RefreshSession) .where(RefreshSession.user_id == existing.user_id, RefreshSession.revoked_at.is_(None)) .values(revoked_at=_now()) ) await db.commit() raise HTTPException( status_code=401, detail="Session compromised. All sessions revoked. Please login again.", ) raise HTTPException(status_code=401, detail="Invalid or expired refresh token") if not AuthService._verify_token(raw_refresh_token, session.token_hash): raise HTTPException(status_code=401, detail="Invalid refresh token") user_result = await db.execute( select(User).where( User.id == session.user_id, User.deleted_at.is_(None), User.is_active.is_(True) ) ) user = user_result.scalars().first() if not user: raise HTTPException(status_code=401, detail="User not found or inactive") session.revoked_at = _now() new_access = AuthService._create_access_token(user.id) new_raw_refresh, new_hash, new_key, new_expires = AuthService._create_refresh_token() new_session = RefreshSession( user_id=user.id, token_key=new_key, token_hash=new_hash, expires_at=new_expires, device_info=session.device_info, ip_address=session.ip_address, ) db.add(new_session) user.last_login = _now() await db.commit() return TokenData(access_token=new_access, refresh_token=new_raw_refresh) @staticmethod async def logout(db: AsyncSession, user: User, raw_refresh_token: str): key = _token_key(raw_refresh_token) result = await db.execute( select(RefreshSession).where( RefreshSession.token_key == key, RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None), ) ) session = result.scalars().first() if not session: raise HTTPException(status_code=404, detail="Session not found") session.revoked_at = _now() await db.commit() @staticmethod async def logout_all(db: AsyncSession, user: User): now = _now() await db.execute( update(RefreshSession) .where( RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None), ) .values(revoked_at=now) ) await db.commit() @staticmethod async def change_password(db: AsyncSession, user: User, schema: ChangePasswordSchema): if not AuthService._verify_password(schema.current_password, user.password_hash): raise HTTPException(status_code=400, detail="Incorrect current password") user.password_hash = ph.hash(schema.new_password) user.password_changed_at = _now() await db.commit() @staticmethod async def forgot_password(db: AsyncSession, schema: ForgotPasswordSchema): result = await db.execute( select(User).where(User.email == schema.email, User.deleted_at.is_(None)) ) user = result.scalars().first() if user: reset_token = jwt.encode( { "sub": user.id, "type": "reset_password", "exp": _now() + timedelta(hours=1), }, _settings.jwt_secret_key, algorithm=_settings.jwt_algorithm, ) logger.info("Password reset token for %s: %s", user.email, reset_token) @staticmethod async def reset_password(db: AsyncSession, schema: ResetPasswordSchema): try: payload = jwt.decode( schema.token, _settings.jwt_secret_key, algorithms=[_settings.jwt_algorithm], ) if payload.get("type") != "reset_password": raise HTTPException(status_code=400, detail="Invalid token type") user_id = payload.get("sub") except jwt.PyJWTError: raise HTTPException(status_code=400, detail="Invalid or expired reset token") result = await db.execute(select(User).where(User.id == user_id)) user = result.scalars().first() if not user: raise HTTPException(status_code=404, detail="User not found") user.password_hash = ph.hash(schema.new_password) user.password_changed_at = _now() await db.execute( update(RefreshSession) .where(RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None)) .values(revoked_at=_now()) ) await db.commit() @staticmethod async def update_profile(db: AsyncSession, user: User, schema: UpdateProfileSchema) -> User: if schema.username is not None: result = await db.execute( select(User).where(User.username == schema.username, User.id != user.id) ) if result.scalars().first(): raise HTTPException(status_code=409, detail="Username already taken") user.username = schema.username if schema.full_name is not None: user.full_name = schema.full_name await db.commit() await db.refresh(user) return user @staticmethod async def soft_delete(db: AsyncSession, user: User): user.deleted_at = _now() user.is_active = False await db.execute( update(RefreshSession) .where(RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None)) .values(revoked_at=_now()) ) await db.commit() @staticmethod async def list_sessions(db: AsyncSession, user: User) -> list[RefreshSession]: result = await db.execute( select(RefreshSession) .where( RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None), ) .order_by(RefreshSession.created_at.desc()) ) return list(result.scalars().all()) @staticmethod async def revoke_session(db: AsyncSession, user: User, session_id: str): result = await db.execute( select(RefreshSession).where( RefreshSession.id == session_id, RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None), ) ) session = result.scalars().first() if not session: raise HTTPException(status_code=404, detail="Session not found") session.revoked_at = _now() await db.commit() @staticmethod async def cleanup_expired_sessions(db: AsyncSession): result = await db.execute( select(RefreshSession).where(RefreshSession.expires_at < _now()) ) count = 0 for s in result.scalars().all(): if s.revoked_at is None: s.revoked_at = _now() count += 1 if count: await db.commit() logger.info("Cleaned up %s expired sessions", count) @staticmethod def _create_access_token(user_id: str) -> str: now = _now() payload = { "sub": user_id, "type": "access", "iat": now, "exp": now + timedelta(minutes=_settings.access_token_expire_minutes), "iss": _settings.jwt_issuer, } return jwt.encode(payload, _settings.jwt_secret_key, algorithm=_settings.jwt_algorithm) @staticmethod def _create_refresh_token() -> tuple[str, str, str, datetime]: token = secrets.token_urlsafe(64) token_hash = ph.hash(token) token_key = _token_key(token) expires_at = _now() + timedelta(days=_settings.refresh_token_expire_days) return token, token_hash, token_key, expires_at @staticmethod def _verify_password(plain: str, hashed: str) -> bool: try: ph.verify(hashed, plain) return True except VerifyMismatchError: return False @staticmethod def _verify_token(raw: str, hashed: str) -> bool: try: ph.verify(hashed, raw) return True except VerifyMismatchError: return False @staticmethod def user_to_profile(user: User) -> UserProfile: return UserProfile( id=user.id, email=user.email, username=user.username, full_name=user.full_name, is_active=user.is_active, is_verified=user.is_verified, roles=[r.name for r in user.roles], created_at=user.created_at, )