File size: 10,374 Bytes
7302343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""ModPilot Investigation Engine — FastAPI entry point.

Spec: docs/Specs.md §10, docs/08-API.md
"""

from __future__ import annotations

from contextlib import asynccontextmanager
from typing import TYPE_CHECKING

from fastapi import FastAPI, Request

from api.config import get_settings
from api.errors import register_error_handlers
from api.middleware import CorrelationIdMiddleware, HmacMiddleware
from api.pipeline import PipelineResult, run_investigation
from api.schemas import InvestigateRequest, InvestigateResponse
from observability.logging import configure_logging, get_logger
from orchestrator.loop import Orchestrator
from orchestrator.prior_actions import PriorActionsTool
from orchestrator.report_velocity import ReportVelocityTool
from orchestrator.thread_context import ThreadContextTool
from orchestrator.tools import ToolRegistry
from orchestrator.user_history import UserHistoryTool
from store.connections import close_postgres, close_redis, open_postgres, open_redis
from store.postgres import (
    append_evidence,
    ensure_subreddit_profile,
    finalize_investigation,
    get_thread_memory,
    get_user_memory,
    make_sessionmaker,
    start_investigation,
    with_session,
)
from store.types import (
    EvidenceRowInput,
    FinalizeInvestigationInput,
    StartInvestigationInput,
)

if TYPE_CHECKING:
    from collections.abc import AsyncIterator


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    settings = get_settings()
    configure_logging(level=settings.log_level, env=settings.env)
    logger = get_logger(__name__)
    settings.validate_for_runtime()  # F-0.7 — fail-closed in prod when keys missing
    logger.info(
        "engine.startup",
        env=settings.env,
        model_reasoner=settings.model_reasoner,
        model_summarizer=settings.model_summarizer,
        hmac_enforced=settings.hmac_enforced,
        gemini_configured=bool(settings.gemini_api_key),
    )

    # F-0.6: probe Postgres + Redis at startup.
    app.state.pg = await open_postgres(settings)
    app.state.redis = await open_redis(settings)
    app.state.pg_sessions = make_sessionmaker(app.state.pg)

    # LLM client — deferred import to avoid hard google-genai dep at import time.
    # Built before the registry so LLM-using tools (thread_context) can register.
    if settings.gemini_api_key:
        from llm.gemini import GeminiClient  # noqa: PLC0415

        app.state.llm = GeminiClient(settings)
    else:
        app.state.llm = None
        logger.warning("engine.no_llm", reason="GEMINI_API_KEY not set")

    # E-2.11 + I-3.3: build Tool Registry + Orchestrator.
    registry = ToolRegistry()
    registry.register(ReportVelocityTool(app.state.redis))
    registry.register(PriorActionsTool(app.state.pg_sessions))
    registry.register(UserHistoryTool(app.state.pg_sessions))
    if app.state.llm is not None:
        registry.register(ThreadContextTool(app.state.llm, app.state.redis))
    else:
        logger.warning("engine.thread_context_disabled", reason="no LLM client")
    # PolicyMatchTool requires embed + rules_text functions; registered when
    # those are wired (post-MVP). Orchestrator records "skipped" for missing tools.
    app.state.orchestrator = Orchestrator(registry)

    try:
        yield
    finally:
        await close_redis(app.state.redis)
        await close_postgres(app.state.pg)
        logger.info("engine.shutdown")


app = FastAPI(
    title="ModPilot Investigation Engine",
    version="0.0.1",
    description="Context-aware investigation engine for Reddit moderation",
    lifespan=lifespan,
)

# Middleware order matters: HMAC runs *after* correlation-id is bound,
# so a rejection log carries the request's correlation_id.
app.add_middleware(HmacMiddleware)
app.add_middleware(CorrelationIdMiddleware)

register_error_handlers(app)


@app.get("/health")
async def health() -> dict[str, object]:
    """Liveness + readiness + model identifiers. Spec: docs/Specs.md §10.1."""
    settings = get_settings()
    return {
        "ok": True,
        "data": {
            "engine": "0.0.1",
            "git_sha": "unknown",
            "reasoner_prompt": None,
            "summarizer_prompt": None,
            "model_reasoner": settings.model_reasoner,
            "model_summarizer": settings.model_summarizer,
        },
    }


@app.post("/investigate", response_model=InvestigateResponse)
async def investigate(
    req: InvestigateRequest, request: Request
) -> InvestigateResponse:
    """Full pipeline: Strategy -> Orchestrator -> Reasoner -> Validator -> Calibrator.

    Persists investigation + evidence rows. Returns verdict.
    Spec: docs/Specs.md §10.2, docs/04-InvestigationEngine.md §1-§9.
    """
    logger = get_logger(__name__).bind(correlation_id=req.correlation_id)
    logger.info(
        "investigation.requested",
        subreddit_id=req.subreddit_id,
        target_kind=req.target.kind,
        target_id=req.target.id,
        reporter_count=req.report.reporter_count,
    )

    orchestrator: Orchestrator = request.app.state.orchestrator
    llm = request.app.state.llm

    # Fetch subreddit + memory context from DB (cold-start defaults if missing).
    personality = "balanced"
    region = "Global"
    rules = ""
    cold_start = True
    user_risk_tier = "new"
    tier_override = "auto"
    thread_escalated = False

    # I-3.9: pull subreddit_profile + user_memory + thread_memory in a
    # single session so the Strategy Selector inputs reflect cached state.
    # Lazily create the subreddit_profile row on first contact — the engine
    # doesn't yet receive onAppInstall (post-MVP), so production traffic
    # would otherwise FK-violate on investigation insert.
    async with with_session(request.app.state.pg_sessions) as session:
        profile = await ensure_subreddit_profile(
            session,
            subreddit_id=req.subreddit_id,
            name=req.subreddit_id,  # name unknown engine-side until AppInstall relay; safe default
        )
        if profile is not None:
            personality = profile.personality
            region = profile.region
            rules = profile.rules
            cold_start = profile.cold_start_count < 50
            tier_override = profile.tier_override

        if req.target.author:
            user_mem = await get_user_memory(
                session, subreddit_id=req.subreddit_id, user_id=req.target.author
            )
            if user_mem is not None:
                user_risk_tier = user_mem.risk_tier

        if req.context.thread_id:
            thread_mem = await get_thread_memory(
                session, subreddit_id=req.subreddit_id, post_id=req.context.thread_id
            )
            if thread_mem is not None:
                # Escalation cached when prior mod attention exists OR a
                # prior thread_context summary recorded escalation.
                escalation_turn = thread_mem.detail.get("escalation_turn")
                thread_escalated = bool(thread_mem.mod_actions_taken) or (
                    escalation_turn is not None
                )

    # Run the pipeline.
    result = await run_investigation(
        req=req,
        orchestrator=orchestrator,
        llm=llm,
        personality=personality,
        region=region,
        rules=rules,
        cold_start=cold_start,
        user_risk_tier=user_risk_tier,
        velocity_zscore=0.0,  # TODO(E-3.x): precompute from Redis before pipeline
        rule_match_score=0.0,  # TODO(E-3.x): precompute from embeddings before pipeline
        tier_override=tier_override,
        thread_escalated=thread_escalated,
    )

    # Persist investigation + evidence rows.
    await _persist(request, req, result)

    return InvestigateResponse(data=result.verdict)


async def _persist(
    request: Request,
    req: InvestigateRequest,
    result: PipelineResult,
) -> None:
    """Write investigation + evidence + verdict to Postgres."""
    async with with_session(request.app.state.pg_sessions) as session:
        inv = await start_investigation(
            session,
            input_=StartInvestigationInput(
                correlation_id=req.correlation_id,
                subreddit_id=req.subreddit_id,
                target_kind=req.target.kind,
                target_id=req.target.id,
                target_body=req.target.body,
                target_author_id=req.target.author,
                tier=result.tier,
            ),
        )

        for entry in result.accumulator.entries():
            await append_evidence(
                session,
                investigation=inv,
                subreddit_id=req.subreddit_id,
                evidence=EvidenceRowInput(
                    evidence_id=entry.id,
                    tool=entry.tool,
                    summary=entry.summary,
                    detail=entry.detail,
                    status=entry.status,
                    latency_ms=entry.latency_ms,
                ),
            )

        v = result.verdict
        await finalize_investigation(
            session,
            correlation_id=req.correlation_id,
            subreddit_id=req.subreddit_id,
            verdict=FinalizeInvestigationInput(
                risk_tier=v.risk_tier,
                recommendation=v.recommendation,
                calibrated_confidence=v.calibrated_confidence,
                rationale=v.rationale,
                confidence_breakdown={
                    "llm_self_report": v.confidence_breakdown.llm_self_report,
                    "evidence_convergence": v.confidence_breakdown.evidence_convergence,
                    "subreddit_accuracy": v.confidence_breakdown.subreddit_accuracy,
                    "rule_match_strength": v.confidence_breakdown.rule_match_strength,
                },
                model_reasoner=v.model_reasoner,
                model_summarizer=v.model_summarizer,
                cost_usd=v.cost_usd,
                latency_ms=v.latency_ms,
                input_tokens=result.input_tokens,
                output_tokens=result.output_tokens,
                validation_flag=v.validation_flag,
                degraded=v.degraded,
                cold_start=v.cold_start,
            ),
        )


# TODO(S-1.6): POST /feedback
# TODO(U-4.7): POST /explain
# TODO(F-0.7): GET /config/{sub_id}