File size: 3,261 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""CatalogStore — persists per-user catalogs as Postgres jsonb rows.

Storage shape: one row per user in a `catalogs` table with columns
(user_id PK, data jsonb, schema_version, generated_at, updated_at).
"""

from sqlalchemy import case, delete, func, select
from sqlalchemy.dialects.postgresql import insert

from src.db.postgres.connection import AsyncSessionLocal
from src.db.postgres.models import Catalog as CatalogRow
from src.middlewares.logging import get_logger

from .models import Catalog

logger = get_logger("catalog_store")


class CatalogStore:
    """Read/write catalogs keyed by user_id.

    Each method opens its own AsyncSession. Callers needing transactional
    coordination across multiple stores can be refactored to accept an
    explicit AsyncSession in a later PR.
    """

    async def get(self, user_id: str) -> Catalog | None:
        async with AsyncSessionLocal() as session:
            result = await session.execute(
                select(CatalogRow.data).where(CatalogRow.user_id == user_id)
            )
            row = result.scalar_one_or_none()
        if row is None:
            return None
        return Catalog.model_validate(row)

    async def upsert(self, catalog: Catalog) -> None:
        payload = catalog.model_dump(mode="json")
        async with AsyncSessionLocal() as session:
            stmt = insert(CatalogRow).values(
                user_id=catalog.user_id,
                data=payload,
                schema_version=catalog.schema_version,
                generated_at=catalog.generated_at,
                updated_at=func.now(),
            )
            stmt = stmt.on_conflict_do_update(
                index_elements=[CatalogRow.user_id],
                set_={
                    "data": stmt.excluded.data,
                    "schema_version": stmt.excluded.schema_version,
                    "updated_at": case(
                        (stmt.excluded.data != CatalogRow.data, func.now()),
                        else_=CatalogRow.updated_at,
                    ),
                },
            )
            await session.execute(stmt)
            await session.commit()
        logger.info(
            "catalog upserted",
            user_id=catalog.user_id,
            sources=len(catalog.sources),
        )

    async def remove_source(self, user_id: str, source_id: str) -> None:
        existing = await self.get(user_id)
        if existing is None:
            logger.info("remove_source: no catalog found", user_id=user_id, source_id=source_id)
            return
        filtered = [s for s in existing.sources if s.source_id != source_id]
        if len(filtered) == len(existing.sources):
            logger.info("remove_source: source not in catalog", user_id=user_id, source_id=source_id)
            return
        await self.upsert(existing.model_copy(update={"sources": filtered}))
        logger.info("remove_source: source removed", user_id=user_id, source_id=source_id)

    async def delete(self, user_id: str) -> None:
        async with AsyncSessionLocal() as session:
            await session.execute(delete(CatalogRow).where(CatalogRow.user_id == user_id))
            await session.commit()
        logger.info("catalog deleted", user_id=user_id)