| """DataAccessToolInvoker — the data-access tool family (KM-465 / KM-630). |
| |
| Implements the `ToolInvoker` Protocol (src/agents/slow_path/invoker.py) for the |
| data-access family. Unlike the stateless `AnalyticsToolInvoker`, these tools |
| read the user's catalog / sources, so the invoker is constructed per-request |
| with the authenticated `user_id` and its dependencies (dependency injection — |
| the runtime/Coordinator supplies them; INV-7 keeps the agent layer |
| tool-agnostic). |
| |
| Tools implemented here: |
| - `check_data` — structured data sources (DB + tabular). No `source_id` |
| → list sources (id, name, type, table count); with a |
| `source_id` → that source's tables/columns (one row per |
| column, metadata only — exposes `pii_flag`, never |
| sample values). |
| - `check_knowledge` — the user's unstructured sources / documents (id, name, |
| type). |
| - `retrieve_data` — runs a pre-built `QueryIR` (validate -> dispatch -> |
| execute, skipping the planner) and returns rows as |
| `ToolOutput(kind="table")` — the Pattern A handoff the |
| `analyze_*` tools consume. |
| - `retrieve_knowledge` — dense retrieval over unstructured sources, returns |
| `ToolOutput(kind="documents")`. |
| |
| Frozen guarantee (§8.4): **never throws.** Any failure returns |
| `ToolOutput(kind="error", error=...)`. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from decimal import Decimal |
| from typing import Any, Protocol |
|
|
| from pydantic import ValidationError |
|
|
| from src.catalog.models import Catalog |
| from src.catalog.reader import CatalogReader |
| from src.query.executor.dispatcher import ExecutorDispatcher |
| from src.query.ir.models import QueryIR |
| from src.query.ir.validator import IRValidationError, IRValidator |
| from src.retrieval.base import RetrievalResult |
| from src.tools.contracts import ToolOutput |
|
|
| DispatcherFactory = Callable[[Catalog], ExecutorDispatcher] |
|
|
| |
| |
| |
| |
| |
| |
| DATA_ACCESS_TOOLS: frozenset[str] = frozenset( |
| {"check_data", "check_knowledge", "retrieve_data", "retrieve_knowledge"} |
| ) |
|
|
|
|
| class Retriever(Protocol): |
| """Minimal interface this invoker needs from the retrieval layer.""" |
|
|
| async def retrieve( |
| self, query: str, user_id: str, k: int = 5 |
| ) -> list[RetrievalResult]: ... |
|
|
|
|
| class DataAccessToolInvoker: |
| """Never-throwing invoker for catalog-introspection tools (implements ToolInvoker).""" |
|
|
| def __init__( |
| self, |
| user_id: str, |
| catalog_reader: CatalogReader, |
| *, |
| ir_validator: IRValidator | None = None, |
| dispatcher_factory: DispatcherFactory | None = None, |
| document_retriever: Retriever | None = None, |
| ) -> None: |
| self._user_id = user_id |
| self._reader = catalog_reader |
| |
| |
| |
| self._validator = ir_validator or IRValidator() |
| self._dispatcher_factory: DispatcherFactory = ( |
| dispatcher_factory or ExecutorDispatcher |
| ) |
| |
| |
| |
| self._retriever = document_retriever |
|
|
| async def invoke(self, tool_name: str, args: dict[str, Any]) -> ToolOutput: |
| try: |
| if tool_name == "check_data": |
| return await self._check_data(args) |
| if tool_name == "check_knowledge": |
| return await self._check_knowledge() |
| if tool_name == "retrieve_data": |
| return await self._retrieve_data(args) |
| if tool_name == "retrieve_knowledge": |
| return await self._retrieve_knowledge(args) |
| return ToolOutput( |
| tool=tool_name, kind="error", error=f"unknown tool {tool_name!r}" |
| ) |
| except Exception as exc: |
| return ToolOutput( |
| tool=tool_name, kind="error", error=f"{type(exc).__name__}: {exc}" |
| ) |
|
|
| async def _check_data(self, args: dict[str, Any]) -> ToolOutput: |
| """Inspect the user's structured data sources (DB + tabular). |
| |
| No `source_id` → an overview: one row per structured source (id, name, |
| type, table count). With a `source_id` → that source's schema: one row |
| per column across its tables. |
| |
| Pattern A note: schema is catalog metadata only — never returns row |
| data or PII sample values (only the `pii_flag` boolean per column). |
| Unstructured documents are covered by `check_knowledge`. |
| """ |
| structured = await self._reader.read(self._user_id, "structured") |
| source_id = args.get("source_id") |
|
|
| if not source_id: |
| rows = [ |
| [s.source_id, s.name, s.source_type, len(s.tables)] |
| for s in structured.sources |
| ] |
| return ToolOutput( |
| tool="check_data", |
| kind="table", |
| columns=["source_id", "name", "source_type", "table_count"], |
| rows=rows, |
| meta={"source_count": len(structured.sources)}, |
| ) |
|
|
| source = next( |
| (s for s in structured.sources if s.source_id == source_id), None |
| ) |
| if source is None: |
| return ToolOutput( |
| tool="check_data", |
| kind="error", |
| error=f"structured source {source_id!r} not found", |
| ) |
|
|
| rows = [ |
| [ |
| t.table_id, |
| t.name, |
| t.row_count, |
| c.column_id, |
| c.name, |
| c.data_type, |
| c.nullable, |
| c.pii_flag, |
| ] |
| for t in source.tables |
| for c in t.columns |
| ] |
| return ToolOutput( |
| tool="check_data", |
| kind="table", |
| columns=[ |
| "table_id", |
| "table_name", |
| "table_row_count", |
| "column_id", |
| "column_name", |
| "data_type", |
| "nullable", |
| "pii_flag", |
| ], |
| rows=rows, |
| meta={ |
| "source_id": source.source_id, |
| "source_name": source.name, |
| "source_type": source.source_type, |
| "table_count": len(source.tables), |
| "column_count": len(rows), |
| }, |
| ) |
|
|
| async def _check_knowledge(self) -> ToolOutput: |
| """List the user's unstructured sources (documents). |
| |
| Documents have no column schema to drill into, so there is no |
| `source_id` mode — reading document content is `retrieve_knowledge`'s |
| job. |
| """ |
| unstructured = await self._reader.read(self._user_id, "unstructured") |
| rows = [[s.source_id, s.name, s.source_type] for s in unstructured.sources] |
| return ToolOutput( |
| tool="check_knowledge", |
| kind="table", |
| columns=["source_id", "name", "source_type"], |
| rows=rows, |
| meta={"source_count": len(unstructured.sources)}, |
| ) |
|
|
| async def _retrieve_data(self, args: dict[str, Any]) -> ToolOutput: |
| """Run one validated, single-table QueryIR and return rows as a table. |
| |
| This is the spine of the slow path (Pattern A): the `analyze_*` tools |
| take this output as their `data` arg. We receive an already-built `ir` |
| from the Planner (never SQL, never an NL question), so we skip the |
| planner and run validate -> dispatch -> execute directly (the tail of |
| QueryService.run). Output is `kind="table"` with `columns` + `rows` |
| (rows are list[list], converted from the executor's list[dict]). |
| """ |
| raw = args.get("ir") |
| if raw is None: |
| return ToolOutput( |
| tool="retrieve_data", kind="error", error="missing 'ir' argument" |
| ) |
|
|
| try: |
| ir = raw if isinstance(raw, QueryIR) else QueryIR.model_validate(raw) |
| except ValidationError as exc: |
| return ToolOutput( |
| tool="retrieve_data", kind="error", error=f"invalid IR: {exc}" |
| ) |
|
|
| catalog = await self._reader.read(self._user_id, "structured") |
|
|
| try: |
| self._validator.validate(ir, catalog) |
| except IRValidationError as exc: |
| return ToolOutput( |
| tool="retrieve_data", |
| kind="error", |
| error=f"IR validation failed: {exc}", |
| ) |
|
|
| dispatcher = self._dispatcher_factory(catalog) |
| executor = dispatcher.pick(ir) |
| result = await executor.run(ir) |
|
|
| if result.error: |
| return ToolOutput( |
| tool="retrieve_data", kind="error", error=result.error |
| ) |
|
|
| |
| |
| |
| |
| |
| rows = [ |
| [_json_safe(row.get(c)) for c in result.columns] for row in result.rows |
| ] |
| return ToolOutput( |
| tool="retrieve_data", |
| kind="table", |
| columns=result.columns, |
| rows=rows, |
| meta={ |
| "source_id": result.source_id, |
| "source_name": result.source_name, |
| "table_id": result.table_id, |
| "table_name": result.table_name, |
| "backend": result.backend, |
| "row_count": result.row_count, |
| "truncated": result.truncated, |
| "elapsed_ms": result.elapsed_ms, |
| }, |
| ) |
|
|
| async def _retrieve_knowledge(self, args: dict[str, Any]) -> ToolOutput: |
| """Dense-retrieve relevant chunks from the user's unstructured sources. |
| |
| Pulls qualitative context (PDF/DOCX/TXT) for a natural-language `query` |
| via the retrieval router. `top_k` caps the number of chunks; optional |
| `source_id` scopes to one source (best-effort metadata filter — the |
| router itself does not yet scope by source, so this prunes the results). |
| |
| TODO(retrieval scoping): the Planner few-shot has no `retrieve_knowledge` |
| example, so `source_id` is rarely emitted today and this post-filter is |
| adequate. If source-scoped retrieval becomes common, push scoping down |
| into RetrievalRouter.retrieve()/DocumentRetriever (WHERE |
| cmetadata->>'source_id' = :source_id) and drop this post-filter — more |
| correct than pruning an already-top_k'd unscoped result set. |
| """ |
| query = args.get("query") |
| if not isinstance(query, str) or not query.strip(): |
| return ToolOutput( |
| tool="retrieve_knowledge", |
| kind="error", |
| error="missing 'query' argument", |
| ) |
|
|
| try: |
| top_k = int(args.get("top_k", 5)) |
| except (TypeError, ValueError): |
| top_k = 5 |
| source_id = args.get("source_id") |
|
|
| retriever = self._retriever |
| if retriever is None: |
| from src.retrieval.router import retrieval_router |
|
|
| retriever = retrieval_router |
|
|
| results = await retriever.retrieve(query, self._user_id, top_k) |
| if source_id: |
| results = [r for r in results if _result_source_id(r) == source_id] |
|
|
| documents = [ |
| { |
| "content": r.content, |
| "score": r.score, |
| "source_type": r.source_type, |
| "metadata": r.metadata, |
| } |
| for r in results |
| ] |
| return ToolOutput( |
| tool="retrieve_knowledge", |
| kind="documents", |
| value=documents, |
| meta={ |
| "count": len(documents), |
| "query": query, |
| "top_k": top_k, |
| "source_id": source_id, |
| }, |
| ) |
|
|
|
|
| def _json_safe(value: Any) -> Any: |
| """Coerce DB scalar types that JSON can't represent into plain Python. |
| |
| DB drivers return NUMERIC/DECIMAL as `decimal.Decimal`, which is neither |
| JSON-serializable nor mixable with `float` math. Convert those to `float`; |
| everything else passes through unchanged. |
| """ |
| if isinstance(value, Decimal): |
| return float(value) |
| return value |
|
|
|
|
| def _result_source_id(result: RetrievalResult) -> str | None: |
| """Best-effort extraction of a source_id from a retrieval result's metadata. |
| |
| The chunk metadata schema is owned by the Go ingestion service; the key may |
| live at the top level or nested under "data". Returns None if absent. |
| """ |
| meta = result.metadata or {} |
| top = meta.get("source_id") |
| if isinstance(top, str): |
| return top |
| data = meta.get("data") |
| if isinstance(data, dict): |
| nested = data.get("source_id") |
| if isinstance(nested, str): |
| return nested |
| return None |
|
|