| """Postgres repository functions — typed, async, subreddit_id-scoped. |
| |
| Every public function takes `subreddit_id` as a mandatory positional arg |
| (invariant I-7 enforced at the API). Internal helpers add the predicate to |
| SELECT/UPDATE/DELETE. |
| |
| Spec: docs/Specs.md §9, docs/07-DataLayer.md. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from contextlib import asynccontextmanager |
| from datetime import UTC, datetime |
| from typing import TYPE_CHECKING |
|
|
| import structlog |
| from sqlalchemy import cast, literal, select, update |
| from sqlalchemy.dialects.postgresql import JSONB as JSONB_TYPE |
| from sqlalchemy.dialects.postgresql import insert as pg_insert |
| from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker |
| from sqlalchemy.orm import selectinload |
|
|
| from store import models as m |
|
|
| if TYPE_CHECKING: |
| from collections.abc import AsyncIterator, Sequence |
| from store.types import ( |
| EvidenceRowInput, |
| FeedbackInput, |
| FinalizeInvestigationInput, |
| StartInvestigationInput, |
| SubredditProfileRow, |
| ThreadMemoryRow, |
| UserMemoryRow, |
| ) |
|
|
| logger = structlog.get_logger(__name__) |
|
|
|
|
| def make_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: |
| """Build the per-engine session factory. Held on app.state in the lifespan.""" |
| return async_sessionmaker(engine, expire_on_commit=False, autoflush=False) |
|
|
|
|
| @asynccontextmanager |
| async def with_session( |
| factory: async_sessionmaker[AsyncSession], |
| ) -> AsyncIterator[AsyncSession]: |
| """Open a session, commit on clean exit, rollback on exception.""" |
| async with factory() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
|
|
|
|
| |
|
|
|
|
| async def ensure_subreddit_profile( |
| session: AsyncSession, |
| *, |
| subreddit_id: str, |
| name: str, |
| personality: str = "balanced", |
| ) -> SubredditProfileRow: |
| """Idempotent on AppInstall. Returns the existing row if present.""" |
| stmt = ( |
| pg_insert(m.SubredditProfile) |
| .values(subreddit_id=subreddit_id, name=name, personality=personality) |
| .on_conflict_do_nothing(index_elements=["subreddit_id"]) |
| ) |
| await session.execute(stmt) |
| row = ( |
| await session.execute( |
| select(m.SubredditProfile).where(m.SubredditProfile.subreddit_id == subreddit_id) |
| ) |
| ).scalar_one() |
| return SubredditProfileRow.model_validate(row) |
|
|
|
|
| async def get_subreddit_profile( |
| session: AsyncSession, *, subreddit_id: str |
| ) -> SubredditProfileRow | None: |
| row = ( |
| await session.execute( |
| select(m.SubredditProfile).where(m.SubredditProfile.subreddit_id == subreddit_id) |
| ) |
| ).scalar_one_or_none() |
| return SubredditProfileRow.model_validate(row) if row else None |
|
|
|
|
| |
|
|
|
|
| async def get_user_memory( |
| session: AsyncSession, *, subreddit_id: str, user_id: str |
| ) -> UserMemoryRow | None: |
| row = ( |
| await session.execute( |
| select(m.UserMemory).where( |
| m.UserMemory.subreddit_id == subreddit_id, |
| m.UserMemory.user_id == user_id, |
| ) |
| ) |
| ).scalar_one_or_none() |
| return UserMemoryRow.model_validate(row) if row else None |
|
|
|
|
| async def upsert_user_memory( |
| session: AsyncSession, |
| *, |
| subreddit_id: str, |
| user_id: str, |
| risk_tier: str | None = None, |
| prior_violations_delta: int = 0, |
| prior_approvals_delta: int = 0, |
| ) -> UserMemoryRow: |
| """Get-or-create, then optionally bump counters and refresh last_seen. |
| |
| Used by `feedback`-batch ingest (I-3.4) and by `user_history` tool reads. |
| Idempotent on (subreddit_id, user_id). |
| """ |
| base = ( |
| pg_insert(m.UserMemory) |
| .values( |
| subreddit_id=subreddit_id, |
| user_id=user_id, |
| last_seen_at=datetime.now(UTC), |
| ) |
| .on_conflict_do_update( |
| index_elements=["subreddit_id", "user_id"], |
| set_={"last_seen_at": datetime.now(UTC)}, |
| ) |
| ) |
| await session.execute(base) |
|
|
| if risk_tier or prior_violations_delta or prior_approvals_delta: |
| updates: dict[str, object] = {} |
| if risk_tier: |
| updates["risk_tier"] = risk_tier |
| if prior_violations_delta: |
| updates["prior_violations"] = ( |
| m.UserMemory.prior_violations + prior_violations_delta |
| ) |
| if prior_approvals_delta: |
| updates["prior_approvals"] = ( |
| m.UserMemory.prior_approvals + prior_approvals_delta |
| ) |
| await session.execute( |
| update(m.UserMemory) |
| .where( |
| m.UserMemory.subreddit_id == subreddit_id, |
| m.UserMemory.user_id == user_id, |
| ) |
| .values(**updates) |
| ) |
|
|
| row = ( |
| await session.execute( |
| select(m.UserMemory).where( |
| m.UserMemory.subreddit_id == subreddit_id, |
| m.UserMemory.user_id == user_id, |
| ) |
| ) |
| ).scalar_one() |
| return UserMemoryRow.model_validate(row) |
|
|
|
|
| |
|
|
|
|
| async def get_thread_memory( |
| session: AsyncSession, *, subreddit_id: str, post_id: str |
| ) -> ThreadMemoryRow | None: |
| row = ( |
| await session.execute( |
| select(m.ThreadMemory).where( |
| m.ThreadMemory.subreddit_id == subreddit_id, |
| m.ThreadMemory.post_id == post_id, |
| ) |
| ) |
| ).scalar_one_or_none() |
| return ThreadMemoryRow.model_validate(row) if row else None |
|
|
|
|
| async def upsert_thread_memory( |
| session: AsyncSession, |
| *, |
| subreddit_id: str, |
| post_id: str, |
| mod_action_entry: dict[str, object] | None = None, |
| ) -> ThreadMemoryRow: |
| """Get-or-create thread memory; optionally append a mod action entry.""" |
| base = ( |
| pg_insert(m.ThreadMemory) |
| .values(subreddit_id=subreddit_id, post_id=post_id) |
| .on_conflict_do_nothing(index_elements=["subreddit_id", "post_id"]) |
| ) |
| await session.execute(base) |
|
|
| if mod_action_entry is not None: |
| |
| new_entry = cast( |
| literal(json.dumps([mod_action_entry])), |
| JSONB_TYPE, |
| ) |
| await session.execute( |
| update(m.ThreadMemory) |
| .where( |
| m.ThreadMemory.subreddit_id == subreddit_id, |
| m.ThreadMemory.post_id == post_id, |
| ) |
| .values( |
| mod_actions_taken=m.ThreadMemory.mod_actions_taken + new_entry |
| ) |
| ) |
|
|
| row = ( |
| await session.execute( |
| select(m.ThreadMemory).where( |
| m.ThreadMemory.subreddit_id == subreddit_id, |
| m.ThreadMemory.post_id == post_id, |
| ) |
| ) |
| ).scalar_one() |
| return ThreadMemoryRow.model_validate(row) |
|
|
|
|
| |
|
|
|
|
| async def increment_cold_start_count( |
| session: AsyncSession, *, subreddit_id: str |
| ) -> int: |
| """Increment the feedback counter used for cold-start detection. |
| |
| Returns the new count. Spec: docs/05-Memory.md - cold_start transitions |
| at 50 feedback events. |
| """ |
| await session.execute( |
| update(m.SubredditProfile) |
| .where(m.SubredditProfile.subreddit_id == subreddit_id) |
| .values(cold_start_count=m.SubredditProfile.cold_start_count + 1) |
| ) |
| row = ( |
| await session.execute( |
| select(m.SubredditProfile.cold_start_count).where( |
| m.SubredditProfile.subreddit_id == subreddit_id |
| ) |
| ) |
| ).scalar_one() |
| return int(row) |
|
|
|
|
| |
|
|
|
|
| async def start_investigation( |
| session: AsyncSession, *, input_: StartInvestigationInput |
| ) -> m.Investigation: |
| """Create a `pending` investigation row and return the ORM object.""" |
| row = m.Investigation( |
| correlation_id=input_.correlation_id, |
| subreddit_id=input_.subreddit_id, |
| target_kind=input_.target_kind, |
| target_id=input_.target_id, |
| target_body=input_.target_body, |
| target_author_id=input_.target_author_id, |
| tier=input_.tier, |
| status="pending", |
| ) |
| session.add(row) |
| await session.flush() |
| return row |
|
|
|
|
| async def append_evidence( |
| session: AsyncSession, |
| *, |
| investigation: m.Investigation, |
| subreddit_id: str, |
| evidence: EvidenceRowInput, |
| ) -> None: |
| """Persist one Evidence Accumulator entry. subreddit_id must match the investigation.""" |
| if investigation.subreddit_id != subreddit_id: |
| raise ValueError( |
| f"subreddit_id mismatch: investigation={investigation.subreddit_id} call={subreddit_id}" |
| ) |
| row = m.Evidence( |
| investigation_id=investigation.id, |
| subreddit_id=subreddit_id, |
| evidence_id=evidence.evidence_id, |
| tool=evidence.tool, |
| summary=evidence.summary, |
| detail=evidence.detail, |
| status=evidence.status, |
| latency_ms=evidence.latency_ms, |
| ) |
| session.add(row) |
| await session.flush() |
|
|
|
|
| async def finalize_investigation( |
| session: AsyncSession, |
| *, |
| correlation_id: str, |
| subreddit_id: str, |
| verdict: FinalizeInvestigationInput, |
| ) -> None: |
| """Stamp the verdict columns + flip status='completed'.""" |
| completed_at = datetime.now(UTC) |
| stmt = ( |
| update(m.Investigation) |
| .where( |
| m.Investigation.correlation_id == correlation_id, |
| m.Investigation.subreddit_id == subreddit_id, |
| ) |
| .values( |
| status="completed", |
| risk_tier=verdict.risk_tier, |
| recommendation=verdict.recommendation, |
| calibrated_confidence=verdict.calibrated_confidence, |
| rationale=verdict.rationale, |
| confidence_breakdown=verdict.confidence_breakdown, |
| model_reasoner=verdict.model_reasoner, |
| model_summarizer=verdict.model_summarizer, |
| cost_usd=verdict.cost_usd, |
| latency_ms=verdict.latency_ms, |
| input_tokens=verdict.input_tokens, |
| output_tokens=verdict.output_tokens, |
| validation_flag=verdict.validation_flag, |
| degraded=verdict.degraded, |
| cold_start=verdict.cold_start, |
| completed_at=completed_at, |
| ) |
| .execution_options(synchronize_session=False) |
| ) |
| result = await session.execute(stmt) |
| |
| rowcount = getattr(result, "rowcount", 0) or 0 |
| if rowcount == 0: |
| raise LookupError( |
| "no pending investigation for " |
| f"correlation_id={correlation_id} subreddit_id={subreddit_id}" |
| ) |
|
|
|
|
| async def get_investigation_by_correlation( |
| session: AsyncSession, *, correlation_id: str, subreddit_id: str |
| ) -> m.Investigation | None: |
| """Eagerly loads `.evidence` so callers can iterate it after the session closes.""" |
| return ( |
| await session.execute( |
| select(m.Investigation) |
| .where( |
| m.Investigation.correlation_id == correlation_id, |
| m.Investigation.subreddit_id == subreddit_id, |
| ) |
| .options(selectinload(m.Investigation.evidence)) |
| ) |
| ).scalar_one_or_none() |
|
|
|
|
| |
|
|
|
|
| async def record_feedback(session: AsyncSession, *, feedback: FeedbackInput) -> None: |
| session.add( |
| m.Feedback( |
| correlation_id=feedback.correlation_id, |
| subreddit_id=feedback.subreddit_id, |
| target_id=feedback.target_id, |
| mod_action=feedback.mod_action, |
| raw_action=feedback.raw_action, |
| moderator_id=feedback.moderator_id, |
| moderator_name=feedback.moderator_name, |
| source=feedback.source, |
| aligned=feedback.aligned, |
| ) |
| ) |
| await session.flush() |
|
|
|
|
| async def list_recent_feedback_for_subreddit( |
| session: AsyncSession, *, subreddit_id: str, limit: int = 100 |
| ) -> Sequence[m.Feedback]: |
| """Used by the nightly calibration batch (post-MVP).""" |
| return ( |
| await session.execute( |
| select(m.Feedback) |
| .where(m.Feedback.subreddit_id == subreddit_id) |
| .order_by(m.Feedback.created_at.desc()) |
| .limit(limit) |
| ) |
| ).scalars().all() |
|
|
|
|
| |
|
|
|
|
| async def append_audit( |
| session: AsyncSession, |
| *, |
| subreddit_id: str, |
| event_type: str, |
| actor: str = "system", |
| correlation_id: str | None = None, |
| detail: dict[str, object] | None = None, |
| ) -> None: |
| session.add( |
| m.AuditLog( |
| subreddit_id=subreddit_id, |
| event_type=event_type, |
| actor=actor, |
| correlation_id=correlation_id, |
| detail=detail or {}, |
| ) |
| ) |
| await session.flush() |
|
|
|
|
| async def list_prior_actions_on_user( |
| session: AsyncSession, |
| *, |
| subreddit_id: str, |
| author_id: str, |
| limit: int = 5, |
| ) -> Sequence[m.Investigation]: |
| """Return last N completed investigations on `author_id` in this subreddit. |
| |
| Used by the `prior_actions` tool (I-3.2, docs/04-InvestigationEngine.md §5.3.4). |
| Ordered newest-first. Only completed investigations with a verdict. |
| """ |
| return ( |
| await session.execute( |
| select(m.Investigation) |
| .where( |
| m.Investigation.subreddit_id == subreddit_id, |
| m.Investigation.target_author_id == author_id, |
| m.Investigation.status == "completed", |
| ) |
| .order_by(m.Investigation.completed_at.desc()) |
| .limit(limit) |
| ) |
| ).scalars().all() |
|
|