| """ |
| infrastructure/database/repositories/sqlalchemy_base.py |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| SQLAlchemyBaseRepository[T, M] β generic CRUD implementation (DRY). |
| |
| All concrete repositories inherit from this class, which provides the |
| four standard operations. Subclasses only need to implement: |
| β’ _to_entity(model) β domain entity |
| β’ _to_model(entity) β ORM model |
| |
| Error translation: |
| All SQLAlchemy / asyncpg errors are caught here and re-raised as domain |
| exceptions, ensuring the Application and Domain layers never receive or |
| import SQLAlchemy types: |
| β’ IntegrityError (unique / FK violation) β ConflictError |
| β’ OperationalError / DBAPIError β DatabaseError |
| β’ Any other SQLAlchemyError β DatabaseError |
| |
| Generic type parameters: |
| T β domain entity type (e.g. PPGSignal) |
| M β SQLAlchemy ORM model type (e.g. PPGModel) |
| """ |
| from __future__ import annotations |
|
|
| from abc import abstractmethod |
| from typing import Generic, Optional, Type, TypeVar |
|
|
| from sqlalchemy import select |
| from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError |
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from src.domain.exceptions.domain_exceptions import ConflictError, DatabaseError |
| from src.domain.interfaces.repositories.base_repository import BaseRepository |
| from src.infrastructure.database.models.base import Base |
| from src.shared.logger import get_logger |
|
|
| T = TypeVar("T") |
| M = TypeVar("M", bound=Base) |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class SQLAlchemyBaseRepository(BaseRepository[T], Generic[T, M]): |
| """ |
| Generic async CRUD repository backed by SQLAlchemy. |
| |
| Subclasses must: |
| 1. Set ``_model_class`` to the ORM model class. |
| 2. Implement ``_to_entity(model: M) β T``. |
| 3. Implement ``_to_model(entity: T) β M``. |
| """ |
|
|
| _model_class: Type[M] |
|
|
| def __init__(self, session: AsyncSession) -> None: |
| self._session = session |
|
|
| |
|
|
| @abstractmethod |
| def _to_entity(self, model: M) -> T: |
| """Convert an ORM model instance to a domain entity.""" |
| ... |
|
|
| @abstractmethod |
| def _to_model(self, entity: T) -> M: |
| """Convert a domain entity to an ORM model instance.""" |
| ... |
|
|
| |
|
|
| async def add(self, entity: T) -> T: |
| """Persist a new entity and return it (with server-generated fields).""" |
| model = self._to_model(entity) |
| try: |
| self._session.add(model) |
| await self._session.flush() |
| await self._session.refresh(model) |
| except IntegrityError as exc: |
| await self._session.rollback() |
| entity_name = self._model_class.__name__ |
| detail = str(exc.orig) if exc.orig else str(exc) |
| logger.warning("IntegrityError in add() [%s]: %s", entity_name, detail) |
| raise ConflictError( |
| entity_type=entity_name, |
| detail=f"A record with the same unique key already exists. ({detail})", |
| ) from exc |
| except OperationalError as exc: |
| await self._session.rollback() |
| detail = str(exc.orig) if exc.orig else str(exc) |
| logger.error("OperationalError in add() [%s]: %s", self._model_class.__name__, detail) |
| raise DatabaseError(operation="add", reason=detail) from exc |
| except SQLAlchemyError as exc: |
| await self._session.rollback() |
| logger.error("SQLAlchemyError in add() [%s]: %s", self._model_class.__name__, exc) |
| raise DatabaseError(operation="add", reason=str(exc)) from exc |
|
|
| logger.debug("add() β %s id=%s", self._model_class.__name__, model.id) |
| return self._to_entity(model) |
|
|
| async def get_by_id(self, entity_id: str) -> Optional[T]: |
| """Return the entity with the given UUID, or None.""" |
| try: |
| result = await self._session.get(self._model_class, entity_id) |
| except OperationalError as exc: |
| detail = str(exc.orig) if exc.orig else str(exc) |
| logger.error("OperationalError in get_by_id() [%s]: %s", self._model_class.__name__, detail) |
| raise DatabaseError(operation="get_by_id", reason=detail) from exc |
| except SQLAlchemyError as exc: |
| logger.error("SQLAlchemyError in get_by_id() [%s]: %s", self._model_class.__name__, exc) |
| raise DatabaseError(operation="get_by_id", reason=str(exc)) from exc |
|
|
| if result is None: |
| return None |
| return self._to_entity(result) |
|
|
| async def get_all(self, limit: int = 50, offset: int = 0) -> list[T]: |
| """Return a paginated list of all entities.""" |
| stmt = ( |
| select(self._model_class) |
| .order_by(self._model_class.created_at.desc()) |
| .limit(limit) |
| .offset(offset) |
| ) |
| try: |
| result = await self._session.execute(stmt) |
| except OperationalError as exc: |
| detail = str(exc.orig) if exc.orig else str(exc) |
| logger.error("OperationalError in get_all() [%s]: %s", self._model_class.__name__, detail) |
| raise DatabaseError(operation="get_all", reason=detail) from exc |
| except SQLAlchemyError as exc: |
| logger.error("SQLAlchemyError in get_all() [%s]: %s", self._model_class.__name__, exc) |
| raise DatabaseError(operation="get_all", reason=str(exc)) from exc |
|
|
| models = result.scalars().all() |
| return [self._to_entity(m) for m in models] |
|
|
| async def delete(self, entity_id: str) -> bool: |
| """Delete an entity by ID. Returns True if found and deleted.""" |
| try: |
| model = await self._session.get(self._model_class, entity_id) |
| if model is None: |
| return False |
| await self._session.delete(model) |
| await self._session.flush() |
| except OperationalError as exc: |
| await self._session.rollback() |
| detail = str(exc.orig) if exc.orig else str(exc) |
| logger.error("OperationalError in delete() [%s]: %s", self._model_class.__name__, detail) |
| raise DatabaseError(operation="delete", reason=detail) from exc |
| except SQLAlchemyError as exc: |
| await self._session.rollback() |
| logger.error("SQLAlchemyError in delete() [%s]: %s", self._model_class.__name__, exc) |
| raise DatabaseError(operation="delete", reason=str(exc)) from exc |
|
|
| logger.debug("delete() β %s id=%s", self._model_class.__name__, entity_id) |
| return True |
|
|