"""DataAccessToolInvoker — the data-access tool family (KM-465 / KM-630). Implements the `ToolInvoker` Protocol (src/agents/slow_path/invoker.py) for the data-access family. Unlike the stateless `AnalyticsToolInvoker`, these tools read the user's catalog / sources, so the invoker is constructed per-request with the authenticated `user_id` and its dependencies (dependency injection — the runtime/Coordinator supplies them; INV-7 keeps the agent layer tool-agnostic). Tools implemented here: - `check_data` — structured data sources (DB + tabular). No `source_id` → list sources (id, name, type, table count); with a `source_id` → that source's tables/columns (one row per column, metadata only — exposes `pii_flag`, never sample values). - `check_knowledge` — the user's unstructured sources / documents (id, name, type). - `retrieve_data` — runs a pre-built `QueryIR` (validate -> dispatch -> execute, skipping the planner) and returns rows as `ToolOutput(kind="table")` — the Pattern A handoff the `analyze_*` tools consume. - `retrieve_knowledge` — dense retrieval over unstructured sources, returns `ToolOutput(kind="documents")`. Frozen guarantee (§8.4): **never throws.** Any failure returns `ToolOutput(kind="error", error=...)`. """ from __future__ import annotations from collections.abc import Callable from decimal import Decimal from typing import Any, Protocol from pydantic import ValidationError from src.catalog.models import Catalog from src.catalog.reader import CatalogReader from src.query.executor.dispatcher import ExecutorDispatcher from src.query.ir.models import QueryIR from src.query.ir.validator import IRValidationError, IRValidator from src.retrieval.base import RetrievalResult from src.tools.contracts import ToolOutput DispatcherFactory = Callable[[Catalog], ExecutorDispatcher] # Canonical set of data-access tool names — the single source of truth for which # tools this invoker serves. `CompositeToolInvoker` imports it to route by name; # the planner registry should derive its data-access spec names from it (agent -> # tool is the correct dependency direction). Defining it once here means # adding/renaming a data-access tool can't silently drift the router out of sync # from the registry (R11). Must match the names in `DataAccessToolInvoker.invoke`. DATA_ACCESS_TOOLS: frozenset[str] = frozenset( {"check_data", "check_knowledge", "retrieve_data", "retrieve_knowledge"} ) class Retriever(Protocol): """Minimal interface this invoker needs from the retrieval layer.""" async def retrieve( self, query: str, user_id: str, k: int = 5 ) -> list[RetrievalResult]: ... class DataAccessToolInvoker: """Never-throwing invoker for catalog-introspection tools (implements ToolInvoker).""" def __init__( self, user_id: str, catalog_reader: CatalogReader, *, ir_validator: IRValidator | None = None, dispatcher_factory: DispatcherFactory | None = None, document_retriever: Retriever | None = None, ) -> None: self._user_id = user_id self._reader = catalog_reader # retrieve_data deps — injectable so tests need no real LLM/DB. The # validator is stateless; the dispatcher is built per-call from the # request's catalog (executors are picked by source_type). self._validator = ir_validator or IRValidator() self._dispatcher_factory: DispatcherFactory = ( dispatcher_factory or ExecutorDispatcher ) # retrieve_knowledge dep — the module singleton by default, injectable # for tests (the real one pulls PGVector + Redis). Lazy-imported on first # use so importing this module stays cheap. self._retriever = document_retriever async def invoke(self, tool_name: str, args: dict[str, Any]) -> ToolOutput: try: if tool_name == "check_data": return await self._check_data(args) if tool_name == "check_knowledge": return await self._check_knowledge() if tool_name == "retrieve_data": return await self._retrieve_data(args) if tool_name == "retrieve_knowledge": return await self._retrieve_knowledge(args) return ToolOutput( tool=tool_name, kind="error", error=f"unknown tool {tool_name!r}" ) except Exception as exc: # noqa: BLE001 — never-throw seam (§8.4) return ToolOutput( tool=tool_name, kind="error", error=f"{type(exc).__name__}: {exc}" ) async def _check_data(self, args: dict[str, Any]) -> ToolOutput: """Inspect the user's structured data sources (DB + tabular). No `source_id` → an overview: one row per structured source (id, name, type, table count). With a `source_id` → that source's schema: one row per column across its tables. Pattern A note: schema is catalog metadata only — never returns row data or PII sample values (only the `pii_flag` boolean per column). Unstructured documents are covered by `check_knowledge`. """ structured = await self._reader.read(self._user_id, "structured") source_id = args.get("source_id") if not source_id: rows = [ [s.source_id, s.name, s.source_type, len(s.tables)] for s in structured.sources ] return ToolOutput( tool="check_data", kind="table", columns=["source_id", "name", "source_type", "table_count"], rows=rows, meta={"source_count": len(structured.sources)}, ) source = next( (s for s in structured.sources if s.source_id == source_id), None ) if source is None: return ToolOutput( tool="check_data", kind="error", error=f"structured source {source_id!r} not found", ) rows = [ [ t.table_id, t.name, t.row_count, c.column_id, c.name, c.data_type, c.nullable, c.pii_flag, ] for t in source.tables for c in t.columns ] return ToolOutput( tool="check_data", kind="table", columns=[ "table_id", "table_name", "table_row_count", "column_id", "column_name", "data_type", "nullable", "pii_flag", ], rows=rows, meta={ "source_id": source.source_id, "source_name": source.name, "source_type": source.source_type, "table_count": len(source.tables), "column_count": len(rows), }, ) async def _check_knowledge(self) -> ToolOutput: """List the user's unstructured sources (documents). Documents have no column schema to drill into, so there is no `source_id` mode — reading document content is `retrieve_knowledge`'s job. """ unstructured = await self._reader.read(self._user_id, "unstructured") rows = [[s.source_id, s.name, s.source_type] for s in unstructured.sources] return ToolOutput( tool="check_knowledge", kind="table", columns=["source_id", "name", "source_type"], rows=rows, meta={"source_count": len(unstructured.sources)}, ) async def _retrieve_data(self, args: dict[str, Any]) -> ToolOutput: """Run one validated, single-table QueryIR and return rows as a table. This is the spine of the slow path (Pattern A): the `analyze_*` tools take this output as their `data` arg. We receive an already-built `ir` from the Planner (never SQL, never an NL question), so we skip the planner and run validate -> dispatch -> execute directly (the tail of QueryService.run). Output is `kind="table"` with `columns` + `rows` (rows are list[list], converted from the executor's list[dict]). """ raw = args.get("ir") if raw is None: return ToolOutput( tool="retrieve_data", kind="error", error="missing 'ir' argument" ) try: ir = raw if isinstance(raw, QueryIR) else QueryIR.model_validate(raw) except ValidationError as exc: return ToolOutput( tool="retrieve_data", kind="error", error=f"invalid IR: {exc}" ) catalog = await self._reader.read(self._user_id, "structured") try: self._validator.validate(ir, catalog) except IRValidationError as exc: return ToolOutput( tool="retrieve_data", kind="error", error=f"IR validation failed: {exc}", ) dispatcher = self._dispatcher_factory(catalog) executor = dispatcher.pick(ir) result = await executor.run(ir) if result.error: return ToolOutput( tool="retrieve_data", kind="error", error=result.error ) # QueryResult.rows is list[dict]; ToolOutput.rows is list[list] ordered # by `columns` so downstream materialization is positional. DB NUMERIC # columns arrive as `Decimal` (asyncpg) — coerce to float here so the # output is JSON-serializable (SSE / analysis_record persistence) and # plays nicely with the float math in the analyze_* tools. rows = [ [_json_safe(row.get(c)) for c in result.columns] for row in result.rows ] return ToolOutput( tool="retrieve_data", kind="table", columns=result.columns, rows=rows, meta={ "source_id": result.source_id, "source_name": result.source_name, "table_id": result.table_id, "table_name": result.table_name, "backend": result.backend, "row_count": result.row_count, "truncated": result.truncated, "elapsed_ms": result.elapsed_ms, }, ) async def _retrieve_knowledge(self, args: dict[str, Any]) -> ToolOutput: """Dense-retrieve relevant chunks from the user's unstructured sources. Pulls qualitative context (PDF/DOCX/TXT) for a natural-language `query` via the retrieval router. `top_k` caps the number of chunks; optional `source_id` scopes to one source (best-effort metadata filter — the router itself does not yet scope by source, so this prunes the results). TODO(retrieval scoping): the Planner few-shot has no `retrieve_knowledge` example, so `source_id` is rarely emitted today and this post-filter is adequate. If source-scoped retrieval becomes common, push scoping down into RetrievalRouter.retrieve()/DocumentRetriever (WHERE cmetadata->>'source_id' = :source_id) and drop this post-filter — more correct than pruning an already-top_k'd unscoped result set. """ query = args.get("query") if not isinstance(query, str) or not query.strip(): return ToolOutput( tool="retrieve_knowledge", kind="error", error="missing 'query' argument", ) try: top_k = int(args.get("top_k", 5)) except (TypeError, ValueError): top_k = 5 source_id = args.get("source_id") retriever = self._retriever if retriever is None: from src.retrieval.router import retrieval_router retriever = retrieval_router results = await retriever.retrieve(query, self._user_id, top_k) if source_id: results = [r for r in results if _result_source_id(r) == source_id] documents = [ { "content": r.content, "score": r.score, "source_type": r.source_type, "metadata": r.metadata, } for r in results ] return ToolOutput( tool="retrieve_knowledge", kind="documents", value=documents, meta={ "count": len(documents), "query": query, "top_k": top_k, "source_id": source_id, }, ) def _json_safe(value: Any) -> Any: """Coerce DB scalar types that JSON can't represent into plain Python. DB drivers return NUMERIC/DECIMAL as `decimal.Decimal`, which is neither JSON-serializable nor mixable with `float` math. Convert those to `float`; everything else passes through unchanged. """ if isinstance(value, Decimal): return float(value) return value def _result_source_id(result: RetrievalResult) -> str | None: """Best-effort extraction of a source_id from a retrieval result's metadata. The chunk metadata schema is owned by the Go ingestion service; the key may live at the top level or nested under "data". Returns None if absent. """ meta = result.metadata or {} top = meta.get("source_id") if isinstance(top, str): return top data = meta.get("data") if isinstance(data, dict): nested = data.get("source_id") if isinstance(nested, str): return nested return None