ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""CatalogStore — persists per-user catalogs as Postgres jsonb rows.
Storage shape: one row per user in a `catalogs` table with columns
(user_id PK, data jsonb, schema_version, generated_at, updated_at).
"""
from sqlalchemy import case, delete, func, select
from sqlalchemy.dialects.postgresql import insert
from src.db.postgres.connection import AsyncSessionLocal
from src.db.postgres.models import Catalog as CatalogRow
from src.middlewares.logging import get_logger
from .models import Catalog
logger = get_logger("catalog_store")
class CatalogStore:
"""Read/write catalogs keyed by user_id.
Each method opens its own AsyncSession. Callers needing transactional
coordination across multiple stores can be refactored to accept an
explicit AsyncSession in a later PR.
"""
async def get(self, user_id: str) -> Catalog | None:
async with AsyncSessionLocal() as session:
result = await session.execute(
select(CatalogRow.data).where(CatalogRow.user_id == user_id)
)
row = result.scalar_one_or_none()
if row is None:
return None
return Catalog.model_validate(row)
async def upsert(self, catalog: Catalog) -> None:
payload = catalog.model_dump(mode="json")
async with AsyncSessionLocal() as session:
stmt = insert(CatalogRow).values(
user_id=catalog.user_id,
data=payload,
schema_version=catalog.schema_version,
generated_at=catalog.generated_at,
updated_at=func.now(),
)
stmt = stmt.on_conflict_do_update(
index_elements=[CatalogRow.user_id],
set_={
"data": stmt.excluded.data,
"schema_version": stmt.excluded.schema_version,
"updated_at": case(
(stmt.excluded.data != CatalogRow.data, func.now()),
else_=CatalogRow.updated_at,
),
},
)
await session.execute(stmt)
await session.commit()
logger.info(
"catalog upserted",
user_id=catalog.user_id,
sources=len(catalog.sources),
)
async def remove_source(self, user_id: str, source_id: str) -> None:
existing = await self.get(user_id)
if existing is None:
logger.info("remove_source: no catalog found", user_id=user_id, source_id=source_id)
return
filtered = [s for s in existing.sources if s.source_id != source_id]
if len(filtered) == len(existing.sources):
logger.info("remove_source: source not in catalog", user_id=user_id, source_id=source_id)
return
await self.upsert(existing.model_copy(update={"sources": filtered}))
logger.info("remove_source: source removed", user_id=user_id, source_id=source_id)
async def delete(self, user_id: str) -> None:
async with AsyncSessionLocal() as session:
await session.execute(delete(CatalogRow).where(CatalogRow.user_id == user_id))
await session.commit()
logger.info("catalog deleted", user_id=user_id)