firermsdata-agent / src /entities /base /_repository.py
Aryan Jain
feat: Initialize FastAPI project with database configuration and CRUD structure
047d92b
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Type, TypeVar
from sqlalchemy import (
JSON,
BigInteger,
Date,
Integer,
Numeric,
String,
Text,
asc,
desc,
func,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy.sql import Select
from src.configs import DatabaseConfig
from ._model import BaseModel_
ModelT = TypeVar("ModelT", bound=BaseModel_)
class BaseRepository:
def __init__(self, model: Type[ModelT]):
self.model = model
def _convert_filter_by(self, filter_by: Dict[str, Any]) -> Dict[str, Any]:
converted = {}
for k, v in filter_by.items():
column = self.model.__table__.columns[k]
col_type = column.type
if isinstance(col_type, (BigInteger, Integer)):
try:
converted[k] = int(v)
except ValueError:
raise ValueError(f"Invalid integer value for {k}: {v}")
elif isinstance(col_type, (String, Text)):
converted[k] = str(v)
elif isinstance(col_type, Numeric):
try:
converted[k] = float(v)
except ValueError:
raise ValueError(f"Invalid numeric value for {k}: {v}")
elif isinstance(col_type, Date):
from datetime import datetime
try:
converted[k] = datetime.strptime(str(v), "%Y-%m-%d").date()
except ValueError:
raise ValueError(f"Invalid date value for {k}: {v}")
elif isinstance(col_type, JSON):
import json
try:
converted[k] = json.loads(str(v))
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON value for {k}: {v}")
else:
converted[k] = v
return converted
@asynccontextmanager
async def get_session(self):
async with DatabaseConfig.async_session() as session:
try:
yield session
except SQLAlchemyError as e:
try:
await session.rollback()
except Exception as e2:
print(f"Rollback failed: {e2}")
raise e
finally:
await session.close()
async def create(self, object: ModelT):
async with self.get_session() as session:
session.add(object)
await session.commit()
await session.refresh(object)
return object
async def list(
self,
page: int = 1,
page_size: int = 10,
order_by: Optional[List[str]] = None,
filter_by: Optional[Dict[str, Any]] = None,
):
column_names = [c.key for c in self.model.__table__.columns]
filter_by = {k: v for k, v in (filter_by or {}).items() if k in column_names}
filter_by = self._convert_filter_by(filter_by)
order_by = [
field for field in (order_by or []) if field.lstrip("-") in column_names
]
offset: int = (page - 1) * page_size
async with self.get_session() as session:
query: Select[Any] = (
select(self.model)
.filter_by(**filter_by)
.offset(offset)
.limit(page_size)
)
for field in order_by:
if field.startswith("-"):
query = query.order_by(desc(getattr(self.model, field[1:])))
else:
query = query.order_by(asc(getattr(self.model, field)))
result = await session.execute(query)
return [item for item in result.scalars().all()]
async def get(self, id: int):
async with self.get_session() as session:
result = await session.get(self.model, id)
return result
async def patch(self, id: int, **kwargs: Dict[str, Any]):
async with self.get_session() as session:
instance = await session.get(self.model, id)
if not instance:
raise SQLAlchemyError(f"{self.model.__name__} not found with id {id}")
column_names = [c.key for c in self.model.__table__.columns]
for key, value in kwargs.items():
if key not in column_names:
raise SQLAlchemyError(f"Invalid field: {key}")
setattr(instance, key, value)
await session.commit()
await session.refresh(instance)
return instance
async def delete(self, id: int):
async with self.get_session() as session:
instance = await session.get(self.model, id)
if not instance:
raise SQLAlchemyError(f"{self.model.__name__} not found with id {id}")
await session.delete(instance)
await session.commit()
async def count(self, filter_by: Optional[Dict[str, Any]] = None):
column_names = [c.key for c in self.model.__table__.columns]
filter_by = {k: v for k, v in (filter_by or {}).items() if k in column_names}
filter_by = self._convert_filter_by(filter_by)
async with self.get_session() as session:
query: Select[Any] = (
select(func.count()).select_from(self.model).filter_by(**filter_by)
)
result = await session.execute(query)
return result.scalar_one() or 0
async def execute(self, query: Any):
async with self.get_session() as session:
result = await session.execute(query)
return result