llm-ready-data / app /services /auth_service.py
light-infer-chat's picture
ok
08919be
Raw
History Blame Contribute Delete
13.4 kB
from __future__ import annotations
import hashlib
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from fastapi import HTTPException, Request, status
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import get_settings
from app.core.auth.models import RefreshSession, Role, User
from app.core.auth.schemas import (
ChangePasswordSchema,
ForgotPasswordSchema,
LoginSchema,
RegisterSchema,
ResetPasswordSchema,
TokenData,
UpdateProfileSchema,
UserProfile,
)
logger = logging.getLogger("auth_service")
_settings = get_settings()
ph = PasswordHasher()
def _now() -> datetime:
return datetime.now(timezone.utc)
def _token_key(raw: str) -> str:
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
class AuthService:
@staticmethod
async def register(db: AsyncSession, schema: RegisterSchema) -> User:
result = await db.execute(select(User).where(User.email == schema.email))
if result.scalars().first():
raise HTTPException(status_code=409, detail="Email already registered")
if schema.username:
result = await db.execute(select(User).where(User.username == schema.username))
if result.scalars().first():
raise HTTPException(status_code=409, detail="Username already taken")
user = User(
email=schema.email,
username=schema.username,
full_name=schema.full_name,
password_hash=ph.hash(schema.password),
)
role_result = await db.execute(select(Role).where(Role.name == "User"))
default_role = role_result.scalars().first()
if default_role:
user.roles.append(default_role)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@staticmethod
async def login(db: AsyncSession, request: Request, schema: LoginSchema) -> TokenData:
result = await db.execute(
select(User).where(User.email == schema.email, User.deleted_at.is_(None))
)
user = result.scalars().first()
if not user:
raise HTTPException(status_code=401, detail="Incorrect email or password")
if user.locked_until and user.locked_until > _now():
raise HTTPException(
status_code=403,
detail=f"Account locked until {user.locked_until.isoformat()}",
)
if not AuthService._verify_password(schema.password, user.password_hash):
user.failed_login_attempts += 1
if user.failed_login_attempts >= _settings.max_login_attempts:
user.locked_until = _now() + timedelta(minutes=_settings.lockout_minutes)
await db.commit()
raise HTTPException(status_code=401, detail="Incorrect email or password")
user.failed_login_attempts = 0
user.locked_until = None
user.last_login = _now()
access_token = AuthService._create_access_token(user.id)
raw_refresh, refresh_hash, token_key, expires_at = AuthService._create_refresh_token()
session = RefreshSession(
user_id=user.id,
token_key=token_key,
token_hash=refresh_hash,
expires_at=expires_at,
device_info=request.headers.get("User-Agent", "Unknown"),
ip_address=request.client.host if request.client else "Unknown",
)
db.add(session)
await db.commit()
return TokenData(access_token=access_token, refresh_token=raw_refresh)
@staticmethod
async def refresh(db: AsyncSession, raw_refresh_token: str) -> TokenData:
key = _token_key(raw_refresh_token)
result = await db.execute(
select(RefreshSession).where(
RefreshSession.token_key == key,
RefreshSession.revoked_at.is_(None),
RefreshSession.expires_at > _now(),
)
)
session = result.scalars().first()
if not session:
session = await db.execute(
select(RefreshSession).where(RefreshSession.token_key == key)
)
existing = session.scalars().first()
if existing and existing.revoked_at is not None:
await db.execute(
update(RefreshSession)
.where(RefreshSession.user_id == existing.user_id, RefreshSession.revoked_at.is_(None))
.values(revoked_at=_now())
)
await db.commit()
raise HTTPException(
status_code=401,
detail="Session compromised. All sessions revoked. Please login again.",
)
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
if not AuthService._verify_token(raw_refresh_token, session.token_hash):
raise HTTPException(status_code=401, detail="Invalid refresh token")
user_result = await db.execute(
select(User).where(
User.id == session.user_id, User.deleted_at.is_(None), User.is_active.is_(True)
)
)
user = user_result.scalars().first()
if not user:
raise HTTPException(status_code=401, detail="User not found or inactive")
session.revoked_at = _now()
new_access = AuthService._create_access_token(user.id)
new_raw_refresh, new_hash, new_key, new_expires = AuthService._create_refresh_token()
new_session = RefreshSession(
user_id=user.id,
token_key=new_key,
token_hash=new_hash,
expires_at=new_expires,
device_info=session.device_info,
ip_address=session.ip_address,
)
db.add(new_session)
user.last_login = _now()
await db.commit()
return TokenData(access_token=new_access, refresh_token=new_raw_refresh)
@staticmethod
async def logout(db: AsyncSession, user: User, raw_refresh_token: str):
key = _token_key(raw_refresh_token)
result = await db.execute(
select(RefreshSession).where(
RefreshSession.token_key == key,
RefreshSession.user_id == user.id,
RefreshSession.revoked_at.is_(None),
)
)
session = result.scalars().first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
session.revoked_at = _now()
await db.commit()
@staticmethod
async def logout_all(db: AsyncSession, user: User):
now = _now()
await db.execute(
update(RefreshSession)
.where(
RefreshSession.user_id == user.id,
RefreshSession.revoked_at.is_(None),
)
.values(revoked_at=now)
)
await db.commit()
@staticmethod
async def change_password(db: AsyncSession, user: User, schema: ChangePasswordSchema):
if not AuthService._verify_password(schema.current_password, user.password_hash):
raise HTTPException(status_code=400, detail="Incorrect current password")
user.password_hash = ph.hash(schema.new_password)
user.password_changed_at = _now()
await db.commit()
@staticmethod
async def forgot_password(db: AsyncSession, schema: ForgotPasswordSchema):
result = await db.execute(
select(User).where(User.email == schema.email, User.deleted_at.is_(None))
)
user = result.scalars().first()
if user:
reset_token = jwt.encode(
{
"sub": user.id,
"type": "reset_password",
"exp": _now() + timedelta(hours=1),
},
_settings.jwt_secret_key,
algorithm=_settings.jwt_algorithm,
)
logger.info("Password reset token for %s: %s", user.email, reset_token)
@staticmethod
async def reset_password(db: AsyncSession, schema: ResetPasswordSchema):
try:
payload = jwt.decode(
schema.token,
_settings.jwt_secret_key,
algorithms=[_settings.jwt_algorithm],
)
if payload.get("type") != "reset_password":
raise HTTPException(status_code=400, detail="Invalid token type")
user_id = payload.get("sub")
except jwt.PyJWTError:
raise HTTPException(status_code=400, detail="Invalid or expired reset token")
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalars().first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.password_hash = ph.hash(schema.new_password)
user.password_changed_at = _now()
await db.execute(
update(RefreshSession)
.where(RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None))
.values(revoked_at=_now())
)
await db.commit()
@staticmethod
async def update_profile(db: AsyncSession, user: User, schema: UpdateProfileSchema) -> User:
if schema.username is not None:
result = await db.execute(
select(User).where(User.username == schema.username, User.id != user.id)
)
if result.scalars().first():
raise HTTPException(status_code=409, detail="Username already taken")
user.username = schema.username
if schema.full_name is not None:
user.full_name = schema.full_name
await db.commit()
await db.refresh(user)
return user
@staticmethod
async def soft_delete(db: AsyncSession, user: User):
user.deleted_at = _now()
user.is_active = False
await db.execute(
update(RefreshSession)
.where(RefreshSession.user_id == user.id, RefreshSession.revoked_at.is_(None))
.values(revoked_at=_now())
)
await db.commit()
@staticmethod
async def list_sessions(db: AsyncSession, user: User) -> list[RefreshSession]:
result = await db.execute(
select(RefreshSession)
.where(
RefreshSession.user_id == user.id,
RefreshSession.revoked_at.is_(None),
)
.order_by(RefreshSession.created_at.desc())
)
return list(result.scalars().all())
@staticmethod
async def revoke_session(db: AsyncSession, user: User, session_id: str):
result = await db.execute(
select(RefreshSession).where(
RefreshSession.id == session_id,
RefreshSession.user_id == user.id,
RefreshSession.revoked_at.is_(None),
)
)
session = result.scalars().first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
session.revoked_at = _now()
await db.commit()
@staticmethod
async def cleanup_expired_sessions(db: AsyncSession):
result = await db.execute(
select(RefreshSession).where(RefreshSession.expires_at < _now())
)
count = 0
for s in result.scalars().all():
if s.revoked_at is None:
s.revoked_at = _now()
count += 1
if count:
await db.commit()
logger.info("Cleaned up %s expired sessions", count)
@staticmethod
def _create_access_token(user_id: str) -> str:
now = _now()
payload = {
"sub": user_id,
"type": "access",
"iat": now,
"exp": now + timedelta(minutes=_settings.access_token_expire_minutes),
"iss": _settings.jwt_issuer,
}
return jwt.encode(payload, _settings.jwt_secret_key, algorithm=_settings.jwt_algorithm)
@staticmethod
def _create_refresh_token() -> tuple[str, str, str, datetime]:
token = secrets.token_urlsafe(64)
token_hash = ph.hash(token)
token_key = _token_key(token)
expires_at = _now() + timedelta(days=_settings.refresh_token_expire_days)
return token, token_hash, token_key, expires_at
@staticmethod
def _verify_password(plain: str, hashed: str) -> bool:
try:
ph.verify(hashed, plain)
return True
except VerifyMismatchError:
return False
@staticmethod
def _verify_token(raw: str, hashed: str) -> bool:
try:
ph.verify(hashed, raw)
return True
except VerifyMismatchError:
return False
@staticmethod
def user_to_profile(user: User) -> UserProfile:
return UserProfile(
id=user.id,
email=user.email,
username=user.username,
full_name=user.full_name,
is_active=user.is_active,
is_verified=user.is_verified,
roles=[r.name for r in user.roles],
created_at=user.created_at,
)