LIBRE / src /infrastructure /database /repositories /sqlalchemy_base.py
RyZ
feat: adding full working local ETL Pipeline
e391a84
Raw
History Blame Contribute Delete
7.06 kB
"""
infrastructure/database/repositories/sqlalchemy_base.py
─────────────────────────────────────────────────────────
SQLAlchemyBaseRepository[T, M] β€” generic CRUD implementation (DRY).
All concrete repositories inherit from this class, which provides the
four standard operations. Subclasses only need to implement:
β€’ _to_entity(model) β†’ domain entity
β€’ _to_model(entity) β†’ ORM model
Error translation:
All SQLAlchemy / asyncpg errors are caught here and re-raised as domain
exceptions, ensuring the Application and Domain layers never receive or
import SQLAlchemy types:
β€’ IntegrityError (unique / FK violation) β†’ ConflictError
β€’ OperationalError / DBAPIError β†’ DatabaseError
β€’ Any other SQLAlchemyError β†’ DatabaseError
Generic type parameters:
T β€” domain entity type (e.g. PPGSignal)
M β€” SQLAlchemy ORM model type (e.g. PPGModel)
"""
from __future__ import annotations
from abc import abstractmethod
from typing import Generic, Optional, Type, TypeVar
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from src.domain.exceptions.domain_exceptions import ConflictError, DatabaseError
from src.domain.interfaces.repositories.base_repository import BaseRepository
from src.infrastructure.database.models.base import Base
from src.shared.logger import get_logger
T = TypeVar("T") # Domain entity
M = TypeVar("M", bound=Base) # ORM model
logger = get_logger(__name__)
class SQLAlchemyBaseRepository(BaseRepository[T], Generic[T, M]):
"""
Generic async CRUD repository backed by SQLAlchemy.
Subclasses must:
1. Set ``_model_class`` to the ORM model class.
2. Implement ``_to_entity(model: M) β†’ T``.
3. Implement ``_to_model(entity: T) β†’ M``.
"""
_model_class: Type[M]
def __init__(self, session: AsyncSession) -> None:
self._session = session
# ── Abstract Mapping Methods ──────────────────────────────────────────────
@abstractmethod
def _to_entity(self, model: M) -> T:
"""Convert an ORM model instance to a domain entity."""
...
@abstractmethod
def _to_model(self, entity: T) -> M:
"""Convert a domain entity to an ORM model instance."""
...
# ── BaseRepository Implementation ─────────────────────────────────────────
async def add(self, entity: T) -> T:
"""Persist a new entity and return it (with server-generated fields)."""
model = self._to_model(entity)
try:
self._session.add(model)
await self._session.flush() # write to DB within current transaction
await self._session.refresh(model)
except IntegrityError as exc:
await self._session.rollback()
entity_name = self._model_class.__name__
detail = str(exc.orig) if exc.orig else str(exc)
logger.warning("IntegrityError in add() [%s]: %s", entity_name, detail)
raise ConflictError(
entity_type=entity_name,
detail=f"A record with the same unique key already exists. ({detail})",
) from exc
except OperationalError as exc:
await self._session.rollback()
detail = str(exc.orig) if exc.orig else str(exc)
logger.error("OperationalError in add() [%s]: %s", self._model_class.__name__, detail)
raise DatabaseError(operation="add", reason=detail) from exc
except SQLAlchemyError as exc:
await self._session.rollback()
logger.error("SQLAlchemyError in add() [%s]: %s", self._model_class.__name__, exc)
raise DatabaseError(operation="add", reason=str(exc)) from exc
logger.debug("add() β€” %s id=%s", self._model_class.__name__, model.id)
return self._to_entity(model)
async def get_by_id(self, entity_id: str) -> Optional[T]:
"""Return the entity with the given UUID, or None."""
try:
result = await self._session.get(self._model_class, entity_id)
except OperationalError as exc:
detail = str(exc.orig) if exc.orig else str(exc)
logger.error("OperationalError in get_by_id() [%s]: %s", self._model_class.__name__, detail)
raise DatabaseError(operation="get_by_id", reason=detail) from exc
except SQLAlchemyError as exc:
logger.error("SQLAlchemyError in get_by_id() [%s]: %s", self._model_class.__name__, exc)
raise DatabaseError(operation="get_by_id", reason=str(exc)) from exc
if result is None:
return None
return self._to_entity(result)
async def get_all(self, limit: int = 50, offset: int = 0) -> list[T]:
"""Return a paginated list of all entities."""
stmt = (
select(self._model_class)
.order_by(self._model_class.created_at.desc())
.limit(limit)
.offset(offset)
)
try:
result = await self._session.execute(stmt)
except OperationalError as exc:
detail = str(exc.orig) if exc.orig else str(exc)
logger.error("OperationalError in get_all() [%s]: %s", self._model_class.__name__, detail)
raise DatabaseError(operation="get_all", reason=detail) from exc
except SQLAlchemyError as exc:
logger.error("SQLAlchemyError in get_all() [%s]: %s", self._model_class.__name__, exc)
raise DatabaseError(operation="get_all", reason=str(exc)) from exc
models = result.scalars().all()
return [self._to_entity(m) for m in models]
async def delete(self, entity_id: str) -> bool:
"""Delete an entity by ID. Returns True if found and deleted."""
try:
model = await self._session.get(self._model_class, entity_id)
if model is None:
return False
await self._session.delete(model)
await self._session.flush()
except OperationalError as exc:
await self._session.rollback()
detail = str(exc.orig) if exc.orig else str(exc)
logger.error("OperationalError in delete() [%s]: %s", self._model_class.__name__, detail)
raise DatabaseError(operation="delete", reason=detail) from exc
except SQLAlchemyError as exc:
await self._session.rollback()
logger.error("SQLAlchemyError in delete() [%s]: %s", self._model_class.__name__, exc)
raise DatabaseError(operation="delete", reason=str(exc)) from exc
logger.debug("delete() β€” %s id=%s", self._model_class.__name__, entity_id)
return True