ishaq101's picture
feat/Knowledge & Data Tools (#3)
0721bb4
Raw
History Blame
14 kB
"""DataAccessToolInvoker — the data-access tool family (KM-465 / KM-630).
Implements the `ToolInvoker` Protocol (src/agents/slow_path/invoker.py) for the
data-access family. Unlike the stateless `AnalyticsToolInvoker`, these tools
read the user's catalog / sources, so the invoker is constructed per-request
with the authenticated `user_id` and its dependencies (dependency injection —
the runtime/Coordinator supplies them; INV-7 keeps the agent layer
tool-agnostic).
Tools implemented here:
- `check_data` — structured data sources (DB + tabular). No `source_id`
→ list sources (id, name, type, table count); with a
`source_id` → that source's tables/columns (one row per
column, metadata only — exposes `pii_flag`, never
sample values).
- `check_knowledge` — the user's unstructured sources / documents (id, name,
type).
- `retrieve_data` — runs a pre-built `QueryIR` (validate -> dispatch ->
execute, skipping the planner) and returns rows as
`ToolOutput(kind="table")` — the Pattern A handoff the
`analyze_*` tools consume.
- `retrieve_knowledge` — dense retrieval over unstructured sources, returns
`ToolOutput(kind="documents")`.
Frozen guarantee (§8.4): **never throws.** Any failure returns
`ToolOutput(kind="error", error=...)`.
"""
from __future__ import annotations
from collections.abc import Callable
from decimal import Decimal
from typing import Any, Protocol
from pydantic import ValidationError
from src.catalog.models import Catalog
from src.catalog.reader import CatalogReader
from src.query.executor.dispatcher import ExecutorDispatcher
from src.query.ir.models import QueryIR
from src.query.ir.validator import IRValidationError, IRValidator
from src.retrieval.base import RetrievalResult
from src.tools.contracts import ToolOutput
DispatcherFactory = Callable[[Catalog], ExecutorDispatcher]
# Canonical set of data-access tool names — the single source of truth for which
# tools this invoker serves. `CompositeToolInvoker` imports it to route by name;
# the planner registry should derive its data-access spec names from it (agent ->
# tool is the correct dependency direction). Defining it once here means
# adding/renaming a data-access tool can't silently drift the router out of sync
# from the registry (R11). Must match the names in `DataAccessToolInvoker.invoke`.
DATA_ACCESS_TOOLS: frozenset[str] = frozenset(
{"check_data", "check_knowledge", "retrieve_data", "retrieve_knowledge"}
)
class Retriever(Protocol):
"""Minimal interface this invoker needs from the retrieval layer."""
async def retrieve(
self, query: str, user_id: str, k: int = 5
) -> list[RetrievalResult]: ...
class DataAccessToolInvoker:
"""Never-throwing invoker for catalog-introspection tools (implements ToolInvoker)."""
def __init__(
self,
user_id: str,
catalog_reader: CatalogReader,
*,
ir_validator: IRValidator | None = None,
dispatcher_factory: DispatcherFactory | None = None,
document_retriever: Retriever | None = None,
) -> None:
self._user_id = user_id
self._reader = catalog_reader
# retrieve_data deps — injectable so tests need no real LLM/DB. The
# validator is stateless; the dispatcher is built per-call from the
# request's catalog (executors are picked by source_type).
self._validator = ir_validator or IRValidator()
self._dispatcher_factory: DispatcherFactory = (
dispatcher_factory or ExecutorDispatcher
)
# retrieve_knowledge dep — the module singleton by default, injectable
# for tests (the real one pulls PGVector + Redis). Lazy-imported on first
# use so importing this module stays cheap.
self._retriever = document_retriever
async def invoke(self, tool_name: str, args: dict[str, Any]) -> ToolOutput:
try:
if tool_name == "check_data":
return await self._check_data(args)
if tool_name == "check_knowledge":
return await self._check_knowledge()
if tool_name == "retrieve_data":
return await self._retrieve_data(args)
if tool_name == "retrieve_knowledge":
return await self._retrieve_knowledge(args)
return ToolOutput(
tool=tool_name, kind="error", error=f"unknown tool {tool_name!r}"
)
except Exception as exc: # noqa: BLE001 — never-throw seam (§8.4)
return ToolOutput(
tool=tool_name, kind="error", error=f"{type(exc).__name__}: {exc}"
)
async def _check_data(self, args: dict[str, Any]) -> ToolOutput:
"""Inspect the user's structured data sources (DB + tabular).
No `source_id` → an overview: one row per structured source (id, name,
type, table count). With a `source_id` → that source's schema: one row
per column across its tables.
Pattern A note: schema is catalog metadata only — never returns row
data or PII sample values (only the `pii_flag` boolean per column).
Unstructured documents are covered by `check_knowledge`.
"""
structured = await self._reader.read(self._user_id, "structured")
source_id = args.get("source_id")
if not source_id:
rows = [
[s.source_id, s.name, s.source_type, len(s.tables)]
for s in structured.sources
]
return ToolOutput(
tool="check_data",
kind="table",
columns=["source_id", "name", "source_type", "table_count"],
rows=rows,
meta={"source_count": len(structured.sources)},
)
source = next(
(s for s in structured.sources if s.source_id == source_id), None
)
if source is None:
return ToolOutput(
tool="check_data",
kind="error",
error=f"structured source {source_id!r} not found",
)
rows = [
[
t.table_id,
t.name,
t.row_count,
c.column_id,
c.name,
c.data_type,
c.nullable,
c.pii_flag,
]
for t in source.tables
for c in t.columns
]
return ToolOutput(
tool="check_data",
kind="table",
columns=[
"table_id",
"table_name",
"table_row_count",
"column_id",
"column_name",
"data_type",
"nullable",
"pii_flag",
],
rows=rows,
meta={
"source_id": source.source_id,
"source_name": source.name,
"source_type": source.source_type,
"table_count": len(source.tables),
"column_count": len(rows),
},
)
async def _check_knowledge(self) -> ToolOutput:
"""List the user's unstructured sources (documents).
Documents have no column schema to drill into, so there is no
`source_id` mode — reading document content is `retrieve_knowledge`'s
job.
"""
unstructured = await self._reader.read(self._user_id, "unstructured")
rows = [[s.source_id, s.name, s.source_type] for s in unstructured.sources]
return ToolOutput(
tool="check_knowledge",
kind="table",
columns=["source_id", "name", "source_type"],
rows=rows,
meta={"source_count": len(unstructured.sources)},
)
async def _retrieve_data(self, args: dict[str, Any]) -> ToolOutput:
"""Run one validated, single-table QueryIR and return rows as a table.
This is the spine of the slow path (Pattern A): the `analyze_*` tools
take this output as their `data` arg. We receive an already-built `ir`
from the Planner (never SQL, never an NL question), so we skip the
planner and run validate -> dispatch -> execute directly (the tail of
QueryService.run). Output is `kind="table"` with `columns` + `rows`
(rows are list[list], converted from the executor's list[dict]).
"""
raw = args.get("ir")
if raw is None:
return ToolOutput(
tool="retrieve_data", kind="error", error="missing 'ir' argument"
)
try:
ir = raw if isinstance(raw, QueryIR) else QueryIR.model_validate(raw)
except ValidationError as exc:
return ToolOutput(
tool="retrieve_data", kind="error", error=f"invalid IR: {exc}"
)
catalog = await self._reader.read(self._user_id, "structured")
try:
self._validator.validate(ir, catalog)
except IRValidationError as exc:
return ToolOutput(
tool="retrieve_data",
kind="error",
error=f"IR validation failed: {exc}",
)
dispatcher = self._dispatcher_factory(catalog)
executor = dispatcher.pick(ir)
result = await executor.run(ir)
if result.error:
return ToolOutput(
tool="retrieve_data", kind="error", error=result.error
)
# QueryResult.rows is list[dict]; ToolOutput.rows is list[list] ordered
# by `columns` so downstream materialization is positional. DB NUMERIC
# columns arrive as `Decimal` (asyncpg) — coerce to float here so the
# output is JSON-serializable (SSE / analysis_record persistence) and
# plays nicely with the float math in the analyze_* tools.
rows = [
[_json_safe(row.get(c)) for c in result.columns] for row in result.rows
]
return ToolOutput(
tool="retrieve_data",
kind="table",
columns=result.columns,
rows=rows,
meta={
"source_id": result.source_id,
"source_name": result.source_name,
"table_id": result.table_id,
"table_name": result.table_name,
"backend": result.backend,
"row_count": result.row_count,
"truncated": result.truncated,
"elapsed_ms": result.elapsed_ms,
},
)
async def _retrieve_knowledge(self, args: dict[str, Any]) -> ToolOutput:
"""Dense-retrieve relevant chunks from the user's unstructured sources.
Pulls qualitative context (PDF/DOCX/TXT) for a natural-language `query`
via the retrieval router. `top_k` caps the number of chunks; optional
`source_id` scopes to one source (best-effort metadata filter — the
router itself does not yet scope by source, so this prunes the results).
TODO(retrieval scoping): the Planner few-shot has no `retrieve_knowledge`
example, so `source_id` is rarely emitted today and this post-filter is
adequate. If source-scoped retrieval becomes common, push scoping down
into RetrievalRouter.retrieve()/DocumentRetriever (WHERE
cmetadata->>'source_id' = :source_id) and drop this post-filter — more
correct than pruning an already-top_k'd unscoped result set.
"""
query = args.get("query")
if not isinstance(query, str) or not query.strip():
return ToolOutput(
tool="retrieve_knowledge",
kind="error",
error="missing 'query' argument",
)
try:
top_k = int(args.get("top_k", 5))
except (TypeError, ValueError):
top_k = 5
source_id = args.get("source_id")
retriever = self._retriever
if retriever is None:
from src.retrieval.router import retrieval_router
retriever = retrieval_router
results = await retriever.retrieve(query, self._user_id, top_k)
if source_id:
results = [r for r in results if _result_source_id(r) == source_id]
documents = [
{
"content": r.content,
"score": r.score,
"source_type": r.source_type,
"metadata": r.metadata,
}
for r in results
]
return ToolOutput(
tool="retrieve_knowledge",
kind="documents",
value=documents,
meta={
"count": len(documents),
"query": query,
"top_k": top_k,
"source_id": source_id,
},
)
def _json_safe(value: Any) -> Any:
"""Coerce DB scalar types that JSON can't represent into plain Python.
DB drivers return NUMERIC/DECIMAL as `decimal.Decimal`, which is neither
JSON-serializable nor mixable with `float` math. Convert those to `float`;
everything else passes through unchanged.
"""
if isinstance(value, Decimal):
return float(value)
return value
def _result_source_id(result: RetrievalResult) -> str | None:
"""Best-effort extraction of a source_id from a retrieval result's metadata.
The chunk metadata schema is owned by the Go ingestion service; the key may
live at the top level or nested under "data". Returns None if absent.
"""
meta = result.metadata or {}
top = meta.get("source_id")
if isinstance(top, str):
return top
data = meta.get("data")
if isinstance(data, dict):
nested = data.get("source_id")
if isinstance(nested, str):
return nested
return None