"""Postgres repository functions — typed, async, subreddit_id-scoped. Every public function takes `subreddit_id` as a mandatory positional arg (invariant I-7 enforced at the API). Internal helpers add the predicate to SELECT/UPDATE/DELETE. Spec: docs/Specs.md §9, docs/07-DataLayer.md. """ from __future__ import annotations import json from contextlib import asynccontextmanager from datetime import UTC, datetime from typing import TYPE_CHECKING import structlog from sqlalchemy import cast, literal, select, update from sqlalchemy.dialects.postgresql import JSONB as JSONB_TYPE from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlalchemy.orm import selectinload from store import models as m if TYPE_CHECKING: from collections.abc import AsyncIterator, Sequence from store.types import ( EvidenceRowInput, FeedbackInput, FinalizeInvestigationInput, StartInvestigationInput, SubredditProfileRow, ThreadMemoryRow, UserMemoryRow, ) logger = structlog.get_logger(__name__) def make_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: """Build the per-engine session factory. Held on app.state in the lifespan.""" return async_sessionmaker(engine, expire_on_commit=False, autoflush=False) @asynccontextmanager async def with_session( factory: async_sessionmaker[AsyncSession], ) -> AsyncIterator[AsyncSession]: """Open a session, commit on clean exit, rollback on exception.""" async with factory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise # === Subreddit profile ================================================= async def ensure_subreddit_profile( session: AsyncSession, *, subreddit_id: str, name: str, personality: str = "balanced", ) -> SubredditProfileRow: """Idempotent on AppInstall. Returns the existing row if present.""" stmt = ( pg_insert(m.SubredditProfile) .values(subreddit_id=subreddit_id, name=name, personality=personality) .on_conflict_do_nothing(index_elements=["subreddit_id"]) ) await session.execute(stmt) row = ( await session.execute( select(m.SubredditProfile).where(m.SubredditProfile.subreddit_id == subreddit_id) ) ).scalar_one() return SubredditProfileRow.model_validate(row) async def get_subreddit_profile( session: AsyncSession, *, subreddit_id: str ) -> SubredditProfileRow | None: row = ( await session.execute( select(m.SubredditProfile).where(m.SubredditProfile.subreddit_id == subreddit_id) ) ).scalar_one_or_none() return SubredditProfileRow.model_validate(row) if row else None # === User memory ======================================================= async def get_user_memory( session: AsyncSession, *, subreddit_id: str, user_id: str ) -> UserMemoryRow | None: row = ( await session.execute( select(m.UserMemory).where( m.UserMemory.subreddit_id == subreddit_id, m.UserMemory.user_id == user_id, ) ) ).scalar_one_or_none() return UserMemoryRow.model_validate(row) if row else None async def upsert_user_memory( # noqa: PLR0913 — kwarg-only delta-style API, intentional surface session: AsyncSession, *, subreddit_id: str, user_id: str, risk_tier: str | None = None, prior_violations_delta: int = 0, prior_approvals_delta: int = 0, ) -> UserMemoryRow: """Get-or-create, then optionally bump counters and refresh last_seen. Used by `feedback`-batch ingest (I-3.4) and by `user_history` tool reads. Idempotent on (subreddit_id, user_id). """ base = ( pg_insert(m.UserMemory) .values( subreddit_id=subreddit_id, user_id=user_id, last_seen_at=datetime.now(UTC), ) .on_conflict_do_update( index_elements=["subreddit_id", "user_id"], set_={"last_seen_at": datetime.now(UTC)}, ) ) await session.execute(base) if risk_tier or prior_violations_delta or prior_approvals_delta: updates: dict[str, object] = {} if risk_tier: updates["risk_tier"] = risk_tier if prior_violations_delta: updates["prior_violations"] = ( m.UserMemory.prior_violations + prior_violations_delta ) if prior_approvals_delta: updates["prior_approvals"] = ( m.UserMemory.prior_approvals + prior_approvals_delta ) await session.execute( update(m.UserMemory) .where( m.UserMemory.subreddit_id == subreddit_id, m.UserMemory.user_id == user_id, ) .values(**updates) ) row = ( await session.execute( select(m.UserMemory).where( m.UserMemory.subreddit_id == subreddit_id, m.UserMemory.user_id == user_id, ) ) ).scalar_one() return UserMemoryRow.model_validate(row) # === Thread memory ===================================================== async def get_thread_memory( session: AsyncSession, *, subreddit_id: str, post_id: str ) -> ThreadMemoryRow | None: row = ( await session.execute( select(m.ThreadMemory).where( m.ThreadMemory.subreddit_id == subreddit_id, m.ThreadMemory.post_id == post_id, ) ) ).scalar_one_or_none() return ThreadMemoryRow.model_validate(row) if row else None async def upsert_thread_memory( session: AsyncSession, *, subreddit_id: str, post_id: str, mod_action_entry: dict[str, object] | None = None, ) -> ThreadMemoryRow: """Get-or-create thread memory; optionally append a mod action entry.""" base = ( pg_insert(m.ThreadMemory) .values(subreddit_id=subreddit_id, post_id=post_id) .on_conflict_do_nothing(index_elements=["subreddit_id", "post_id"]) ) await session.execute(base) if mod_action_entry is not None: # Append to the JSONB array using Postgres jsonb || jsonb. new_entry = cast( literal(json.dumps([mod_action_entry])), JSONB_TYPE, ) await session.execute( update(m.ThreadMemory) .where( m.ThreadMemory.subreddit_id == subreddit_id, m.ThreadMemory.post_id == post_id, ) .values( mod_actions_taken=m.ThreadMemory.mod_actions_taken + new_entry ) ) row = ( await session.execute( select(m.ThreadMemory).where( m.ThreadMemory.subreddit_id == subreddit_id, m.ThreadMemory.post_id == post_id, ) ) ).scalar_one() return ThreadMemoryRow.model_validate(row) # === Subreddit cold-start counter ====================================== async def increment_cold_start_count( session: AsyncSession, *, subreddit_id: str ) -> int: """Increment the feedback counter used for cold-start detection. Returns the new count. Spec: docs/05-Memory.md - cold_start transitions at 50 feedback events. """ await session.execute( update(m.SubredditProfile) .where(m.SubredditProfile.subreddit_id == subreddit_id) .values(cold_start_count=m.SubredditProfile.cold_start_count + 1) ) row = ( await session.execute( select(m.SubredditProfile.cold_start_count).where( m.SubredditProfile.subreddit_id == subreddit_id ) ) ).scalar_one() return int(row) # === Investigation ===================================================== async def start_investigation( session: AsyncSession, *, input_: StartInvestigationInput ) -> m.Investigation: """Create a `pending` investigation row and return the ORM object.""" row = m.Investigation( correlation_id=input_.correlation_id, subreddit_id=input_.subreddit_id, target_kind=input_.target_kind, target_id=input_.target_id, target_body=input_.target_body, target_author_id=input_.target_author_id, tier=input_.tier, status="pending", ) session.add(row) await session.flush() return row async def append_evidence( session: AsyncSession, *, investigation: m.Investigation, subreddit_id: str, evidence: EvidenceRowInput, ) -> None: """Persist one Evidence Accumulator entry. subreddit_id must match the investigation.""" if investigation.subreddit_id != subreddit_id: raise ValueError( f"subreddit_id mismatch: investigation={investigation.subreddit_id} call={subreddit_id}" ) row = m.Evidence( investigation_id=investigation.id, subreddit_id=subreddit_id, evidence_id=evidence.evidence_id, tool=evidence.tool, summary=evidence.summary, detail=evidence.detail, status=evidence.status, latency_ms=evidence.latency_ms, ) session.add(row) await session.flush() async def finalize_investigation( session: AsyncSession, *, correlation_id: str, subreddit_id: str, verdict: FinalizeInvestigationInput, ) -> None: """Stamp the verdict columns + flip status='completed'.""" completed_at = datetime.now(UTC) stmt = ( update(m.Investigation) .where( m.Investigation.correlation_id == correlation_id, m.Investigation.subreddit_id == subreddit_id, ) .values( status="completed", risk_tier=verdict.risk_tier, recommendation=verdict.recommendation, calibrated_confidence=verdict.calibrated_confidence, rationale=verdict.rationale, confidence_breakdown=verdict.confidence_breakdown, model_reasoner=verdict.model_reasoner, model_summarizer=verdict.model_summarizer, cost_usd=verdict.cost_usd, latency_ms=verdict.latency_ms, input_tokens=verdict.input_tokens, output_tokens=verdict.output_tokens, validation_flag=verdict.validation_flag, degraded=verdict.degraded, cold_start=verdict.cold_start, completed_at=completed_at, ) .execution_options(synchronize_session=False) ) result = await session.execute(stmt) # CursorResult.rowcount is set on UPDATE; the typed `Result` superclass omits it. rowcount = getattr(result, "rowcount", 0) or 0 if rowcount == 0: raise LookupError( "no pending investigation for " f"correlation_id={correlation_id} subreddit_id={subreddit_id}" ) async def get_investigation_by_correlation( session: AsyncSession, *, correlation_id: str, subreddit_id: str ) -> m.Investigation | None: """Eagerly loads `.evidence` so callers can iterate it after the session closes.""" return ( await session.execute( select(m.Investigation) .where( m.Investigation.correlation_id == correlation_id, m.Investigation.subreddit_id == subreddit_id, ) .options(selectinload(m.Investigation.evidence)) ) ).scalar_one_or_none() # === Feedback ========================================================== async def record_feedback(session: AsyncSession, *, feedback: FeedbackInput) -> None: session.add( m.Feedback( correlation_id=feedback.correlation_id, subreddit_id=feedback.subreddit_id, target_id=feedback.target_id, mod_action=feedback.mod_action, raw_action=feedback.raw_action, moderator_id=feedback.moderator_id, moderator_name=feedback.moderator_name, source=feedback.source, aligned=feedback.aligned, ) ) await session.flush() async def list_recent_feedback_for_subreddit( session: AsyncSession, *, subreddit_id: str, limit: int = 100 ) -> Sequence[m.Feedback]: """Used by the nightly calibration batch (post-MVP).""" return ( await session.execute( select(m.Feedback) .where(m.Feedback.subreddit_id == subreddit_id) .order_by(m.Feedback.created_at.desc()) .limit(limit) ) ).scalars().all() # === Audit log ========================================================= async def append_audit( # noqa: PLR0913 — append-only audit takes the full event shape session: AsyncSession, *, subreddit_id: str, event_type: str, actor: str = "system", correlation_id: str | None = None, detail: dict[str, object] | None = None, ) -> None: session.add( m.AuditLog( subreddit_id=subreddit_id, event_type=event_type, actor=actor, correlation_id=correlation_id, detail=detail or {}, ) ) await session.flush() async def list_prior_actions_on_user( session: AsyncSession, *, subreddit_id: str, author_id: str, limit: int = 5, ) -> Sequence[m.Investigation]: """Return last N completed investigations on `author_id` in this subreddit. Used by the `prior_actions` tool (I-3.2, docs/04-InvestigationEngine.md §5.3.4). Ordered newest-first. Only completed investigations with a verdict. """ return ( await session.execute( select(m.Investigation) .where( m.Investigation.subreddit_id == subreddit_id, m.Investigation.target_author_id == author_id, m.Investigation.status == "completed", ) .order_by(m.Investigation.completed_at.desc()) .limit(limit) ) ).scalars().all()