diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000000000000000000000000000000000000..47f51336c29d5cee0593d608149a4a0a001f8a46 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,340 @@ +# Architecture — Data Eyond Agentic Service + +**Last updated**: 2026-05-07 +**Status**: Design phase — folder skeleton in place, implementation in progress + +--- + +## TL;DR + +A catalog-driven AI service for data analysis. Users upload documents and register databases or tabular files; they ask natural-language questions and get answers grounded in their data. + +The architecture has two paths: + +- **Unstructured** (PDF, DOCX, TXT) — dense similarity over prose chunks (the right primitive for free-form text). +- **Structured** (databases, XLSX, CSV, Parquet) — a per-user **data catalog** describes what tables/columns exist; an LLM produces a structured **JSON intermediate representation (IR)** of the user's intent; a deterministic compiler turns the IR into SQL or pandas operations. + +The LLM produces *intent*, not query syntax. Deterministic code does the rest. + +--- + +## 1. Why catalog-driven design + +For a database or spreadsheet, a user's question maps to *known tables and columns* — not to *similar text fragments*. Treating structured data with the same retrieval primitive as prose (chunk + embed + rank top-K) makes the right column survive a probabilistic ranking lottery. Catalog-based **lookup** is the right primitive instead. + +A central per-user catalog also means: + +- One place to keep table/column descriptions (AI-generated, refreshed when the source changes). +- The query planner sees the user's full data landscape in a single prompt. +- Schema stays stable across user sessions without hitting the source DB on every query. +- New sources auto-update the catalog without re-embedding chunks. + +--- + +## 2. Source taxonomy + +``` +Sources +├── Unstructured (pdf, docx, txt) → Cu (prose chunks via DocumentRetriever) +└── Structured + ├── Schema (DB) → Cs (DB tables + columns) + └── Tabular (xlsx, csv, parquet) → Ct (sheets + columns) + Cs ∪ Ct = Data Catalog Context +``` + +- **Cu** = unstructured prose context. Retrieval primitive: dense similarity over chunks. +- **Cs** = DB schema context (tables, columns, descriptions, sample values). +- **Ct** = tabular file context (sheets, columns, descriptions, sample values). +- **Data Catalog Context** = `Cs ∪ Ct`. Passed to the query planner as a single unified view. + +DB vs tabular is **not** a routing concern — it's a per-source attribute (`source_type`) on each catalog entry. The split only matters at execution time (SQL vs pandas). + +--- + +## 3. Routing model + +``` +source_hint ∈ { chat, unstructured, structured } +``` + +- `chat` — no search, conversational reply only +- `unstructured` — DocumentRetriever path (Cu) +- `structured` — catalog-driven path (Cs ∪ Ct → planner → compiler → executor) + +The router commits to one path. Cross-source questions ("compare DB sales vs uploaded customer file") are handled inside the structured path because the planner sees both Cs and Ct in one prompt. + +--- + +## 4. Core architectural decisions + +### 4.1 Catalog as primary context, not retrieval + +For most users (≤50 tables), the entire catalog fits in ~3-5k tokens and is passed verbatim to the planner. No vector search, no BM25, no chunk retrieval. The LLM reads the whole catalog and picks the right table. + +When a user has hundreds of tables, **catalog-level retrieval** (BM25 + table-level vectors with RRF) can be added as a slicer between `CatalogReader` and `Planner`. Deferred until measurably needed. + +### 4.2 JSON IR over raw SQL + +The planner LLM emits a structured JSON IR describing query intent — not a SQL string. A deterministic compiler turns the IR into SQL (per dialect) or pandas/polars operations. + +Benefits: + +- Validatable with Pydantic before execution +- Compiler whitelists allowed operations (no DROP, DELETE, etc.) +- Portable: same IR → SQL (any dialect) / pandas / polars +- Cheaper tokens, easier to debug, trivially testable without an LLM +- LLM cannot emit valid-but-wrong SQL syntax + +### 4.3 Deterministic compiler, not LLM SQL writer + +The LLM produces *intent* (the IR). All actual query construction is deterministic Python. Compiler bugs are reproducible and fixable. Same IR always produces the same query. + +### 4.4 Pipeline stage isolation + +Each stage is its own module with typed input and typed output. No god classes. Stages: `IntentRouter`, `CatalogReader`, `QueryPlanner`, `IRValidator`, `QueryCompiler`, `QueryExecutor`, `ChatbotAgent`. Each is testable in isolation. + +### 4.5 Minimal LLM surface + +LLM calls happen in exactly three places (KM-557 removed `CatalogEnricher`; ingestion is now LLM-free — the planner reads column names, stats, and sample rows directly): + +1. **`IntentRouter`** — once per user message +2. **`QueryPlanner`** — once per structured query (produces the IR) +3. **`ChatbotAgent`** — once per answer (formats the response) + +Compiler and executors are pure code. No LLM in the hot path of query construction. + +--- + +## 5. End-to-end flow + +### Ingestion (when user uploads a file or connects a DB) + +``` +source upload / DB connect + ↓ +introspect schema (DB: information_schema; tabular: file headers + sample rows) + ↓ +validate (Pydantic) + ↓ +write to catalog store (Postgres jsonb in `data_catalog`, keyed by user_id) +``` + +For unstructured files: chunk + embed → PGVector. + +### Query (per user message) + +``` +User message + ↓ +Chat cache check (Redis, 24h TTL) + ↓ miss +Load chat history + ↓ +IntentRouter LLM → needs_search? source_hint? + ↓ + ├── chat → ChatbotAgent → SSE stream + ├── unstructured → DocumentRetriever → answerer + └── structured → + CatalogReader (load full Cs ∪ Ct for user) + ↓ + QueryPlanner LLM → JSON IR + ↓ + IRValidator (Pydantic + columns-exist + ops whitelist) + ↓ + QueryCompiler → SQL (schema source) or pandas (tabular source) + ↓ + QueryExecutor (DbExecutor or TabularExecutor) + ↓ + QueryResult + ↓ + ChatbotAgent → SSE stream +``` + +--- + +## 6. Data catalog + +### Storage + +Per-user JSON document, stored as a `jsonb` row in Postgres keyed by `user_id`. + +### Schema (initial scope) + +``` +Catalog +├── user_id, schema_version, generated_at +└── sources[] + └── Source + ├── source_id, source_type, name, description, location_ref, updated_at + └── tables[] + └── Table + ├── table_id, name, description, row_count + └── columns[] + └── Column + ├── column_id, name, data_type, description + ├── nullable + ├── pii_flag + ├── sample_values[] + └── stats: { min, max, distinct_count } | null +``` + +### Best-practice fields deferred + +`description_human`, `synonyms[]`, `tags[]`, `primary_key`, `foreign_keys`, `unit`, `semantic_type`, `example_questions[]`, `schema_hash`, `enrichment_status`. Add when justified by user need. + +### Stable IDs + +`source_id`, `table_id`, `column_id` are stable internal references. `name` fields can change (e.g. column rename in source DB) without invalidating cached IRs. + +### PII handling + +Columns with `pii_flag: true` have `sample_values: null` — real values never enter LLM prompts. Auto-detected at ingestion via name patterns + value regex. + +--- + +## 7. JSON IR + +### Schema (initial scope) + +``` +QueryIR +├── ir_version : "1.0" +├── source_id : str (references catalog) +├── table_id : str (references catalog) +├── select[] : SelectItem +│ ├── { kind: "column", column_id, alias? } +│ └── { kind: "agg", fn, column_id?, alias? } +├── filters[] : { column_id, op, value, value_type } +├── group_by[] : column_id +├── order_by[] : { column_id | alias, dir } +└── limit : int | null +``` + +### Whitelisted operators + +``` +Filter ops: = != < <= > >= in not_in is_null is_not_null like between +Agg fns: count count_distinct sum avg min max +``` + +### Validation rules (enforced before execution) + +- `source_id` exists in catalog for this user +- `table_id` belongs to that source +- Every `column_id` exists in that table +- Every `agg.fn` and `filter.op` is whitelisted +- `value_type` consistent with column's `data_type` +- `limit` positive int, ≤ hard cap (e.g. 10000) + +If any rule fails → reject IR → re-prompt planner with error context (max 3 retries). + +### Deferred features + +`having`, `offset`, boolean tree filters (OR/NOT), `distinct`, joins, window functions. Add as user demand proves the limitation. + +--- + +## 8. Executors + +Same input (validated IR), same output (`QueryResult`), different backends. + +### DbExecutor (schema sources) + +``` +IR → SqlCompiler → SQL string + params + ↓ +sqlglot validation (SELECT-only, whitelist tables/columns, LIMIT enforced) + ↓ +asyncpg / pymysql in read-only transaction with timeout (30s) + ↓ +QueryResult +``` + +Identifiers come from catalog (verified at validation time, safe to inline as quoted identifiers). Values are always parameterized — never inlined as strings. + +### TabularExecutor (tabular sources) + +``` +IR → PandasCompiler → operation chain + ↓ +choose strategy by file size: + ≤ 100 MB → eager pandas + 100 MB-1 GB → pyarrow with predicate pushdown + > 1 GB → polars lazy scan + ↓ +execute in asyncio.to_thread (CPU work off the event loop) + ↓ +QueryResult +``` + +Initially eager pandas is sufficient. Add the others when a real file is too big. + +### Shared safety guarantees + +1. IR validated before reaching compiler +2. Compiler is deterministic (no LLM) +3. Identifiers from catalog (trusted) +4. Values parameterized +5. sqlglot second-line defence for SQL +6. Read-only at every layer +7. Timeouts and row caps + +--- + +## 9. Implementation scope + +### Initial PR — what ships first + +| Item | Folder | +|---|---| +| Data catalog Pydantic models | `src/catalog/models.py` | +| Catalog ingestion (introspect → enrich → validate → store) | `src/catalog/`, `src/pipeline/` | +| `IntentRouter` with 3-way source_hint | `src/agents/` | +| `CatalogReader` (loads full catalog) | `src/catalog/reader.py` | +| `QueryPlanner` LLM call | `src/query/planner/` | +| JSON IR Pydantic models | `src/query/ir/models.py` | +| IR validator | `src/query/ir/validator.py` | + +**Output**: a validated JSON IR object. Execution lands in a follow-up PR. + +### Follow-up PRs + +| PR | Scope | +|---|---| +| 2 | `QueryCompiler` (IR → SQL / pandas) | +| 3 | `QueryExecutor` split: `DbExecutor` + `TabularExecutor` | +| 4 | Retry / self-correction loop on execution failure | +| 5 | Eval harness (golden question→IR→result examples) | +| 6 | Auto PII tagging in catalog | +| Later | Joins in IR, schema drift detection, hybrid catalog search | + +--- + +## 10. Open questions + +| # | Question | Why it matters | +|---|---|---| +| 1 | Catalog storage: JSON file per user vs Postgres `jsonb` row? | Affects ingestion + read performance | +| 2 | Should the catalog also list unstructured files (with descriptions only)? | Gives router unified view of all user sources | +| 3 | Catalog refresh trigger: explicit "rebuild" button, on every upload, or background TTL? | Staleness vs latency tradeoff | +| 4 | Confirm joins are out of initial IR scope? | Limits what user questions can be answered | +| 5 | PII handling for sample_values: mask, synthesize, or skip? | Affects what gets sent to LLM prompts | + +--- + +## 11. References + +- `docs/flowchart.html` — interactive end-to-end diagram (open in browser) +- `docs/flowchart.mmd` — mermaid source for the diagram + +--- + +## Glossary + +- **Cu** — unstructured context (prose chunks) +- **Cs** — schema context (DB tables/columns from catalog) +- **Ct** — tabular context (file sheets/columns from catalog) +- **IR** — intermediate representation (the JSON query shape) +- **PR** — pull request (a unit of code change) +- **PII** — personally identifiable information (names, emails, etc.) +- **ABC** — abstract base class (Python contract for subclasses) diff --git a/PHASE1_TO_PHASE2_REPORT.md b/PHASE1_TO_PHASE2_REPORT.md new file mode 100644 index 0000000000000000000000000000000000000000..93041c8e9a0b21579fa863cdc18023055d48a2bb --- /dev/null +++ b/PHASE1_TO_PHASE2_REPORT.md @@ -0,0 +1,260 @@ +# Phase 1 → Phase 2 Migration Report + +A walkthrough of what changed between the original retrieval-style backend (Phase 1) and the current catalog-driven backend (Phase 2). Intended as a hand-off for the lead. + +--- + +## 1. The conceptual change + +**Phase 1** was a single retrieval-style RAG pipeline. Every question — whether it pointed at a database, a spreadsheet, or a PDF — went through the same primitive: **chunk + embed + top-K** over PGVector. Schema and tabular columns were embedded as chunks and ranked alongside prose. When the question needed SQL, the LLM **wrote the SQL string directly** (via `query_executor`). + +**Phase 2** splits the system into two paths governed by an LLM router: + +| Path | Primitive | Why | +|---|---|---| +| Unstructured (PDF / DOCX / TXT) | Dense similarity over prose chunks (PGVector) | Right primitive for free text | +| Structured (DB / CSV / XLSX / Parquet) | **Per-user data catalog** → LLM emits a **JSON IR** of intent → deterministic **compiler** → **executor** (SQL or pandas) | A column lookup shouldn't go through a similarity ranking lottery; the LLM emits intent, never SQL syntax | + +Three explicit LLM call sites only: + +1. **Intent router** (classifies the user message into `chat` / `unstructured` / `structured`) +2. **Query planner** (turns the question + catalog into a Pydantic-validated `QueryIR`) +3. **Chatbot agent** (formats the final answer, streamed over SSE) + +Everything else — IR validation, SQL/pandas compilation, execution — is deterministic Python. + +--- + +## 2. File-by-file changes + +### 2.1 Deleted (Phase 1 only) + +| Phase 1 path | Reason it was removed | +|---|---| +| `src/rag/base.py`, `src/rag/retriever.py`, `src/rag/router.py` | Replaced by `src/retrieval/` | +| `src/rag/retrievers/baseline.py`, `schema.py`, `document.py` | Schema retrieval gone (catalog replaces it); document retriever rewritten in `src/retrieval/document.py` | +| `src/tools/search.py` (whole `tools/` folder) | Only consumer was `rag/router.py` | +| `src/query/base.py` | Duplicate of `query/executor/base.py` | +| `src/query/query_executor.py` | Replaced by `src/query/service.py` | +| `src/query/executors/db_executor.py` | Replaced by `src/query/executor/db.py` | +| `src/query/executors/tabular.py` | Replaced by `src/query/executor/tabular.py` | +| `src/agents/chatbot.py` (Phase 1 LangChain chatbot) | Phase 2 `ChatbotAgent` lives at the same path now — see §2.2 | +| `src/api/v1/knowledge.py` | Fake `/knowledge/rebuild` endpoint, never wired | +| `src/config/agents/system_prompt.md`, `guardrails_prompt.md` | Replaced by `src/config/prompts/{chatbot_system,guardrails}.md` | +| `src/models/structured_output.py` (`IntentClassification`) | Replaced by `IntentRouterDecision` Pydantic model inside `agents/orchestration.py` | +| `src/models/sql_query.py` | LLM no longer emits SQL; IR replaces it | +| `src/pipeline/orchestrator.py` (empty stub) | Redundant — `StructuredPipeline` takes the introspector at `run()` time | + +### 2.2 Renamed / moved (same role, new home) + +| Phase 1 location | Phase 2 location | Notes | +|---|---|---| +| `src/agents/chatbot.py` (Phase 1) → deleted, then `src/agents/answer_agent.py` (`AnswerAgent`) → renamed | `src/agents/chatbot.py::ChatbotAgent` | Final answer formation; streams via `astream` | +| `src/knowledge/parquet_service.py` | `src/storage/parquet.py` | Parquet upload/download helper | +| `src/pipeline/document_pipeline/document_pipeline.py` (folder) | `src/pipeline/document_pipeline.py` (flat) | Single module | +| `src/rag/retrievers/document.py` | `src/retrieval/document.py` | `DocumentRetriever` migrated; tabular file types filtered out of results | +| `src/rag/router.py` | `src/retrieval/router.py` | `RetrievalRouter`, Redis-cached, unstructured-only; dead `db: AsyncSession` + `source_hint` params removed | +| `src/rag/base.py` (`RetrievalResult`, `BaseRetriever`) | `src/retrieval/base.py` | Same dataclass + ABC | + +> **Heads-up on the intent router**: the Phase 1 file `src/agents/orchestration.py` and its class `OrchestratorAgent` were **kept in place** for Phase 2 — but the body was fully rewritten. The class now emits `IntentRouterDecision(needs_search, source_hint ∈ {chat, unstructured, structured}, rewritten_query)`. The prompt file and test file use the `intent_router` name (`config/prompts/intent_router.md`, `tests/agents/test_intent_router.py`), but **the source module is still `orchestration.py` and the class is still `OrchestratorAgent`**. Existing imports continue to work; only the behavior changed. + +### 2.3 Added (Phase 2 new) + +**Catalog subsystem (whole new concept)** + +| Path | Role | +|---|---| +| `src/catalog/models.py` | Pydantic: `Catalog → Source[] → Table[] → Column[]`, `ForeignKey`, `ColumnStats.top_values` | +| `src/catalog/introspect/base.py` | `BaseIntrospector` ABC | +| `src/catalog/introspect/database.py` | DB introspector — wraps Phase 1 `db_pipeline/extractor.py` (`get_schema`, `profile_column`, `get_row_count`) | +| `src/catalog/introspect/tabular.py` | CSV / XLSX / Parquet introspector — one `Table` per XLSX sheet | +| `src/catalog/render.py` | Renders a `Source` for the planner prompt | +| `src/catalog/validator.py` | Unique-ID + foreign-key-ref invariants | +| `src/catalog/store.py` | Postgres `jsonb` upsert keyed by `user_id` (table `data_catalog`) | +| `src/catalog/reader.py` | Loads + filters catalog by `source_hint` | +| `src/catalog/pii_detector.py` | Flags PII columns at ingestion → suppresses `sample_values` | +| `src/security/pii_patterns.py` | Name patterns + value regex used by the detector | + +**JSON IR + query subsystem** + +| Path | Role | +|---|---| +| `src/query/ir/models.py` | `QueryIR` Pydantic schema | +| `src/query/ir/operators.py` | `ALLOWED_FILTER_OPS`, `ALLOWED_AGG_FNS`, `LIMIT_HARD_CAP`, `TYPE_COMPATIBILITY` | +| `src/query/ir/validator.py` | Catalog-aware IR validation (rejects unknown column ids, bad ops, type mismatches, oversize limits) | +| `src/query/planner/service.py` | `QueryPlannerService.plan(question, catalog, previous_error)` — Azure OpenAI structured output → `QueryIR` | +| `src/query/planner/prompt.py` | Builds the planner prompt from catalog text | +| `src/query/compiler/base.py` | Compiler ABC | +| `src/query/compiler/sql.py` | `SqlCompiler` (Postgres) — all 12 filter ops, params as a dict | +| `src/query/compiler/pandas.py` | `PandasCompiler` — returns `CompiledPandas(apply, output_columns)` | +| `src/query/executor/base.py` | `BaseExecutor` + `QueryResult` | +| `src/query/executor/db.py` | `DbExecutor` — sqlglot SELECT-only guard, RO txn, 30 s `statement_timeout`, 10 k row cap | +| `src/query/executor/tabular.py` | `TabularExecutor` — Parquet via blob, `asyncio.to_thread`, 10 k cap | +| `src/query/executor/dispatcher.py` | `ExecutorDispatcher.pick(ir)` — picks by `source.source_type` | +| `src/query/service.py` | `QueryService.run(user_id, question, catalog)` — plan → validate → retry (max 3) → dispatch → execute | + +**Agents** + +| Path | Role | +|---|---| +| `src/agents/orchestration.py` | `OrchestratorAgent` — Phase 1 file/class name preserved; Phase 2 body. Emits `IntentRouterDecision` | +| `src/agents/chatbot.py` | `ChatbotAgent` — formerly `AnswerAgent` in `agents/answer_agent.py`; renamed in Cleanup PR | +| `src/agents/chat_handler.py` | `ChatHandler.handle(...)` — top-level orchestrator; yields `intent` / `chunk` / `done` / `error` SSE events | + +**Pipelines & API** + +| Path | Role | +|---|---| +| `src/pipeline/structured_pipeline.py` | DB / tabular ingestion: introspect → merge → validate → upsert | +| `src/pipeline/triggers.py` | `on_db_registered`, `on_tabular_uploaded`, `on_document_uploaded`, `on_catalog_rebuild_requested` | +| `src/api/v1/data_catalog.py` | `GET /api/v1/data-catalog/{user_id}` + `POST /api/v1/data-catalog/rebuild` | +| `src/models/api/catalog.py` | Catalog request/response models | +| `src/config/prompts/intent_router.md`, `query_planner.md`, `chatbot_system.md`, `guardrails.md` | New prompts. `guardrails.md` is appended to `chatbot_system.md` at load time | +| `src/db/postgres/models.py` (added `Catalog` SQLAlchemy class) | Stores the per-user jsonb document in `data_catalog` | + +### 2.4 Rewired API endpoints + +| Endpoint | Phase 1 wiring | Phase 2 wiring | +|---|---|---| +| `POST /api/v1/chat/stream` | Inline in `chat.py`: `OrchestratorAgent` → `retriever` → `query_executor` → `chatbot` | Delegates to `ChatHandler.handle()`. Redis cache, fast intent, history load, and message persistence stay in the endpoint | +| `POST /api/v1/database-clients/{id}/ingest` | Called `db_pipeline_service.run()` and dual-wrote vectors | Calls **only** `on_db_registered` (catalog build). Failure → HTTP 500 | +| `POST /api/v1/document/process` | Always pushed to vector store | PDF/DOCX/TXT → `knowledge_processor` (vectors); CSV/XLSX → `on_tabular_uploaded` (catalog only, **no vector embedding**) | +| `POST /api/v1/document/upload` | Storage + DB row | Same, plus `on_document_uploaded` trigger | +| `POST /api/v1/data-catalog/rebuild` | — | New: iterates all sources, re-runs per-source trigger | +| `GET /api/v1/data-catalog/{user_id}` | — | New: returns `list[CatalogIndexEntry]` | + +### 2.5 Phase 1 files still in production use + +These were **not rewritten** — Phase 2 imports them directly: + +- `src/database_client/database_client_service.py` +- `src/utils/db_credential_encryption.py` (`decrypt_credentials_dict`) — `src/security/credentials.py` is still a stub +- `src/pipeline/db_pipeline/db_pipeline_service.py` (`engine_scope` context manager — used by both the introspector and `DbExecutor`) +- `src/pipeline/db_pipeline/extractor.py` (`get_schema`, `profile_column`, `get_row_count`) +- `src/knowledge/processing_service.py` (PDF / DOCX / TXT extraction + embedding) +- `src/db/postgres/{connection,init_db,vector_store}.py`, `src/storage/az_blob/`, `src/middlewares/`, `src/security/auth.py` + +--- + +## 3. End-to-end flow (current state) + +### 3.1 Ingestion + +``` +User action Pipeline Storage +────────────── ──────────────────────────── ───────────────── +upload PDF/DOCX/TXT → DocumentPipeline → Azure Blob + PGVector + (extract → chunk → embed) (table: langchain_pg_embedding) + + on_document_uploaded + retrieval cache invalidate + +upload CSV/XLSX → TabularIntrospector → Azure Blob (Parquet) + (sheets / columns + sample + stats) + data_catalog jsonb row + → CatalogValidator → CatalogStore (NO vector store — catalog only) + via on_tabular_uploaded + +register DB → DatabaseIntrospector → data_catalog jsonb row + (information_schema + sample + FKs) + → validate → store + via on_db_registered +``` + +### 3.2 Query (per user message → SSE stream) + +``` +POST /api/v1/chat/stream + │ + ├── Redis cache check (24h TTL) — hit returns cached stream + ├── _fast_intent (greetings / goodbyes) — bypass LLM + ├── load history from chat_messages + │ + └── ChatHandler.handle(message, user_id, history) [src/agents/chat_handler.py] + │ + ├─ OrchestratorAgent.classify() [agents/orchestration.py] + │ → needs_search, source_hint, rewritten_query + │ + ├── source_hint == "chat" + │ → ChatbotAgent.astream() → yield chunk events + │ + ├── source_hint == "unstructured" + │ → RetrievalRouter.retrieve() [retrieval/router.py, Redis-cached] + │ → DocumentRetriever (PGVector MMR/cosine/etc.) + │ → ChatbotAgent.astream(chunks=...) + │ + └── source_hint == "structured" + → CatalogReader.read(user_id, "structured") [catalog/reader.py] + → QueryService.run(user_id, question, catalog) [query/service.py] + │ + ├─ QueryPlannerService.plan(...) [query/planner/service.py] + │ LLM(catalog, question, prev_error?) → QueryIR + │ + ├─ IRValidator.validate(ir, catalog) [query/ir/validator.py] + │ fail → loop back to planner with error context (max 3) + │ + ├─ ExecutorDispatcher.pick(ir) [query/executor/dispatcher.py] + │ schema source → DbExecutor + │ tabular source → TabularExecutor + │ + ├─ DbExecutor.run(ir): [query/executor/db.py] + │ SqlCompiler → (sql, params) + │ → sqlglot SELECT-only guard + │ → engine_scope (Phase 1 utility) in asyncio.to_thread + │ → RO txn + statement_timeout=30s + 10k cap + │ + ├─ TabularExecutor.run(ir): [query/executor/tabular.py] + │ resolve Parquet blob path + │ → download → PandasCompiler.apply(df) + │ → asyncio.to_thread → 10k cap + │ + └─ QueryResult { rows, columns, row_count, + truncated, source_id, error?, elapsed_ms } + → + ChatbotAgent.astream(query_result=...) + → yield chunk events + │ + └── final events: done / error + │ + └── persist user + assistant messages to chat_messages + └── populate Redis cache +``` + +**Safety invariants for the structured path** (read-only at every layer): + +1. IR validated against the catalog before reaching the compiler +2. Identifiers come from the catalog (trusted; inlined as quoted identifiers) +3. Values from `IR.filters` are always parameterized +4. Compiler is deterministic — no LLM in the hot path +5. sqlglot rejects anything that isn't a pure SELECT +6. DB connection is read-only with a 30 s `statement_timeout` +7. Hard 10 000 row cap on both executors; neither raises — errors go in `QueryResult.error` + +--- + +## 4. Summary table for review + +| Concern | Phase 1 — where it lived | Phase 2 — where it lives | Change type | +|---|---|---|---| +| Intent classification | `agents/orchestration.py::OrchestratorAgent` (free-text intent) | **Same path + same class name** — body rewritten to emit `IntentRouterDecision` | Body rewrite only | +| Top-level chat orchestration | Inline in `api/v1/chat.py` | `agents/chat_handler.py::ChatHandler` | Extracted to a reusable module | +| Final answer formation | `agents/chatbot.py` (Phase 1 LangChain) | `agents/chatbot.py::ChatbotAgent` (was `AnswerAgent` in `answer_agent.py` mid-cycle) | Rewritten + renamed | +| Schema retrieval (DB / tabular) | `rag/retrievers/schema.py` + PGVector chunks | **Removed**. Replaced by catalog (`catalog/store.py` jsonb) loaded verbatim into planner prompt | Whole concept replaced | +| Doc retrieval (PDF / DOCX / TXT) | `rag/retrievers/document.py`, `rag/router.py` | `retrieval/document.py`, `retrieval/router.py` | Moved; Redis cache restored; tabular files filtered | +| Query writing | `query/query_executor.py` + `models/sql_query.py` (LLM writes SQL) | `query/planner/service.py` (LLM writes IR) + `query/compiler/sql.py` (deterministic) | LLM emits intent, not SQL | +| DB execution | `query/executors/db_executor.py` | `query/executor/db.py::DbExecutor` | Folder renamed (`executors` → `executor`); sqlglot guard + RO txn + 30 s timeout kept | +| Tabular execution | `query/executors/tabular.py` | `query/executor/tabular.py::TabularExecutor` | Parquet-only; pandas compiler split out | +| Executor selection | Hard-coded in `query_executor.py` | `query/executor/dispatcher.py::ExecutorDispatcher` | New; routes by `source.source_type` | +| Catalog (NEW) | — | `catalog/` (models, introspect/, validator, store, reader, pii_detector, render) | New subsystem | +| Catalog persistence | (data was embedded in PGVector) | Postgres jsonb table `data_catalog`, keyed by `user_id` | New table | +| Ingestion triggers | Inline in API endpoints | `pipeline/triggers.py` (`on_db_registered`, `on_tabular_uploaded`, `on_document_uploaded`, `on_catalog_rebuild_requested`) | Centralized event entry points | +| Structured pipeline | `pipeline/db_pipeline/db_pipeline_service.py` (still present for `engine_scope` + extractor reuse) | `pipeline/structured_pipeline.py` (orchestrator) — reuses Phase 1 extractor | New orchestrator wraps Phase 1 introspection helpers | +| Document pipeline | `pipeline/document_pipeline/document_pipeline.py` (folder) | `pipeline/document_pipeline.py` (file) | Flattened; CSV / XLSX now skip the vector store | +| Parquet helper | `knowledge/parquet_service.py` | `storage/parquet.py` | Moved into `storage/` | +| Prompts | `config/agents/system_prompt.md`, `guardrails_prompt.md` | `config/prompts/{intent_router,query_planner,chatbot_system,guardrails}.md` | Folder renamed; split into four files; guardrails appended to `chatbot_system` at load | +| PII detection | — | `catalog/pii_detector.py` + `security/pii_patterns.py` | New. Columns flagged `pii_flag=true` get `sample_values: null` so PII never enters prompts | +| Chat endpoint | `api/v1/chat.py` (does everything inline) | `api/v1/chat.py` (cache + history + persistence) → delegates to `ChatHandler` | Slimmed; SSE event shape is `intent` / `chunk` / `done` / `error` | +| DB ingest endpoint | `api/v1/db_client.py::ingest` (Phase 1 `db_pipeline_service.run()`) | `api/v1/db_client.py::ingest` (calls `on_db_registered` only) | Phase 1 dual-write removed | +| Document process endpoint | `api/v1/document.py::process` (always vectorize) | `api/v1/document.py::process` (PDF/DOCX/TXT → vectors; CSV/XLSX → catalog via `on_tabular_uploaded`) | Routing by file type | +| Catalog management API | — | `api/v1/data_catalog.py` (GET index + POST rebuild) | New | + +**Bottom line.** Every Phase 1 file under `src/rag/`, `src/tools/`, `src/query/executors/`, `src/query/query_executor.py`, `src/query/base.py`, `src/api/v1/knowledge.py`, and `src/config/agents/` is gone. Phase 1 introspection helpers under `src/pipeline/db_pipeline/` and `src/database_client/` are still imported by Phase 2 — they were not rewritten, just wrapped. The three LLM call sites are now explicit and the SQL-writing one no longer exists; the planner emits a Pydantic-validated `QueryIR` instead. + +The one filename gotcha to remember: the **intent router** still lives at `src/agents/orchestration.py` as class `OrchestratorAgent` (Phase 1 name kept for import-site compatibility, Phase 2 body). The matching prompt and tests use the `intent_router` name, but the source module does not. diff --git a/PROGRESS.md b/PROGRESS.md new file mode 100644 index 0000000000000000000000000000000000000000..9d8361343924c7f035873f9cd19f600411de2329 --- /dev/null +++ b/PROGRESS.md @@ -0,0 +1,381 @@ +# Progress — Phase 2 catalog-driven build + +Persistent tracker mirroring the 42-item ownership table in `REPO_CONTEXT.md` "Team — division of work". Update as PRs land. Future Claude Code sessions read this to know what's already done. + +**Last updated**: 2026-05-12 ([NOTICKET] Cleanup PR landed: ChatHandler wired to chat.py, Phase 1 dual-write dropped from /ingest, on_catalog_rebuild_requested implemented, dead modules deleted, answer_agent→chatbot renamed, retrieval cache restored via RetrievalRouter, top_values added to ColumnStats, lifespan migration, knowledge_router removed) +**Current open PR**: `pr/1` — active. Cleanup PR committed and pushed. + +--- + +## Legend + +- `[x]` done and merged +- `[~]` in progress (open PR or active branch) +- `[ ]` not started +- **DB** / **TAB** / **B** — ownership (from REPO_CONTEXT.md) + +--- + +## PR sequence + +| PR | Status | Owner(s) | Scope | +|---|---|---|---| +| PR1 | `[x]` merged | DB | Contract locks + catalog plumbing + DB introspector + IR validator + tests | +| PR1-tab | `[x]` shipped | TAB | Tabular introspector + on_tabular_uploaded trigger + 31 unit tests | +| PR2a | `[x]` merged | DB | CatalogEnricher + StructuredPipeline + on_db_registered trigger + FK extension on Table (enricher later removed in KM-557) | +| KM-557 | `[x]` shipped | DB | Drop CatalogEnricher entirely (cost cut — planner uses stats + sample rows directly); rename jsonb table `catalogs` → `data_catalog`; add `GET /api/v1/data-catalog/{user_id}` index endpoint for catalog refresher | +| PR2b | `[x]` shipped | DB-solo (B-review) | IntentRouter + planner prompt + planner LLM service | +| PR3-DB | `[x]` shipped | DB | SqlCompiler (Postgres) + DbExecutor (sqlglot guard, RO + statement_timeout, asyncio.to_thread) + 36 golden IR→SQL tests | +| PR3-TAB | `[x]` shipped | TAB | PandasCompiler + TabularExecutor + 43+12 golden IR→DataFrame tests | +| PR4 | `[x]` | DB-solo (B-review) | ExecutorDispatcher + QueryService + ChatHandler module. **API rewired in Cleanup PR.** | +| PR5 | `[x]` shipped | DB-solo (B-review) | Retry/self-correction loop on validation failure (lives in QueryService, max 3 attempts, planner re-prompted with prior error) | +| PR6 | `[~]` scaffold | DB-solo (B-review) | Eval harness scaffold + 3 DB-targeting golden cases. Skipped without `RUN_PLANNER_EVAL=1` env. TAB extends with tabular cases. | +| PR7 | `[x]` | DB-solo (B-review) | `ChatbotAgent` (renamed from `AnswerAgent`) + chatbot_system + guardrails prompts. `answer_agent.py` → `chatbot.py`, `AnswerAgent` → `ChatbotAgent`. API rewired in Cleanup PR. | +| Cleanup | `[x]` | B | ChatHandler wired to chat.py; Phase 1 dual-write dropped from /ingest; on_catalog_rebuild_requested + POST /data-catalog/rebuild; dead modules deleted (chatbot Phase 1, orchestrator, query/base, knowledge.py, config/agents/); retrieval cache restored via RetrievalRouter; top_values added to ColumnStats; lifespan migration; knowledge_router removed. | + +--- + +## All items + +### Contracts (B — shared) + +| # | Item | Status | Notes | +|---|---|---|---| +| 1 | Catalog Pydantic models (`catalog/models.py`) | `[x]` | PR1 added `location_ref` URI-scheme docstring; PR2a added `ForeignKey` model + `Table.foreign_keys` field | +| 2 | IR Pydantic models (`query/ir/models.py`) | `[x]` | Pre-existing scaffold | +| 3 | IR operator whitelists (`query/ir/operators.py`) | `[x]` | PR1 filled `TYPE_COMPATIBILITY` matrix | +| 4 | PII patterns / regex (`security/pii_patterns.py`) | `[x]` | Pre-existing | +| — | `data_catalog` Postgres jsonb table (`db/postgres/models.py`) | `[x]` | PR1 added `Catalog` SQLAlchemy class + `init_db.py` import. KM-557 renamed `__tablename__` from `catalogs` → `data_catalog`; created fresh (no migration) | +| — | `QueryResult` shape (`query/executor/base.py`) | `[x]` | Pre-existing scaffold; `columns: list[str]` added (TAB owner, PR1-tab) — DbExecutor updated to populate it. | +| — | `Source.location_ref` URI scheme | `[x]` | PR1 documented in `catalog/models.py` docstring | + +### Ingestion — introspection + +| # | Item | Owner | Status | Notes | +|---|---|---|---|---| +| 5 | DB introspector (`catalog/introspect/database.py`) | DB | `[x]` | PR1 — reuses Phase 1 `database_client_service`, `db_credential_encryption`, `db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`. PR2a wired FK extraction (was discarded before). | +| 6 | Tabular introspector (`catalog/introspect/tabular.py`) | TAB | `[~]` | PR1-tab — downloads original blob (CSV/XLSX/Parquet), one Table per sheet (XLSX) or one Table (CSV/Parquet). `source_id = document_id`. `fetch_doc`/`fetch_blob` injectable for unit tests (no Settings). | +| 7 | `BaseIntrospector` ABC (`catalog/introspect/base.py`) | B | `[x]` | Pre-existing; signature locked | + +### Ingestion — shared catalog plumbing + +| # | Item | Owner | Status | Notes | +|---|---|---|---|---| +| 8 | ~~Catalog enricher + prompt~~ | B | **REMOVED in KM-557** | Cost optimization — planner reads stats + sample rows + column names directly. `catalog/enricher.py` + `config/prompts/catalog_enricher.md` deleted. `render_source` (the only piece still needed) moved to `src/catalog/render.py`. Tests moved to `tests/catalog/test_render.py`. | +| 9 | Catalog validator (`catalog/validator.py`) | B | `[x]` | PR1 (DB owner picked up) — uniqueness invariants | +| 10 | Catalog store — Postgres jsonb (`catalog/store.py`) | B | `[x]` | PR1 (DB owner picked up) — `INSERT ... ON CONFLICT` | +| 11 | Catalog reader (`catalog/reader.py`) | B | `[x]` | PR1 (DB owner picked up) — filters by source_hint, empty on miss | +| 12 | PII detector (`catalog/pii_detector.py`) | B | `[x]` | PR1 (DB owner picked up) — name + value matching, bias toward over-flag | + +### Ingestion — pipelines + +| # | Item | Owner | Status | Notes | +|---|---|---|---|---| +| 13 | Structured pipeline (`pipeline/structured_pipeline.py`) | B | `[x]` | PR2a (DB owner) — Source-type-agnostic: caller supplies the introspector. `default_structured_pipeline()` factory wires production deps lazily so tests can inject mocks without `Settings()` construction. **KM-557**: enrich step removed; pipeline is now `introspect → merge with existing → validate → upsert`. Constructor no longer takes `enricher`. | +| 14 | Triggers (`pipeline/triggers.py`) | B | `[x]` | PR2a — `on_db_registered` implemented (DB owner). PR1-tab — `on_tabular_uploaded` implemented (TAB owner). **2026-05-11** — `on_document_uploaded` implemented. **2026-05-12** — `on_catalog_rebuild_requested` implemented: iterates all Sources in current catalog, re-runs `on_db_registered` (schema) or `on_tabular_uploaded` (tabular) per source; per-source errors logged but don't abort. | +| 15 | Ingestion orchestrator (`pipeline/orchestrator.py`) | B | **DELETED** | Redundant stub — `StructuredPipeline` already takes introspector at run() time. Deleted in Cleanup PR. | +| 16 | Document pipeline (`pipeline/document_pipeline.py`) | TAB | `[x]` | Flattened `pipeline/document_pipeline/document_pipeline.py` (folder) → `pipeline/document_pipeline.py` (file). Updated import in `api/v1/document.py`. | + +### Query — shared spine + +| # | Item | Owner | Status | Notes | +|---|---|---|---|---| +| 17 | IR validator (`query/ir/validator.py`) | B | `[x]` | PR1 (DB owner) — full rule set; descriptive errors for planner retry | +| 18 | Planner LLM service (`query/planner/service.py`) | B | `[x]` | PR2b — Azure OpenAI structured output → `QueryIR`. Injectable chain. Supports retry via `previous_error` argument. | +| 19 | Planner prompt (`query/planner/prompt.py`, `config/prompts/query_planner.md`) | B | `[x]` | PR2b — system prompt with hard constraints + few-shot for DB and tabular sources. `build_planner_prompt(question, catalog, previous_error)` calls `catalog.render.render_source` (renamed from `catalog.enricher.render_source` in KM-557). | +| 20 | Intent router (`agents/orchestration.py` — class `OrchestratorAgent`; `config/prompts/intent_router.md`) | B | `[x]` | PR2b — single LLM call → `IntentRouterDecision(needs_search, source_hint, rewritten_query)`. Supports conversation history. **NOTE**: source filename + class name were kept from Phase 1 for import-site compatibility; only the body is Phase 2. Prompt file and test file use the `intent_router` name. | +| 21 | Executor base + `QueryResult` (`query/executor/base.py`) | B | `[x]` | Pre-existing scaffold | +| 22 | Executor dispatcher (`query/executor/dispatcher.py`) | B | `[x]` | PR4 — picks DbExecutor / TabularExecutor by `source.source_type`. Lazy imports of production executors keep import side-effect-free for tests. Caches per source_type. | +| 23 | Compiler base ABC (`query/compiler/base.py`) | B | `[x]` | Pre-existing scaffold | +| 24 | Top-level QueryService (`query/service.py`) | B | `[x]` | PR4+5 — `plan → validate → dispatch → execute → QueryResult`. Retry loop on validation failure (max 3, planner re-prompted with prior error). Catches NotImplementedError from TabularExecutor placeholder gracefully. Never raises. | + +### Query — DB path + +| # | Item | Status | Notes | +|---|---|---|---| +| 25 | SQL compiler (`query/compiler/sql.py`) | `[x]` | PR3-DB — Postgres dialect (Supabase reuses); deterministic IR → (sql, named-params dict); double-quoted identifiers from catalog; all whitelisted ops (=, !=, <, <=, >, >=, in, not_in, is_null, is_not_null, like, between); alias-aware order_by; `CompiledSql.params: dict[str, Any]` (changed from `list`). MySQL/BigQuery/Snowflake compilers later. | +| 26 | DB executor (`query/executor/db.py`) | `[x]` | PR3-DB — sync engine via `db_pipeline_service.engine_scope` inside `asyncio.to_thread`. sqlglot SELECT-only / no-DML guard. Postgres-only session settings: `default_transaction_read_only=on` + `statement_timeout=30000`. asyncio.wait_for backstop. Never raises — populates `QueryResult.error`. 10k row hard cap. | +| 27 | Credential encryption (`security/credentials.py`) | `[ ]` | Stub exists; PR1 reused Phase 1 `utils/db_credential_encryption.py` instead. Move in cleanup PR | +| 28 | User-DB connection management | `[x]` | PR3-DB reused Phase 1 `db_pipeline_service.engine_scope` (same as PR1 introspector); no new helper needed | + +### Query — Tabular path + +| # | Item | Status | Notes | +|---|---|---|---| +| 29 | Pandas compiler (`query/compiler/pandas.py`) | `[~]` | PR3-TAB — `CompiledPandas` dataclass; all 12 filter ops; all 6 aggs; group_by via `pd.concat` of Series; alias-aware order_by; `_like_to_regex` (`%`→`.*`, `_`→`.`); pure module-level helpers | +| 30 | Tabular executor (`query/executor/tabular.py`) | `[~]` | PR3-TAB — `fetch_blob` injectable for tests; blob path: single-table → `{uid}/{did}.parquet`, multi-table → `{uid}/{did}__{table.name}.parquet`; `asyncio.to_thread`; 10k row hard cap; errors → `QueryResult.error` | +| 31 | Parquet upload/download wrapper | `[x]` | Moved `knowledge/parquet_service.py` → `storage/parquet.py`. Updated 4 import sites: `pipeline/document_pipeline.py`, `knowledge/processing_service.py`, `query/executor/tabular.py`, `query/executors/tabular.py`. | + +### Agents + chat + +| # | Item | Status | Notes | +|---|---|---|---| +| 32 | Chatbot agent + prompt (`agents/chatbot.py`, `config/prompts/chatbot_system.md`) | `[x]` | PR7-bundle — `ChatbotAgent` (was `AnswerAgent`) streams tokens, accepts `QueryResult` or list[`DocumentChunk`] or neither. **Cleanup PR**: renamed `answer_agent.py` → `chatbot.py`, `AnswerAgent` → `ChatbotAgent`; Phase 1 `agents/chatbot.py` deleted. | +| 33 | Guardrails prompt (`config/prompts/guardrails.md`) | `[x]` | PR7-bundle — appended to `chatbot_system.md` so guardrails take precedence in conflict. | +| — | Chat handler / orchestrator (`agents/chat_handler.py`) | `[x]` | PR4-bundle — top-level Phase 2 orchestrator. Routes by `source_hint`: chat → AnswerAgent direct; structured → CatalogReader + QueryService; unstructured → DocumentRetriever placeholder + AnswerAgent. Yields `intent` / `chunk` / `done` / `error` SSE-style events. Phase 1 chat.py NOT touched — cleanup PR rewires the API to call this. | + +### API surface + +| # | Item | Owner | Status | Notes | +|---|---|---|---|---| +| 34 | DB client endpoints (`api/v1/db_client.py`) | DB | `[x]` | **Cleanup PR** — `/ingest` now calls only `on_db_registered`. Phase 1 `db_pipeline_service.run()` + `decrypt_credentials_dict` removed. Error from catalog build now raises HTTP 500 (was silent log). Response simplified to `{"status": "success", "client_id": ...}`. | +| 35 | Document/tabular upload endpoints (`api/v1/document.py`) | TAB | `[x]` | Rewired `/document/process` — after processing CSV/XLSX, calls `on_tabular_uploaded(document_id, user_id)`. Catalog ingestion failure is logged but does not fail the request. **2026-05-11** — CSV/XLSX no longer ingested to vector store (`knowledge_processor` skipped for tabular types in `document_pipeline.py`); they go to catalog only. | +| 36 | Chat stream endpoint (`api/v1/chat.py`) | B | `[x]` | Rewired `/chat/stream` — replaced `query_executor.execute()` (Phase 1) with `CatalogReader + QueryService` (Phase 2). **Cleanup PR**: fully rewired to `ChatHandler.handle()`. Inline intent routing, retrieval, and answer generation removed. Redis cache, fast intent, history loading, and message persistence remain in chat.py. Sources event emits `[]` (retrieval not yet exposed by ChatHandler). | +| 37 | Room / users endpoints (`api/v1/room.py`, `api/v1/users.py`) | B | `[ ]` | No catalog work; only touch if auth flow changes | +| — | Data catalog index endpoint (`api/v1/data_catalog.py`) | DB | `[x]` | **KM-557** — `GET /api/v1/data-catalog/{user_id}` → `list[CatalogIndexEntry]`. **Cleanup PR** — added `POST /api/v1/data-catalog/rebuild?user_id=` → calls `on_catalog_rebuild_requested`; per-source errors logged but don't fail the request. | + +### Tests + eval + +| # | Item | Owner | Status | Notes | +|---|---|---|---|---| +| 38 | DB compiler golden tests (`tests/query/compiler/test_sql.py`) | DB | `[x]` | PR3-DB — 36 tests across all whitelisted ops, identifier quoting, agg / count_distinct / count(*), order_by alias resolution, parameter sequencing, error paths. Pure-Python, no LLM, no DB. | +| 39 | Pandas compiler golden tests (`tests/unit/query/compiler/test_pandas_compiler.py`) | TAB | `[~]` | PR3-TAB — 43 tests: all 12 filter ops, all 6 aggs, group_by, order_by, limit, aliases, empty DataFrame, error paths. `test_tabular_executor.py` adds 12 more (blob name resolution + happy path + error paths). | +| 40 | IR validator tests (`tests/query/ir/test_validator.py`) | B | `[x]` | PR1 — 19 tests, all rules covered | +| — | PII detector tests (`tests/catalog/test_pii_detector.py`) | B | `[x]` | PR1 — 26 tests (parametrized) | +| — | Catalog validator tests (`tests/catalog/test_validator.py`) | B | `[x]` | PR1 — 5 tests | +| — | Catalog render tests (`tests/catalog/test_render.py`) | B | `[x]` | **KM-557** — 5 tests (renamed from `test_enricher.py`; LLM enrichment tests dropped, render-only tests kept). | +| — | Catalog store integration test (`tests/catalog/test_store.py`) | DB | `[x]` | PR1 — module-level skip without `RUN_INTEGRATION_TESTS=1` | +| — | DB introspector test | DB | `[ ]` | Deferred to PR2 — needs Postgres testcontainer or fixture infra | +| — | Tabular introspector test | TAB | `[x]` | PR1-tab — 31 unit tests (CSV/XLSX/Parquet, stats, PII, error paths). No DB/blob I/O — mocks injected via constructor. | +| 41 | Planner eval (`tests/query/planner/`) | B | `[x]` | PR6-scaffold — `test_golden_questions.py` with 3 DB-targeting cases. TAB added `test_golden_tabular.py` with 4 tabular cases (group_by+sum, top-N+limit, date range filter, XLSX sheet selection). All 4 passed against real Azure OpenAI. Fix shipped alongside: `query/planner/service.py` replaced `("system", text)` tuple with `SystemMessage` — without this, `{...}` in `query_planner.md` was parsed as f-string variables and crashed on every real invocation. | +| 42 | E2E smoke tests (`tests/e2e/`) | B | `[ ]` | Defer until Phase 2 endpoints are wired (cleanup PR). Component-level orchestration is already covered by `test_chat_handler.py` + `test_service.py`. | +| — | Golden IR fixtures (`tests/fixtures/golden_irs.json`) | B | `[~]` | PR1 seeded with 5 DB-targeting examples; TAB extends in PR1-tab | +| — | Shared `sample_catalog` fixture (`tests/conftest.py`) | B | `[x]` | PR1 — DB-shaped; TAB may add tabular sibling | + +--- + +## What just shipped (2026-05-12 — Cleanup PR) + +**Phase 1 removal + Phase 2 API rewiring:** +- `src/api/v1/chat.py` — fully rewired to `ChatHandler.handle()`. Removed inline IntentRouter, retrieval, and ChatbotAgent calls. Redis cache, fast intent, load_history, save_messages stay in chat.py. +- `src/api/v1/db_client.py` — `/ingest` now calls only `on_db_registered`. Phase 1 `db_pipeline_service.run()` block removed. Catalog build failure now raises HTTP 500. +- `src/api/v1/data_catalog.py` — added `POST /api/v1/data-catalog/rebuild` endpoint. +- `src/pipeline/triggers.py` — `on_catalog_rebuild_requested` implemented: iterates catalog sources, re-runs the appropriate trigger per source type, per-source errors logged. + +**Dead modules deleted:** +- `src/agents/chatbot.py` (Phase 1 LangChain chatbot) +- `src/pipeline/orchestrator.py` (empty stub) +- `src/query/base.py` (old duplicate of `executor/base.py`) +- `src/api/v1/knowledge.py` (fake `/knowledge/rebuild` endpoint) +- `src/config/agents/` (folder — prompts only used by deleted Phase 1 chatbot) + +**Renames:** +- `src/agents/answer_agent.py` → `src/agents/chatbot.py`; `AnswerAgent` → `ChatbotAgent`; updated all import sites (`chat_handler.py`, `chat.py`) + +**Fixes + improvements:** +- `src/agents/chat_handler.py` — `_get_document_retriever()` now returns `RetrievalRouter` (Redis-cached) instead of `DocumentRetriever` directly; retrieval-level cache restored. +- `src/retrieval/router.py` — removed dead `db: AsyncSession` and `source_hint` parameters + `_UNSTRUCTURED_HINTS` constant from `retrieve()`. Cache key simplified. +- `src/knowledge/processing_service.py` — removed dead `_build_csv_documents`, `_build_excel_documents`, `_profile_dataframe`, `_to_sheet_document` methods + `pandas` and `upload_parquet` imports. +- `src/catalog/models.py` — added `top_values: list[Any] | None` to `ColumnStats`. +- `src/catalog/introspect/tabular.py` — `_to_column` now populates `top_values` for columns with ≤10 distinct values; useful for query planner WHERE clause generation. +- `main.py` — replaced deprecated `@app.on_event("startup")` with `lifespan` context manager; removed `knowledge_router`. + +--- + +## What just shipped (KM-557 — DB owner) + +After lead review of the catalog ingestion cost: dropped LLM enrichment, +renamed the storage table, and exposed a lightweight index endpoint for +the upcoming catalog refresher. + +**Files deleted**: +- `src/catalog/enricher.py` — entire CatalogEnricher + EnrichmentResponse + apply_descriptions removed +- `src/config/prompts/catalog_enricher.md` — dead prompt +- `tests/catalog/test_enricher.py` — replaced by `test_render.py` + +**Files added**: +- `src/catalog/render.py` — new home for `render_source` (the only piece of the old enricher still needed; consumed by `query/planner/prompt.py`) +- `src/api/v1/data_catalog.py` — `GET /api/v1/data-catalog/{user_id}` returns `list[CatalogIndexEntry]` +- `tests/catalog/test_render.py` — 5 tests (same coverage as the old render block) + +**Files modified**: +- `src/db/postgres/models.py` — `__tablename__ = "data_catalog"` (was `"catalogs"`). Class name unchanged +- `src/pipeline/structured_pipeline.py` — `StructuredPipeline(validator, store)` (was `(enricher, validator, store)`); pipeline is now `introspect → merge → validate → upsert`; `default_structured_pipeline()` no longer constructs an enricher +- `src/pipeline/triggers.py` — docstrings updated; `on_catalog_rebuild_requested` docstring rewritten for the refresher use case +- `src/query/planner/prompt.py` — import now `from ...catalog.render import render_source` +- `src/catalog/introspect/{base,database,tabular}.py` — docstring scrubs (no behavior changes) +- `src/models/api/catalog.py` — added `CatalogIndexEntry`; simplified `CatalogRebuildResponse` to `sources_rebuilt` +- `main.py` — registered `data_catalog_router` +- `src/security/README.md` — one stale wording fix + +**No migration**: the `data_catalog` table is created from scratch on first `init_db()`. The old `catalogs` table was never deployed against production data, so no rename SQL is needed. + +**Tests**: all 4 `test_structured_pipeline.py` tests reworked to construct `StructuredPipeline(validator=, store=)` without `enricher`. 5 `test_render.py` tests cover render_source standalone. + +**Lint**: `ruff check` clean on modified Phase 2 paths. + +**Open follow-ups left for the lead**: +- `on_catalog_rebuild_requested` body — the refresher will iterate the index endpoint and call this trigger per source +- `api/v1/db_client.py` `/ingest` still doesn't call `on_db_registered` — same blocker as before, untouched by KM-557 + +--- + +## What just shipped (2026-05-11 — retrieval migration + bug fixes) + +**Files implemented / migrated**: +- `src/retrieval/base.py` — `RetrievalResult` dataclass + `BaseRetriever` ABC (was in `src/rag/base.py`) +- `src/retrieval/document.py` — full `DocumentRetriever` migrated from `src/rag/retrievers/document.py`; all retrieval methods (MMR/cosine/euclidean/inner_product/manhattan). Tabular file types filtered out from results. +- `src/retrieval/router.py` — `RetrievalRouter` (Redis-cached, unstructured-only). `invalidate_cache(user_id)` clears all `retrieval:{user_id}:*` keys. + +**Deleted** (no longer used): +- `src/rag/` — entire folder (base.py, retriever.py, router.py, retrievers/) +- `src/tools/` — entire folder (search.py was the only real file; only called by deleted rag/ router) + +**Bug fixes**: +- `src/pipeline/document_pipeline.py` — `retrieval_router.invalidate_cache(user_id)` called after `process()` and `delete()`. Redis failure is caught and logged (does not fail the document op). +- `src/pipeline/document_pipeline.py` — CSV/XLSX now skips `knowledge_processor` (vector store). Tabular files go to catalog only; no duplicate embeddings. +- `src/pipeline/triggers.py` — `on_document_uploaded` implemented (was `raise NotImplementedError`). +- `src/agents/chat_handler.py` — `_normalize_chunks` now handles `RetrievalResult` objects. Previously they were silently dropped, causing empty context for unstructured queries through ChatHandler. + +**Import updates** (all changed from `src.rag.*` → `src.retrieval.*`): +- `src/api/v1/chat.py`, `src/query/base.py`, `src/query/query_executor.py`, `src/query/executors/db_executor.py`, `src/query/executors/tabular.py` + +--- + +## What shipped previously (PR2b/4/5/6/7-bundle — DB owner solo, teammate reviews) + +**Files implemented**: +- `src/agents/orchestration.py` — `OrchestratorAgent.classify(message, history) → IntentRouterDecision`. Pydantic model for structured output. History-aware query rewriting. Phase 1 filename + class name preserved; body fully rewritten for Phase 2. +- `src/agents/answer_agent.py` — `AnswerAgent.astream(...)` streams answer tokens; accepts `QueryResult` and/or `list[DocumentChunk]`. Renames to `chatbot.py` in cleanup PR. +- `src/agents/chat_handler.py` — `ChatHandler.handle(message, user_id, history)` returns `AsyncIterator[dict]` of `intent` / `chunk` / `done` / `error` SSE events. All deps injectable; lazy default builders. +- `src/query/planner/prompt.py` — `render_catalog(catalog)` + `build_planner_prompt(question, catalog, previous_error)`. Reuses `catalog.enricher.render_source` for consistency across LLM call sites. +- `src/query/planner/service.py` — `QueryPlannerService.plan(question, catalog, previous_error)` Azure OpenAI structured output → `QueryIR`. +- `src/query/executor/dispatcher.py` — `ExecutorDispatcher.pick(ir) → BaseExecutor` by `source.source_type`. Lazy executor imports + per-source-type cache. +- `src/query/service.py` — `QueryService.run(user_id, question, catalog) → QueryResult`. Plan→validate→retry-on-failure (max 3)→dispatch→execute. Catches NotImplementedError from TabularExecutor placeholder gracefully. + +**Prompts written** (filled in placeholders): +- `src/config/prompts/intent_router.md` +- `src/config/prompts/query_planner.md` +- `src/config/prompts/chatbot_system.md` +- `src/config/prompts/guardrails.md` + +**Tests added** (46 new — total now 146 + 2 skipped): +- `tests/agents/test_intent_router.py` (4) +- `tests/agents/test_answer_agent.py` (12) +- `tests/agents/test_chat_handler.py` (6) +- `tests/query/planner/test_prompt.py` (7) +- `tests/query/planner/test_service.py` (3) +- `tests/query/executor/test_dispatcher.py` (5) +- `tests/query/test_service.py` (8) +- `tests/query/planner/test_golden_questions.py` (3 — skipped by default; eval harness scaffold) + +**Lint**: `ruff check` clean on all Phase 2 paths. Phase 1 files have pre-existing E501/S608 issues — out of scope for this PR. + +**Placeholders / blockers for teammate** (status as of DB owner's commit, before merge): +- `src/query/executor/tabular.py` (TAB) — DB owner's note: "still raises NotImplementedError". **Post-merge**: TAB shipped this in PR3-TAB; dispatcher now routes to the real `TabularExecutor`. The `NotImplementedError` catch in `QueryService` stays as a safety net. +- `src/retrieval/document.py` — **implemented** (2026-05-11). Full `DocumentRetriever` migrated from `src/rag/retrievers/document.py`; supports MMR/cosine/euclidean/manhattan/inner_product. `_normalize_chunks` in `chat_handler.py` now handles `RetrievalResult` → `DocumentChunk` conversion correctly. +- `src/api/v1/chat.py` (Phase 1) — NOT touched. Cleanup PR rewires the SSE endpoint to call `ChatHandler.handle(...)`. +- `src/api/v1/db_client.py` (Phase 1) — NOT touched. Cleanup PR rewires `/database-clients/{id}/ingest` to call `pipeline.triggers.on_db_registered`. + +--- + +## What shipped previously (PR3-TAB — TAB owner) + +**Files implemented**: +- `src/query/compiler/pandas.py` — `PandasCompiler` + `CompiledPandas(apply, output_columns)` dataclass. Pure helper functions (easier to test in isolation): `_apply_filters` (all 12 ops, `_like_to_regex` for LIKE), `_apply_select` (column pick + rename), `_apply_agg` (scalar + group_by via `pd.concat` of Series → `reset_index`), `_apply_orderby` (alias-aware via `_resolve_order_col`). Closure captures all IR fields explicitly so `apply(df)` is self-contained. +- `src/query/executor/tabular.py` — `TabularExecutor` with injectable `fetch_blob` (same testability pattern as `TabularIntrospector`). Resolves Parquet blob path from `az_blob://{uid}/{did}` + table: single-table → `{uid}/{did}.parquet`, multi-table → `{uid}/{did}__{table.name}.parquet`. Runs compile → download → `asyncio.to_thread(_load_and_apply)` → 10k hard cap. Never raises; errors populate `QueryResult.error`. Uses `compiled.output_columns` for column labels (safe on empty DataFrame). + +**Tests added** (55 new — total suite was 86 all passing at PR3-TAB time): +- `tests/unit/query/compiler/test_pandas_compiler.py` — 43 tests across all 12 filter ops (including `is_null`, `not_in`, `like`, `between`), all 6 agg fns, group_by, order_by asc/desc, limit-after-order, alias round-trip, empty DataFrame, error paths. +- `tests/unit/query/executor/test_tabular_executor.py` — 12 tests: `_resolve_blob_name` (single/multi-table, bad prefix), happy-path `QueryResult` shape (columns, rows, backend, truncated, source_id), wrong source_type → error, blob fetch failure → error, unknown source → error. + +**Lint**: `ruff check` clean on both files. + +--- + +## What shipped previously (PR1-tab — TAB owner) + +**Files implemented**: +- `src/catalog/introspect/tabular.py` — `TabularIntrospector` reads original blob (CSV/XLSX/Parquet), profiles each column (dtype, stats, sample values), runs PIIDetector. For XLSX: one `Table` per sheet (`Table.name = sheet_name`); for CSV/Parquet: one `Table` (`Table.name = filename stem`). `fetch_doc`/`fetch_blob` are constructor-injectable for unit tests — no `Settings` or DB required at import time. +- `src/pipeline/triggers.py` — `on_tabular_uploaded` wired (mirrors `on_db_registered` pattern). + +**Tests added** (31 new): +- `tests/unit/catalog/test_introspect_tabular.py` — CSV / XLSX / Parquet shapes, per-column stats, nullable detection, PII name + value matching, sample capping, all error paths. Pure Python, no network I/O. + +**Executor contract note**: introspector downloads the *original* blob for schema reading. The tabular executor (PR3-TAB) downloads *Parquet* blobs for query execution. For CSV/Parquet sources (single table), the executor must call `parquet_blob_name(uid, did, sheet_name=None)`; for XLSX (multi-table), `parquet_blob_name(uid, did, table.name)`. + +--- + +## What shipped previously (PR3-DB — DB owner) + +**Files implemented**: +- `src/query/compiler/sql.py` — `SqlCompiler` for Postgres dialect; `CompiledSql(sql, params)` dataclass with `params: dict[str, Any]` (changed from `list`); supports all 12 whitelisted filter ops, all 6 aggs, alias-aware order_by; `_qident` escapes embedded double-quotes +- `src/query/executor/db.py` — `DbExecutor` with sqlglot SELECT-only guard, Postgres session-level read-only + 30s `statement_timeout`, `asyncio.wait_for` backstop, 10k row hard cap; rejects non-`schema` source_type and `dbclient://` URI mismatch; never raises (populates `QueryResult.error`) + +**Files extended**: +- `src/query/compiler/pandas.py` — fixed pre-existing UP035 (Callable import) +- `pyproject.toml` — added `S608` to `tests/**` ruff ignore (false positive: tests assert literal SQL strings) + +**Tests added** (36 new, all passing — total now 100): +- `tests/query/compiler/test_sql.py` — every filter op, every agg, count(*), count_distinct, order_by alias vs column, multi-filter AND, identifier quoting escape, error paths + +**Lint**: `ruff check` clean on Phase 2 paths. + +**Hand-off note for teammate**: `CompiledSql.params` is now `dict[str, Any]` not `list`. The pandas compiler will follow the same convention (or document its own) — coordinate when PR3-TAB lands. + +--- + +## What shipped previously (PR2a — DB owner) + +**Files implemented**: +- `src/catalog/enricher.py` — Azure OpenAI GPT-4o + structured output (`EnrichmentResponse`), `render_source` (reusable by planner prompt later), `apply_descriptions` merger, injectable `structured_chain` for tests +- `src/pipeline/structured_pipeline.py` — `StructuredPipeline` orchestrator + `default_structured_pipeline()` factory with lazy production-dep imports +- `src/pipeline/triggers.py` — `on_db_registered` wired; tabular/document/rebuild stubs preserved with implementation notes + +**Files extended**: +- `src/catalog/models.py` — added `ForeignKey` model, `Table.foreign_keys: list[ForeignKey] = []` +- `src/catalog/introspect/database.py` — `_extract_foreign_keys` populates `Table.foreign_keys` from extractor data +- `src/config/prompts/catalog_enricher.md` — full system prompt with style rules and one few-shot example + +**Tests added** (14 new, all passing — total now 64): +- `tests/catalog/test_enricher.py` — render / apply / end-to-end with fake chain (10 tests) +- `tests/pipeline/test_structured_pipeline.py` — orchestration with stub deps (4 tests) + +**Lint**: `ruff check` clean on all Phase 2 paths. Phase 1 files (`pipeline/db_pipeline/`, `pipeline/document_pipeline/`) have pre-existing ruff issues — out of scope for this PR. + +--- + +## What shipped previously (PR1 — DB owner's first chunk) + +**Files implemented** (was `NotImplementedError`): +- `src/catalog/pii_detector.py`, `src/catalog/validator.py`, `src/catalog/store.py`, `src/catalog/reader.py` +- `src/catalog/introspect/database.py` (FK extraction added in PR2a) +- `src/query/ir/validator.py` + +**Files extended**: +- `src/query/ir/operators.py` — `TYPE_COMPATIBILITY` matrix +- `src/catalog/models.py` — `location_ref` URI-scheme docstring +- `src/db/postgres/models.py` — `Catalog` SQLAlchemy table; `init_db.py` imports it + +**Tests**: 50 unit tests + 1 integration (gated on `RUN_INTEGRATION_TESTS=1`). + +**Reused Phase 1 utilities** (cleanup deferred): +- `src/database_client/database_client_service.py:get` +- `src/utils/db_credential_encryption.py:decrypt_credentials_dict` +- `src/pipeline/db_pipeline/db_pipeline_service.py:engine_scope` +- `src/pipeline/db_pipeline/extractor.py:get_schema/profile_column/get_row_count` + +--- + +## Open contract items (not yet locked) + +- **Joins in IR** — currently single-table only (ARCHITECTURE.md §7); DB owner accepted the constraint for v1, will revisit in PR3 if it's blocking real queries +- **`updated_at` on Source vs `generated_at` on Catalog** — Pydantic models have both; introspector sets per-Source; CatalogStore preserves both +- **Catalog refresh trigger** (open question §3) — default policy is rebuild-on-upload-or-connect; auto-refresh deferred +- **Unstructured catalog entries** (open question §2) — currently empty filter for `source_hint="unstructured"`; revisit when adding doc descriptions +- **PII handling for `sample_values`** (open question §5) — currently nulls them out (skip); mask/synthesize deferred +- **Dialect priority for SQL compiler** — PR3 will land Postgres first, MySQL second; BigQuery/Snowflake/SQL Server later + +--- + +## How to update this file + +When a PR lands: +1. Flip status from `[ ]` or `[~]` to `[x]` +2. Add a short note (file paths, scope cuts, surprises) +3. Bump "Last updated" at the top +4. If a new contract decision lands, move it from "Open contract items" to the relevant inline note + +When opening a PR: +1. Flip status to `[~]` and add yourself as the active owner in the PR row +2. Don't promise items in the PR description that aren't in the table diff --git a/REPO_CONTEXT.md b/REPO_CONTEXT.md new file mode 100644 index 0000000000000000000000000000000000000000..404a573f22f8a45035343c640cbdff0f6577293a --- /dev/null +++ b/REPO_CONTEXT.md @@ -0,0 +1,474 @@ +# Repo Context — Agentic Service Data Eyond Catalog + +Orientation file for future Claude Code sessions. Cross-reference `ARCHITECTURE.md` for the full design rationale and decision log. + +--- + +## TL;DR + +FastAPI multi-agent backend for data analysis. Users upload documents and register databases / tabular files; they ask natural-language questions and get answers grounded in their data, streamed via SSE. + +The architecture has two paths: + +- **Unstructured** (PDF, DOCX, TXT) — dense similarity over prose chunks (PGVector). +- **Structured** (databases, XLSX, CSV, Parquet) — a per-user **data catalog** describes what tables/columns exist; an LLM produces a **JSON IR** of intent; a deterministic Python compiler turns the IR into SQL or pandas; the executor runs it. + +The LLM produces *intent*, not query syntax. Deterministic code does the rest. + +The Phase 2 end-to-end flow is **wired and runnable** as of 2026-05-12. See *Implementation status* below for the per-file matrix. `PROGRESS.md` is the authoritative line-by-line tracker; this file is the orientation. + +--- + +## Stack + +- Python 3.12, FastAPI 0.115, uvicorn, sse-starlette +- Async SQLAlchemy 2.0 + asyncpg (Postgres), psycopg3 (PGVector multi-statement workaround) +- LangChain 0.3 + langchain-postgres (PGVector) + langchain-openai (Azure OpenAI GPT-4o + embeddings) +- LangGraph 0.2 + langgraph-checkpoint-postgres +- Redis 5 (response + retrieval cache) +- Azure Blob Storage (uploads + Parquet) +- pandas, pyarrow, polars-ready (deferred), sqlglot, pydantic v2, structlog, slowapi, langfuse +- presidio-analyzer + spaCy `en_core_web_lg` (PII), pytesseract + pdf2image (PDF OCR) +- DB connectors: psycopg2, pymysql, pymssql, sqlalchemy-bigquery, snowflake-sqlalchemy + +Run: `uv run --no-sync uvicorn main:app --host 0.0.0.0 --port 7860`. On Windows use `uv run --no-sync python run.py` (sets `WindowsSelectorEventLoopPolicy` for psycopg3 async). + +--- + +## Top-level layout + +``` +main.py — FastAPI app + middleware + router wiring + init_db() on startup +run.py — Windows-safe local entry point +ARCHITECTURE.md — design intent (source of truth for shape + invariants) +README.md +Dockerfile — python:3.12-slim, installs spaCy en_core_web_lg, tesseract, poppler +pyproject.toml / uv.lock +scripts/ — backfill scripts (build_initial_catalogs, enrich_all_sources) +src/ — all application code +``` + +--- + +## src/ map + +### Core data shapes (only files with real content) + +| Path | Role | +|---|---| +| `catalog/models.py` | Pydantic: `Catalog → Source[] → Table[] → Column[]` | +| `query/ir/models.py` | `QueryIR` (select / filters / group_by / order_by / limit) | +| `query/ir/operators.py` | `ALLOWED_FILTER_OPS`, `ALLOWED_AGG_FNS`, `LIMIT_HARD_CAP=10000` | +| `security/pii_patterns.py` | name patterns + email/phone regex for PII detection | + +### Catalog — identity layer for structured sources (Cs ∪ Ct) + +| Path | Role | +|---|---| +| `catalog/introspect/base.py` | `BaseIntrospector.introspect(location_ref) -> Source` | +| `catalog/introspect/database.py` | `information_schema` + ~100 row sample → draft Source | +| `catalog/introspect/tabular.py` | Parquet/CSV/XLSX header reader + sample (one Table per sheet for XLSX) | +| `catalog/render.py` | renders a `Source` as the canonical text block consumed by the planner (KM-557; LLM enrichment removed — planner reads stats + samples directly) | +| `catalog/validator.py` | invariants beyond Pydantic shape (unique IDs, FK refs) | +| `catalog/store.py` | persist as Postgres `jsonb` row keyed by user_id (`get/upsert/delete`) — table `data_catalog` | +| `catalog/reader.py` | load + filter catalog by source_hint (returns full catalog for ≤50 tables) | +| `catalog/pii_detector.py` | flag PII columns at ingestion → suppresses `sample_values` | + +### Query — catalog-driven structured path + +| Path | Role | +|---|---| +| `query/service.py` | `QueryService.run(user_id, question, catalog) -> QueryResult` (top-level) | +| `query/planner/service.py` | LLM call: question + catalog → QueryIR (structured output) | +| `query/planner/prompt.py` | renders catalog into the planner prompt | +| `query/ir/validator.py` | catalog-aware IR validation: column_ids exist, ops whitelisted, value_type matches data_type, limit ≤ cap | +| `query/compiler/base.py` | `BaseCompiler.compile(ir) -> object` | +| `query/compiler/sql.py` | IR → `(sql, params)`; identifiers from catalog, values parameterized | +| `query/compiler/pandas.py` | IR → callable that runs against a DataFrame | +| `query/executor/base.py` | `BaseExecutor.run(ir) -> QueryResult` (uniform across backends) | +| `query/executor/db.py` | runs compiled SQL via asyncpg/pymysql in read-only txn (sqlglot second-line defence) | +| `query/executor/tabular.py` | runs pandas/polars chain on a Parquet file (eager pandas → pyarrow pushdown → polars lazy by file size) | +| `query/executor/dispatcher.py` | picks DB vs Tabular executor based on `source.source_type` of the IR's source | + +### Retrieval — unstructured path (Cu) + +| Path | Role | +|---|---| +| `retrieval/document.py` | `DocumentRetriever` over PGVector chunks | +| `retrieval/router.py` | dispatches the `unstructured` route (the `chat` and `structured` routes do not pass through here) | + +### Agents — the three LLM call sites + +| Path | Role | +|---|---| +| `agents/orchestration.py` | `OrchestratorAgent` — classifies message → `needs_search`, `source_hint ∈ {chat, unstructured, structured}`, `rewritten_query`. Filename + class name kept from Phase 1; body replaced with Phase 2 logic. Output model is `IntentRouterDecision` | +| `agents/chatbot.py` | `ChatbotAgent` — final answer formation (receives Cu chunks or QueryResult); SSE-streamed via `astream` | +| `agents/chat_handler.py` | `ChatHandler` — top-level orchestrator; routes to chat / unstructured / structured and yields SSE-style `intent`/`chunk`/`done`/`error` events | + +(`QueryPlanner` is the third LLM call site, under `query/planner/`. The +fourth — `CatalogEnricher` — was removed in KM-557; ingestion no longer +makes any LLM calls.) + +### Pipelines — ingestion coordinators + +| Path | Role | +|---|---| +| `pipeline/structured_pipeline.py` | DB / tabular: introspect → merge → validate → store (no enrich step since KM-557) | +| `pipeline/document_pipeline.py` | unstructured: extract → chunk → embed → PGVector. CSV/XLSX skip vector store (catalog only). Invalidates retrieval cache on process/delete. | +| `pipeline/triggers.py` | event entry points called by API routes: `on_db_registered`, `on_tabular_uploaded`, `on_document_uploaded`, `on_catalog_rebuild_requested` | + +(`pipeline/orchestrator.py` was deleted in the Cleanup PR — it was a redundant stub; `StructuredPipeline` already takes the introspector at `run()` time.) + +### Security — cross-cutting + +| Path | Role | +|---|---| +| `security/auth.py` | bcrypt password hash/verify, JWT encode/decode, get_user | +| `security/credentials.py` | Fernet encrypt/decrypt for stored DB credentials | +| `security/pii_patterns.py` | (already listed) | + +### API + infra + config + +| Path | Role | +|---|---| +| `api/v1/*.py` | FastAPI routers — thin endpoints delegating to `pipeline/triggers` and `query/service` | +| `models/api/{catalog,chat,document}.py` | request/response Pydantic models | +| `db/postgres/connection.py` | two async engines: `engine` (app) and `_pgvector_engine` (PGVector) | +| `db/postgres/init_db.py` | startup: creates `vector` extension, all tables, HNSW + GIN indexes | +| `db/postgres/models.py` | SQLAlchemy app tables (users, rooms, chat messages, …) | +| `db/postgres/vector_store.py` | shared PGVector instance (collection `document_embeddings`) | +| `db/redis/connection.py` | async Redis client | +| `storage/az_blob/az_blob.py` | Azure Blob async wrapper (uploads + Parquet) | +| `middlewares/{cors,logging,rate_limit}.py` | CORS allow-all (POC), structlog JSON, slowapi | +| `observability/langfuse/langfuse.py` | trace helper | +| `config/settings.py` | pydantic-settings; `.env` uses double-underscore aliases | +| `config/env_constant.py` | env file path constant | +| `config/prompts/*.md` | prompt templates: `intent_router`, `query_planner`, `chatbot_system`, `guardrails` (KM-557 removed `catalog_enricher`) | + +--- + +## Core architectural decisions + +1. **Catalog as primary context, not retrieval.** For ≤50 tables (typical), the entire catalog is rendered into the planner prompt verbatim (~3–5k tokens). No vector search, no BM25, no top-k for structured data. Catalog-level retrieval (BM25 + table-level vectors with RRF) is the *deferred* upgrade for users with hundreds of tables. + +2. **JSON IR over raw SQL.** The planner LLM emits a Pydantic-validated intent, never a SQL string. The compiler is deterministic Python. Benefits: validatable before execution, dialect-portable (one IR → SQL of any dialect / pandas / polars), cheaper tokens, trivially testable without an LLM, and the LLM literally cannot emit invalid SQL syntax. + +3. **Deterministic compiler, not LLM SQL writer.** All actual query construction happens in pure code. Compiler bugs are reproducible and fixable. Same IR → same query. + +4. **Pipeline stage isolation.** Each stage (`IntentRouter`, `CatalogReader`, `QueryPlanner`, `IRValidator`, `QueryCompiler`, `QueryExecutor`, `ChatbotAgent`) is its own module with typed input and typed output. No god classes. + +5. **Minimal LLM surface.** Only three LLM call sites in the system (KM-557 dropped `CatalogEnricher` — ingestion is now LLM-free; the planner reads stats + sample rows + column names directly): + - `IntentRouter` — once per user message + - `QueryPlanner` — once per structured query + - `ChatbotAgent` — once per answer (formatting) + +6. **Three-way routing**: `chat` / `unstructured` / `structured`. The router commits to one path. Cross-source questions ("compare DB sales vs uploaded customer file") are handled inside the structured path because the planner sees Cs ∪ Ct in one prompt. **DB vs tabular is not a routing concern** — it's a per-source attribute (`source_type`) that only matters at execution time. + +7. **Stable IDs.** `source_id`, `table_id`, `column_id` are stable internal references. Renaming a column in the source DB does not invalidate cached IRs. + +8. **PII suppression at the boundary.** Columns flagged with `pii_flag=true` have `sample_values: null` — real PII never enters LLM prompts. Auto-detected at ingestion via name patterns + value regex (`security/pii_patterns.py`). When in doubt, flag — false positives cost nothing; false negatives leak data. + +--- + +## End-to-end flows + +### Ingestion (when user uploads a file or connects a DB) + +``` +source upload / DB connect + │ + ├── unstructured (pdf/docx/txt) + │ → DocumentPipeline: extract → chunk → embed → PGVector + │ + └── structured (DB schema or tabular file) + → introspect (information_schema or file headers + sample rows) + → CatalogValidator (Pydantic + unique-IDs + FK refs) + → CatalogStore.upsert(user_id jsonb row in `data_catalog`) +``` + +### Query (per user message) + +``` +user message + │ + → Redis cache check (24h TTL) ── miss ─→ continue + → + → IntentRouter LLM → needs_search? source_hint? + │ + ├── chat → ChatbotAgent → SSE stream + ├── unstructured → DocumentRetriever (Cu) → ChatbotAgent → SSE stream + └── structured → + CatalogReader.read(user_id, "structured") # full Cs ∪ Ct + ↓ + QueryPlanner LLM(question, catalog) → QueryIR + ↓ + IRValidator.validate(ir, catalog) + (source_id ∈ catalog, table_id ∈ source, column_ids ∈ table, + ops/aggs whitelisted, value_type matches data_type, limit ≤ 10000) + fail → re-prompt planner with error context (max 3 retries) + ↓ + ExecutorDispatcher.pick(ir) # by source.source_type + ├─ DbExecutor → SqlCompiler → sqlglot guard → asyncpg/pymysql + │ (read-only txn, 30s timeout) + └─ TabularExecutor → PandasCompiler → eager pandas (≤100 MB) + or pyarrow pushdown (100 MB–1 GB) + or polars lazy scan (>1 GB) + ↓ + QueryResult + ↓ + ChatbotAgent → SSE stream +``` + +--- + +## Catalog schema (per-user `jsonb` row) + +``` +Catalog +├── user_id, schema_version, generated_at +└── sources[] + └── Source { source_id, source_type, name, description, location_ref, updated_at } + └── tables[] + └── Table { table_id, name, description, row_count, foreign_keys[] } + ├── columns[] + │ └── Column { column_id, name, data_type, description, + │ nullable, pii_flag, sample_values[]|null, stats|null } + └── foreign_keys[] + └── ForeignKey { column_id, target_table_id, target_column_id } +``` + +`source_type ∈ {schema, tabular, unstructured}`. +`data_type ∈ {int, decimal, string, datetime, date, bool, json}`. +`ForeignKey` references are within the SAME `Source` only; cross-source FKs are not modeled. + +Deferred Column fields (add when justified): `description_human`, `synonyms[]`, `tags[]`, `primary_key`, `unit`, `semantic_type`, `example_questions[]`, `schema_hash`, `enrichment_status`. + +--- + +## JSON IR schema + +```jsonc +{ + "ir_version": "1.0", + "source_id": "...", + "table_id": "...", + "select": [ + {"kind": "column", "column_id": "...", "alias": "..."}, + {"kind": "agg", "fn": "count|count_distinct|sum|avg|min|max", + "column_id": "...?", "alias": "..."} + ], + "filters": [ + {"column_id": "...", + "op": "= | != | < | <= | > | >= | in | not_in | is_null | is_not_null | like | between", + "value": ..., + "value_type": "int|decimal|string|datetime|date|bool"} + ], + "group_by": ["column_id", ...], + "order_by": [{"column_id": "...", "dir": "asc|desc"}], + "limit": 100 +} +``` + +Single-table only in v1. `having`, `offset`, boolean filter trees, `distinct`, joins, window functions are deferred until user demand proves the limitation. + +--- + +## Implementation status + +**As of 2026-05-12 — Phase 2 end-to-end flow is wired.** `PROGRESS.md` has the per-PR line-item table; this section is the high-level snapshot. Stub files (`raise NotImplementedError`) are now the exception, not the rule. + +| Area | Status | Notes | +|---|---|---| +| Catalog Pydantic models | ✅ | `catalog/models.py` — incl. `ForeignKey`, `ColumnStats.top_values` | +| JSON IR Pydantic models | ✅ | `query/ir/models.py` + `operators.py` (TYPE_COMPATIBILITY filled) | +| Catalog ingestion — DB | ✅ | introspect → validate → upsert. `on_db_registered` wired; `/api/v1/db-clients/{id}/ingest` calls it | +| Catalog ingestion — tabular | ✅ | CSV/XLSX/Parquet; `on_tabular_uploaded` wired into `/api/v1/document/process`. XLSX → one Table per sheet. CSV/XLSX skip vector store | +| Catalog ingestion — unstructured | ✅ | `on_document_uploaded` implemented; full DocumentPipeline (extract → chunk → embed → PGVector) | +| Catalog store / reader / validator / PII detector | ✅ | `data_catalog` jsonb table (renamed from `catalogs` in KM-557) | +| LLM enrichment | ❌ removed (KM-557) | Cost cut — planner reads `column.stats` + `sample_values` + `top_values` + `column.name` directly. `catalog/render.py` keeps the source-rendering helper | +| `IntentRouter` (lives as `OrchestratorAgent` in `agents/orchestration.py`) | ✅ | 3-way `source_hint`, history-aware query rewriting. Filename + class name kept from Phase 1; Phase 2 body | +| `CatalogReader` | ✅ | Loads full catalog; filters by `source_hint` | +| `QueryPlanner` LLM call | ✅ | Azure OpenAI structured output → `QueryIR`; supports retry with `previous_error` | +| IR validator | ✅ | Catalog-aware; full rule set; descriptive errors | +| SQL compiler (Postgres) | ✅ | All 12 filter ops, all 6 aggs, alias-aware order_by, parameterized values, quoted identifiers | +| DbExecutor | ✅ | sqlglot SELECT-only guard, RO txn, `statement_timeout=30000`, 10k row cap, never raises | +| Pandas compiler | ✅ | Same op coverage as SQL; pure module-level helpers | +| TabularExecutor | ✅ | Parquet blob path resolution, `asyncio.to_thread`, 10k cap, never raises | +| ExecutorDispatcher | ✅ | Routes by `source.source_type`; lazy imports + cache | +| QueryService | ✅ | plan → validate → retry-on-fail (max 3) → dispatch → execute → `QueryResult` | +| `ChatbotAgent` + prompt + guardrails | ✅ | Renamed from `AnswerAgent` in Cleanup PR. Guardrails appended to `chatbot_system.md` | +| `ChatHandler` (top-level chat orchestrator) | ✅ | SSE events: `intent` / `chunk` / `done` / `error` | +| `DocumentRetriever` + `RetrievalRouter` (Redis-cached) | ✅ | Migrated from `src/rag/` (now deleted); MMR/cosine/euclidean/manhattan/inner_product | +| `/api/v1/chat/stream` | ✅ | Rewired to `ChatHandler`; Redis cache + fast intent + history + message persistence remain in chat.py | +| `/api/v1/db-clients/{id}/ingest` | ✅ | Calls only `on_db_registered`; Phase 1 dual-write removed | +| `/api/v1/document/{upload,process,delete}` | ✅ | `/process` triggers `on_tabular_uploaded` for CSV/XLSX | +| `GET /api/v1/data-catalog/{user_id}` | ✅ | Index endpoint (KM-557) | +| `POST /api/v1/data-catalog/rebuild` | ✅ | Iterates sources, re-runs per-source trigger | +| Credential encryption | ⚠️ stub | `security/credentials.py` not migrated; runtime reuses Phase 1 `utils/db_credential_encryption.py` | +| Tests | ✅ 146+ unit | Compilers (DB 36, Pandas 43), validators, introspectors, agents, chat handler, dispatcher, planner | +| Planner eval harness | 🟡 scaffold | 3 DB + 4 tabular golden cases. Gated on `RUN_PLANNER_EVAL=1`. Real Azure OpenAI passing | +| E2E smoke tests | ❌ not started | Component-level orchestration is covered | +| DB introspector unit test | ❌ deferred | Needs Postgres testcontainer | +| Sources event in `/chat/stream` | ⚠️ emits `[]` | `ChatHandler` doesn't surface retrieval sources yet; same gap reflected in `save_messages` | + +**Deferred to later phases**: joins in IR, schema drift detection, hybrid catalog search (BM25 + RRF for 100+ table users), polars lazy scan for >1GB tabular files, MySQL/BigQuery/Snowflake SQL dialects, mask/synthesize PII strategies. + +--- + +## Team — division of work + +The service is built by two engineers; many modules are source-type-agnostic and shared. + +- **DB** owns SQL paths: introspection, SQL compiler, DB executor, credential storage. +- **TAB** owns tabular paths: CSV/XLSX/Parquet introspection, pandas compiler, tabular executor, blob/Parquet plumbing. +- **B** = both — shared contracts and source-type-agnostic plumbing. Pair-program or split with explicit hand-off. + +### Step-by-step ownership + +| # | Step | File / area | Owner | Notes | +|---|---|---|---|---| +| 0 | **Lock contracts before coding** | — | B | See "Decisions to lock" below; block until aligned | +| 1 | Catalog Pydantic models | `catalog/models.py` | B | Already done; only touch if both agree | +| 2 | IR Pydantic models | `query/ir/models.py` | B | Already done; joins/window fns require joint sign-off | +| 3 | IR operator whitelists | `query/ir/operators.py` | B | Already done; both compilers rely on these | +| 4 | PII patterns / regex | `security/pii_patterns.py` | B | Already done; extend together as gaps appear | +| **Ingestion — introspection** | | | | | +| 5 | DB introspector (information_schema, sample, FKs) | `catalog/introspect/database.py` | DB | Use SQLAlchemy `inspect()`; dialect-aware quoting | +| 6 | Tabular introspector (CSV/XLSX/Parquet headers + sample) | `catalog/introspect/tabular.py` | TAB | Each XLSX sheet → one Table | +| 7 | `BaseIntrospector` ABC | `catalog/introspect/base.py` | B | Confirm signature returns the same `Source` shape | +| **Ingestion — shared catalog plumbing** | | | | | +| 8 | ~~Catalog enricher + prompt~~ | — | **REMOVED in KM-557.** Cost optimization — planner reads stats + sample rows directly. `catalog/render.py` keeps the source-rendering helper. | +| 9 | Catalog validator | `catalog/validator.py` | B | Type-agnostic | +| 10 | Catalog store (Postgres jsonb) | `catalog/store.py` | B | Recommend DB (Postgres expertise) | +| 11 | Catalog reader | `catalog/reader.py` | B | Type-agnostic | +| 12 | PII detector | `catalog/pii_detector.py` | B | Either; uses `pii_patterns.py` | +| **Ingestion — pipelines** | | | | | +| 13 | Structured pipeline (introspect → enrich → validate → store) | `pipeline/structured_pipeline.py` | B | Pair on this — calls both introspectors via dispatcher | +| 14 | Triggers (`on_db_registered`, `on_tabular_uploaded`) | `pipeline/triggers.py` | B | Each owns their trigger function | +| 15 | Ingestion orchestrator | `pipeline/orchestrator.py` | B | Routes by source_type; pair | +| 16 | Document pipeline (PDF/DOCX/TXT) | `pipeline/document_pipeline.py` | TAB | Tabular-adjacent (file uploads) | +| **Query — shared spine** | | | | | +| 17 | IR validator (catalog-aware) | `query/ir/validator.py` | B | Recommend DB; both must agree on exact error messages so retry-prompt is consistent | +| 18 | Planner LLM service | `query/planner/service.py` | B | Type-agnostic | +| 19 | Planner prompt (catalog → text) | `query/planner/prompt.py`, `config/prompts/query_planner.md` | B | **Pair-program**. Must describe DB tables and tabular files in one consistent format | +| 20 | Intent router (chat/unstructured/structured) | `agents/orchestration.py` (class `OrchestratorAgent` — Phase 1 filename + class name preserved; Phase 2 body), `config/prompts/intent_router.md` | B | Type-agnostic. The prompt file uses `intent_router.md`, but the source module is still `orchestration.py` | +| 21 | Executor base + `QueryResult` | `query/executor/base.py` | B | Lock the shape before either implements an executor | +| 22 | Executor dispatcher | `query/executor/dispatcher.py` | B | Reads `source.source_type` from catalog; pair | +| 23 | Compiler base ABC | `query/compiler/base.py` | B | Already done | +| 24 | Top-level QueryService | `query/service.py` | B | Wires planner → validator → compiler → executor; pair | +| **Query — DB path** | | | | | +| 25 | SQL compiler (IR → SQL + params, per dialect) | `query/compiler/sql.py` | DB | Identifiers from catalog (quoted), values parameterized | +| 26 | DB executor (asyncpg/pymysql, sqlglot guard, RO txn, 30s timeout) | `query/executor/db.py` | DB | | +| 27 | Credential encryption (Fernet) | `security/credentials.py` | DB | Needed for stored user DB creds | +| 28 | User-DB connection management | helper in pipelines | DB | engine_scope context manager pattern | +| **Query — Tabular path** | | | | | +| 29 | Pandas compiler (IR → callable on DataFrame) | `query/compiler/pandas.py` | TAB | Same IR, different backend | +| 30 | Tabular executor (eager pandas first; pyarrow / polars later) | `query/executor/tabular.py` | TAB | Initial scope: eager pandas only | +| 31 | Parquet upload/download + Azure Blob wrapper | `storage/az_blob/az_blob.py` (+ helper) | TAB | XLSX sheet → one Parquet per sheet (deterministic blob name) | +| **Agents + chat** | | | | | +| 32 | Chatbot agent + prompt | `agents/chatbot.py`, `config/prompts/chatbot_system.md` | B | Receives QueryResult or Cu chunks | +| 33 | Guardrails prompt | `config/prompts/guardrails.md` | B | | +| **API surface** | | | | | +| 34 | DB client endpoints (register/ingest/list/delete) | `api/v1/db_client.py` | DB | | +| 35 | Document/tabular upload endpoints | `api/v1/document.py` | TAB | | +| 36 | Chat stream endpoint (SSE) | `api/v1/chat.py` | B | Dispatches both paths; pair | +| 37 | Room / users endpoints | `api/v1/room.py`, `api/v1/users.py` | B | Whoever has bandwidth | +| **Tests + eval** | | | | | +| 38 | DB compiler golden tests (IR → SQL fixtures) | `tests/query/compiler/test_sql.py` | DB | Pure-Python, no LLM | +| 39 | Pandas compiler golden tests (IR → expected DataFrame) | `tests/query/compiler/test_pandas.py` | TAB | Pure-Python, no LLM | +| 40 | IR validator tests (catalog × IR error matrix) | `tests/query/ir/test_validator.py` | B | Each contributes test cases for their source type | +| 41 | Planner eval (golden question → IR examples) | `tests/query/planner/` | B | Each contributes ~10 question→IR examples | +| 42 | E2E smoke tests | `tests/e2e/` | B | Pair | + +### Decisions to lock before coding + +If made unilaterally these create silent contract drift. Lock them in a 30-min sync first. + +| Decision | Why it matters | Recommended call | +|---|---|---| +| `QueryResult` shape (current scaffold: `source_id, backend, rows, row_count, truncated, elapsed_ms, error`) | Both executors return this; chatbot consumes it | Lock as-is unless either side needs more (e.g. `column_types` for formatting) | +| `Source.location_ref` format (`az_blob://...` vs `dbclient://{id}` etc.) | Dispatcher and executors both parse this | Pick a convention now; document in `catalog/models.py` docstring | +| Where do user DB credentials live? | DB executor needs creds to run queries; Source has `location_ref` but creds are encrypted separately | Recommend: `location_ref="dbclient://{client_id}"`; executor looks up creds by ID | +| How does dispatcher pick the executor? | Routes by `source.source_type` — but where does dispatcher get it (catalog reload, or IR carries it)? | Recommend: dispatcher takes `(Catalog, IR)`, looks up source by `IR.source_id` | +| Joins in v1 IR? | Excluded per ARCHITECTURE.md §7. DB path is most affected — real DB use often needs joins. | Recommend: ship single-table; revisit in PR 2. **DB owner must accept the constraint or push back early** | +| Planner prompt — render tabular vs DB sources uniformly | If described differently, planner gets confused | Pair-program. Render both as `Table: name (n rows) — Columns: ...` regardless of source_type | +| Error contract — raise or return `QueryResult.error`? | Both executors must behave the same so chatbot branches consistently | Recommend: never raise from `executor.run()`; populate `QueryResult.error` | +| PII handling for tabular `sample_values` | DB samples come from `information_schema`; tabular from file reads. Same `pii_flag` rule must apply both sides | Confirm tabular introspector calls `pii_detector` | +| Catalog refresh trigger (open question §3) | Affects both pipelines symmetrically | Default: rebuild on every upload/connect; defer auto-refresh | +| `updated_at` semantics — per-Source vs per-Catalog | Affects how each pipeline writes | Recommend: per-Source `updated_at` + Catalog-level `generated_at` | +| Dialect support scope for v1 | DB compiler must implement at least one dialect well | Recommend: Postgres first (matches app DB); MySQL second | +| Test-fixture format for golden IRs | Both compilers test against golden IR → expected output | Recommend: shared `tests/fixtures/golden_irs.json`; each side adds expected SQL or DataFrame | +| Logging conventions | structlog is already in place; both should log the same fields | Quick agreement: log `source_id`, `table_id`, `ir_version`, `elapsed_ms` | + +### Working rhythm (suggested) + +1. **Day 1** — 30-min sync to lock the decisions table. PR any contract/docstring changes that fall out. +2. **Week 1** — both build introspectors + agree on the planner prompt format. PR in parallel; review each other's. +3. **Week 2** — DB builds SQL compiler + DB executor; TAB builds pandas compiler + tabular executor. Both write golden tests against shared IR fixtures. +4. **Week 3** — pair on dispatcher, QueryService, and chat endpoint integration. End-to-end smoke test. +5. **Ongoing** — short daily standup, mostly to flag IR-shape questions and catalog-field additions *before* either side implements against an unconfirmed contract. + +Biggest risk: **silent contract drift** — one side adds a `QueryResult` field or assumes a new IR op exists, the other ships without it, and integration breaks at the dispatcher. The §0 lock + shared golden-IR fixtures are what prevent that. + +### Onboarding to Claude Code + +If you're new to Claude Code, before you start: + +1. Read `ARCHITECTURE.md` end-to-end (~10 min) — this is the source of truth. +2. Skim this file (`REPO_CONTEXT.md`) — find your section in the ownership table. +3. Read your owned files' docstrings — every stub explains its contract. +4. Open Claude Code in this repo. When you ask Claude to implement a stub: + - Reference the file path + the contract it should follow + - Point it at `ARCHITECTURE.md` section if relevant (e.g. §7 for IR validation) + - Ask it to write the test first (golden IR fixtures), then the implementation + - Always review the diff — don't auto-accept + +Useful slash commands while working: `/review` (PR review), `/security-review` (audit pending changes). + +--- + +## Conventions & gotchas + +- **Async event loop on Windows**: `run.py` sets `WindowsSelectorEventLoopPolicy` because psycopg3 async needs it. Don't call `uvicorn` directly on Windows. +- **Two Postgres engines**: `engine` (app tables) and `_pgvector_engine` (asyncpg with `prepared_statement_cache_size=0`) — the latter is required because PGVector emits `advisory_lock + CREATE EXTENSION` as a multi-statement string and asyncpg rejects multi-statement prepared queries. `init_db.py` creates the extension explicitly so `PGVector(create_extension=False)` skips that path. +- **Read-only at every layer for user DBs**: IR validation + compiler whitelists + sqlglot SELECT-only check + read-only DB credentials + LIMIT enforcement + 30s timeout. Five layers; no single point of failure. +- **Identifiers vs values**: identifiers (table/column names) come from the catalog and are inlined as quoted identifiers — they were verified at validation time so this is safe. Values from `IR.filters` are *always* parameterized, never inlined as strings. +- **Credential encryption**: Fernet via `dataeyond__db__credential__key` env var; lives in `security/credentials.py`. Sensitive fields = `{"password", "service_account_json"}`. +- **Settings env-var aliases**: `.env` uses double-underscore names (`azureai__api_key__4o`); `Settings` exposes them as `azureai_api_key_4o` via `Field(alias=...)`. Mind both forms when adding settings. +- **Prompts**: `src/config/prompts/*.md` — `intent_router`, `query_planner`, `chatbot_system`, `guardrails` are all written. `chatbot_system` has `guardrails` appended so guardrails take precedence in conflict. `catalog_enricher.md` was deleted in KM-557. `config/agents/` folder deleted in Cleanup PR. +- **Planner prompt parsing gotcha**: `query/planner/service.py` uses `SystemMessage(content=...)` not `("system", text)`. The tuple form causes LangChain to interpret `{...}` in `query_planner.md` as f-string variables and crash on every real invocation. Don't refactor back to tuples. +- **Tests**: 146+ unit tests in place. Run with `uv run pytest`. Planner eval gated on `RUN_PLANNER_EVAL=1`; catalog store integration test gated on `RUN_INTEGRATION_TESTS=1`. + +--- + +## Recommended reading order + +1. `ARCHITECTURE.md` — design intent (the source of truth) +2. `src/catalog/models.py` + `src/query/ir/models.py` — the two data shapes everything else moves between +3. `src/query/ir/operators.py` + `src/security/pii_patterns.py` — the explicit whitelists / patterns +4. Skim every `__init__.py`-level docstring under `src/catalog/`, `src/query/`, `src/agents/`, `src/pipeline/` — each describes the contract its module enforces +5. `main.py` + `src/db/postgres/{connection,init_db}.py` — runtime bootstrap +6. `ARCHITECTURE.md §10` — five open questions that haven't been decided yet + +--- + +## Open questions + +Resolved as Phase 2 landed: + +1. ✅ Catalog storage shape — Postgres `jsonb` row in `data_catalog` table, keyed by `user_id`. +2. ❌ Unstructured files in catalog — still not modeled; router uses `source_hint` from the LLM instead. +3. 🟡 Catalog refresh trigger — rebuild-on-upload-or-connect is the default. Explicit endpoint `POST /api/v1/data-catalog/rebuild` exists. Background TTL deferred. +4. ✅ Joins out of v1 IR — confirmed; single-table only. Revisit when real queries need it. +5. 🟡 PII `sample_values` — currently nulled out (skip). Mask/synthesize deferred. + +--- + +## Glossary + +- **Cu** — unstructured context (prose chunks) +- **Cs** — schema context (DB tables/columns from catalog) +- **Ct** — tabular context (file sheets/columns from catalog) +- **IR** — intermediate representation (the JSON query shape) +- **PII** — personally identifiable information +- **ABC** — abstract base class diff --git a/main.py b/main.py index 7398bf92cc5a96f9e118a3040b22cc257ef37903..6936252e6a4711acf04b7d53d89f899d3b62e054 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,7 @@ """Main application entry point.""" +from contextlib import asynccontextmanager + from fastapi import FastAPI from src.middlewares.logging import configure_logging, get_logger from src.middlewares.cors import add_cors_middleware @@ -9,8 +11,8 @@ from src.api.v1.document import router as document_router from src.api.v1.chat import router as chat_router from src.api.v1.room import router as room_router from src.api.v1.users import router as users_router -from src.api.v1.knowledge import router as knowledge_router from src.api.v1.db_client import router as db_client_router +from src.api.v1.data_catalog import router as data_catalog_router from src.db.postgres.init_db import init_db import uvicorn @@ -18,11 +20,21 @@ import uvicorn configure_logging() logger = get_logger("main") + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting application...") + await init_db() + logger.info("Database initialized") + yield + + # Create FastAPI app app = FastAPI( title="DataEyond Agentic Service", description="Multi-agent AI backend with RAG capabilities", - version="0.1.0" + version="0.1.0", + lifespan=lifespan, ) # Add middleware @@ -33,18 +45,10 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # Include routers app.include_router(users_router) app.include_router(document_router) -app.include_router(knowledge_router) app.include_router(room_router) app.include_router(chat_router) app.include_router(db_client_router) - - -@app.on_event("startup") -async def startup_event(): - """Initialize database on startup.""" - logger.info("Starting application...") - await init_db() - logger.info("Database initialized") +app.include_router(data_catalog_router) @app.get("/") diff --git a/pyproject.toml b/pyproject.toml index 42cdaa7b367d7d803e54cc07e85d0d6f9fca1479..031197172202a603b65144b61a847f771cc1fc05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,8 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["S101", "S105", "S106"] +# S608 in tests is a false positive — tests assert literal SQL strings as fixtures. +"tests/**" = ["S101", "S105", "S106", "S608"] [tool.mypy] python_version = "3.12" diff --git a/scripts/build_initial_catalogs.py b/scripts/build_initial_catalogs.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6520beed762a590f722a6af361d49d0f735771 --- /dev/null +++ b/scripts/build_initial_catalogs.py @@ -0,0 +1,73 @@ +"""Backfill catalogs for existing users. + +One-off script. For each user that already has registered DB connections or +uploaded tabular files, run the structured pipeline to build their catalog. + +Run once against the live DB after deploying this branch to populate catalog +rows for data registered before the catalog pipeline landed. + +Note: enrich_all_sources.py is not needed — LLM enrichment was removed in +KM-557. The pipeline is now introspect → merge → validate → upsert. + +Usage: + uv run python scripts/build_initial_catalogs.py [--user-id USER_ID] +""" + +import asyncio +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sqlalchemy import select +from src.db.postgres.connection import AsyncSessionLocal +from src.db.postgres.models import DatabaseClient, Document +from src.pipeline.triggers import on_db_registered, on_tabular_uploaded + + +async def main() -> None: + user_id_filter = None + if "--user-id" in sys.argv: + idx = sys.argv.index("--user-id") + user_id_filter = sys.argv[idx + 1] + print(f"Filtering to user_id: {user_id_filter}") + + async with AsyncSessionLocal() as db: + # ── 1. DB clients ────────────────────────────────────────────── + query = select(DatabaseClient).where(DatabaseClient.status == "active") + if user_id_filter: + query = query.where(DatabaseClient.user_id == user_id_filter) + result = await db.execute(query) + db_clients = result.scalars().all() + print(f"\nFound {len(db_clients)} active DB client(s)") + + for client in db_clients: + try: + await on_db_registered(client.id, client.user_id) + print(f" ✓ db_client {client.id} ({client.name})") + except Exception as e: + print(f" ✗ db_client {client.id} ({client.name}): {e}") + + # ── 2. Tabular files ─────────────────────────────────────────── + query = select(Document).where( + Document.file_type.in_(["csv", "xlsx"]), + Document.status == "completed", + ) + if user_id_filter: + query = query.where(Document.user_id == user_id_filter) + result = await db.execute(query) + docs = result.scalars().all() + print(f"\nFound {len(docs)} completed tabular file(s)") + + for doc in docs: + try: + await on_tabular_uploaded(doc.id, doc.user_id) + print(f" ✓ {doc.file_type} {doc.id} ({doc.filename})") + except Exception as e: + print(f" ✗ {doc.file_type} {doc.id} ({doc.filename}): {e}") + + print("\nDone.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/enrich_all_sources.py b/scripts/enrich_all_sources.py new file mode 100644 index 0000000000000000000000000000000000000000..637609266274ac01ecd2132ffbb972409a6bba9d --- /dev/null +++ b/scripts/enrich_all_sources.py @@ -0,0 +1,16 @@ +"""Bulk re-run CatalogEnricher with the current prompt. + +For when src/config/prompts/catalog_enricher.md changes and existing +catalog descriptions need to be regenerated across all users. + +Usage: + uv run python scripts/enrich_all_sources.py [--user-id USER_ID] +""" + + +def main() -> None: + raise NotImplementedError + + +if __name__ == "__main__": + main() diff --git a/src/agents/chat_handler.py b/src/agents/chat_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ead2f78ce8438de97ba2db05cdd2c1d04b4f6d9e --- /dev/null +++ b/src/agents/chat_handler.py @@ -0,0 +1,274 @@ +"""ChatHandler — top-level Phase 2 chat orchestrator. + +End-to-end flow per user message: + + 1. `IntentRouter.classify` → `chat` / `unstructured` / `structured`. + 2. Route: + - `chat` → no context. Pass straight to ChatbotAgent. + - `structured` → CatalogReader → QueryService → QueryResult. + - `unstructured` → DocumentRetriever (placeholder, raises until TAB + ships) → list[DocumentChunk]. + 3. `ChatbotAgent.astream` → yield text tokens. + 4. Wrap each step into an SSE-style event dict so the API endpoint can + stream them as Server-Sent Events. + +Phase 1's chat endpoint (`src/api/v1/chat.py`) is intentionally NOT touched +in this PR. PR7 cleanup will rewire it to call `ChatHandler.handle(...)`. + +All dependencies are injectable for tests. Default constructors lazy-build +production deps (no `Settings()` triggered at import time as long as you +inject mocks). +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any + +from langchain_core.messages import BaseMessage + +from src.middlewares.logging import get_logger +from src.retrieval.base import RetrievalResult + +from .chatbot import ChatbotAgent, DocumentChunk +from .orchestration import OrchestratorAgent + +if TYPE_CHECKING: + from ..catalog.reader import CatalogReader + from ..query.service import QueryService + from ..retrieval.router import RetrievalRouter + +logger = get_logger("chat_handler") + + +class ChatHandler: + """Top-level chat orchestrator. + + Returns an `AsyncIterator[dict]` of SSE-style events with shape + `{"event": , "data": }`. Event types: + - `intent` — emitted once after classification (JSON-encoded decision) + - `sources` — JSON array of source refs (one per structured table, or + per (document_id, page_label) for unstructured) + - `chunk` — text fragment of the streaming answer (one per token) + - `done` — end of stream (data is empty string) + - `error` — failure; data is a user-facing message + """ + + def __init__( + self, + intent_router: OrchestratorAgent | None = None, + answer_agent: ChatbotAgent | None = None, + catalog_reader: CatalogReader | None = None, + query_service: QueryService | None = None, + document_retriever: RetrievalRouter | None = None, + ) -> None: + self._intent_router = intent_router + self._answer_agent = answer_agent + self._catalog_reader = catalog_reader + self._query_service = query_service + self._document_retriever = document_retriever + + # ------------------------------------------------------------------ + # Lazy default-dep builders + # ------------------------------------------------------------------ + + def _get_intent_router(self) -> OrchestratorAgent: + if self._intent_router is None: + self._intent_router = OrchestratorAgent() + return self._intent_router + + def _get_answer_agent(self) -> ChatbotAgent: + if self._answer_agent is None: + self._answer_agent = ChatbotAgent() + return self._answer_agent + + def _get_catalog_reader(self) -> CatalogReader: + if self._catalog_reader is None: + from ..catalog.reader import CatalogReader + from ..catalog.store import CatalogStore + + self._catalog_reader = CatalogReader(CatalogStore()) + return self._catalog_reader + + def _get_query_service(self) -> QueryService: + if self._query_service is None: + from ..query.service import QueryService + + self._query_service = QueryService() + return self._query_service + + def _get_document_retriever(self) -> RetrievalRouter: + if self._document_retriever is None: + from ..retrieval.router import RetrievalRouter + + self._document_retriever = RetrievalRouter() + return self._document_retriever + + # ------------------------------------------------------------------ + # Public entry + # ------------------------------------------------------------------ + + async def handle( + self, + message: str, + user_id: str, + history: list[BaseMessage] | None = None, + ) -> AsyncIterator[dict[str, Any]]: + # ---- 1. Classify intent -------------------------------------- + try: + decision = await self._get_intent_router().classify(message, history) + except Exception as e: + logger.error("intent classification failed", error=str(e)) + yield {"event": "error", "data": f"Could not classify message: {e}"} + return + + yield {"event": "intent", "data": decision.model_dump_json()} + + rewritten = decision.rewritten_query or message + query_result = None + chunks: list[DocumentChunk] | None = None + raw_chunks: Any = None + + # ---- 2. Route ------------------------------------------------ + if decision.source_hint == "structured": + try: + catalog = await self._get_catalog_reader().read(user_id, "structured") + query_result = await self._get_query_service().run( + user_id, rewritten, catalog + ) + except Exception as e: + logger.error( + "structured route failed", + user_id=user_id, + error=str(e), + ) + yield {"event": "error", "data": f"Structured query failed: {e}"} + return + elif decision.source_hint == "unstructured": + try: + raw_chunks = await self._get_document_retriever().retrieve( + rewritten, user_id + ) + chunks = _normalize_chunks(raw_chunks) + except NotImplementedError: + logger.warning("DocumentRetriever placeholder hit", user_id=user_id) + yield { + "event": "error", + "data": "Document retrieval is not yet available — pending implementation.", + } + return + except Exception as e: + logger.error( + "unstructured route failed", user_id=user_id, error=str(e) + ) + yield {"event": "error", "data": f"Document retrieval failed: {e}"} + return + # else: chat path — no context + + # ---- 2b. Emit sources --------------------------------------- + sources = _build_sources( + decision.source_hint, user_id, query_result, raw_chunks + ) + yield {"event": "sources", "data": json.dumps(sources)} + + # ---- 3. Stream answer ---------------------------------------- + try: + async for token in self._get_answer_agent().astream( + message, + history=history, + query_result=query_result, + chunks=chunks, + ): + yield {"event": "chunk", "data": token} + except Exception as e: + logger.error("answer streaming failed", user_id=user_id, error=str(e)) + yield {"event": "error", "data": f"Answer generation failed: {e}"} + return + + yield {"event": "done", "data": ""} + + +def _build_sources( + source_hint: str, + user_id: str, + query_result: Any, + raw_chunks: Any, +) -> list[dict[str, Any]]: + """Build the sources payload for the SSE `sources` event. + + - structured: one entry per executed table (table_name only). + - unstructured: deduped by (document_id, page_label), Phase 1 shape. + - chat or error: empty list. + """ + if source_hint == "structured": + if query_result is None or getattr(query_result, "error", None): + return [] + table_name = getattr(query_result, "table_name", "") or "" + if not table_name: + return [] + return [{ + "document_id": f"{user_id}_{table_name}", + "filename": table_name, + "page_label": None, + }] + + if source_hint == "unstructured" and raw_chunks: + seen: set[tuple[Any, Any]] = set() + sources: list[dict[str, Any]] = [] + for item in raw_chunks: + if isinstance(item, RetrievalResult): + data = item.metadata.get("data", {}) + elif isinstance(item, dict): + data = item + else: + continue + key = (data.get("document_id"), data.get("page_label")) + if key in seen or key == (None, None): + continue + seen.add(key) + sources.append({ + "document_id": data.get("document_id"), + "filename": data.get("filename", "Unknown"), + "page_label": data.get("page_label", "Unknown"), + }) + return sources + + return [] + + +def _normalize_chunks(raw: Any) -> list[DocumentChunk]: + """Convert whatever the retriever returns into list[DocumentChunk]. + + The Phase 2 `DocumentRetriever.retrieve` interface is a stub today; + when TAB owner ships it, it should return `list[DocumentChunk]` + directly so this normalizer becomes a no-op. Until then we coerce + common shapes (dict-with-content, plain string) defensively. + """ + if not raw: + return [] + if isinstance(raw, list) and all(isinstance(c, DocumentChunk) for c in raw): + return raw + chunks: list[DocumentChunk] = [] + for item in raw: + if isinstance(item, DocumentChunk): + chunks.append(item) + elif isinstance(item, dict): + chunks.append( + DocumentChunk( + content=str(item.get("content", "")), + filename=item.get("filename"), + page_label=item.get("page_label"), + ) + ) + elif isinstance(item, RetrievalResult): + data = item.metadata.get("data", {}) + page = data.get("page_label") + chunks.append(DocumentChunk( + content=item.content, + filename=data.get("filename"), + page_label=str(page) if page is not None else None, + )) + elif isinstance(item, str): + chunks.append(DocumentChunk(content=item)) + return chunks diff --git a/src/agents/chatbot.py b/src/agents/chatbot.py index 44856d94373589d3890398fb75e2ac6e3b346526..ff1fb129a7a08cd0e25dfbd46ac08a93736db1aa 100644 --- a/src/agents/chatbot.py +++ b/src/agents/chatbot.py @@ -1,85 +1,169 @@ -"""Chatbot agent with RAG capabilities.""" +"""ChatbotAgent — final answer formation. Phase 2 chatbot. -import tiktoken -from langchain_openai import AzureChatOpenAI -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +Receives one of: + - a `QueryResult` (structured query path), + - a list of document chunks (unstructured path), or + - nothing (chat-only path: greeting, farewell, meta question). + +Streams the answer token-by-token so the chat handler can wrap each token +into an SSE event. Conversation history is supported. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from langchain_core.messages import BaseMessage from langchain_core.output_parsers import StrOutputParser -from src.config.settings import settings +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import Runnable +from langchain_openai import AzureChatOpenAI + from src.middlewares.logging import get_logger -from langchain_core.messages import HumanMessage, AIMessage + +from ..query.executor.base import QueryResult logger = get_logger("chatbot") -_enc = tiktoken.get_encoding("cl100k_base") +_PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts" +_SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md" +_GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md" -def _count_tokens(messages: list, context: str) -> dict: - msg_tokens = sum(len(_enc.encode(m.content)) for m in messages) - ctx_tokens = len(_enc.encode(context)) - return {"messages_tokens": msg_tokens, "context_tokens": ctx_tokens, "total": msg_tokens + ctx_tokens} +@dataclass +class DocumentChunk: + """One retrieved document chunk for the unstructured path.""" -class ChatbotAgent: - """Chatbot agent with RAG capabilities.""" - - def __init__(self): - self.llm = AzureChatOpenAI( - azure_deployment=settings.azureai_deployment_name_4o, - openai_api_version=settings.azureai_api_version_4o, - azure_endpoint=settings.azureai_endpoint_url_4o, - api_key=settings.azureai_api_key_4o, - temperature=0.7 + content: str + filename: str | None = None + page_label: str | None = None + + +def _load_system_prompt() -> str: + """Compose system prompt = chatbot_system.md + guardrails.md. + + Guardrails appended last so they take precedence in conflict (matches + the docstring at the top of guardrails.md). + """ + chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8") + guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8") + return f"{chatbot}\n\n{guardrails}" + + +def _format_query_result(qr: QueryResult) -> str: + """Render a QueryResult as a compact context block for the LLM.""" + source_label = qr.source_name or "(unknown source)" + table_label = qr.table_name or "(unknown table)" + if qr.error: + return ( + f"[Query result — FAILED]\n" + f"source: {source_label}\n" + f"table: {table_label}\n" + f"error: {qr.error}" ) + lines: list[str] = [ + "[Query result]", + f"source: {source_label}", + f"table: {table_label}", + f"backend: {qr.backend}", + f"row_count: {qr.row_count}" + + (" (truncated)" if qr.truncated else ""), + f"elapsed_ms: {qr.elapsed_ms}", + ] + if qr.rows: + # Cap rendering at 25 rows; the LLM doesn't need the full set + cap = min(len(qr.rows), 25) + columns = list(qr.rows[0].keys()) + lines.append("columns: " + ", ".join(columns)) + lines.append("rows:") + for row in qr.rows[:cap]: + lines.append(" " + ", ".join(f"{k}={row[k]!r}" for k in columns)) + if cap < len(qr.rows): + lines.append(f" ... (+{len(qr.rows) - cap} more rows omitted from prompt)") + return "\n".join(lines) + + +def _format_document_chunks(chunks: list[DocumentChunk]) -> str: + if not chunks: + return "" + blocks: list[str] = [] + for c in chunks: + label_parts = [p for p in (c.filename, c.page_label) if p] + label = ", ".join(label_parts) if label_parts else "Unknown source" + blocks.append(f"[Source: {label}]\n{c.content}") + return "\n\n".join(blocks) + + +def _build_context_block( + query_result: QueryResult | None, + chunks: list[DocumentChunk] | None, +) -> str: + parts: list[str] = [] + if query_result is not None: + parts.append(_format_query_result(query_result)) + if chunks: + parts.append(_format_document_chunks(chunks)) + return "\n\n".join(parts) if parts else "(no data context — answer conversationally)" + + +def _build_default_chain() -> Runnable: + from src.config.settings import settings + + llm = AzureChatOpenAI( + azure_deployment=settings.azureai_deployment_name_4o, + openai_api_version=settings.azureai_api_version_4o, + azure_endpoint=settings.azureai_endpoint_url_4o, + api_key=settings.azureai_api_key_4o, + temperature=0.3, + ) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", _load_system_prompt()), + MessagesPlaceholder(variable_name="history", optional=True), + ("human", "{message}"), + ("system", "Data context for this turn:\n\n{context}"), + ] + ) + return prompt | llm | StrOutputParser() + + +class ChatbotAgent: + """Formats and streams the final user-facing answer. - # Read system prompt - try: - with open("src/config/agents/system_prompt.md", "r") as f: - system_prompt = f.read() - except FileNotFoundError: - system_prompt = "You are a helpful AI assistant with access to user's uploaded documents." + `chain` is injectable: tests pass a fake that yields canned tokens. + Default constructs the production Azure OpenAI streaming chain on + first use. + """ - # Create prompt template - self.prompt = ChatPromptTemplate.from_messages([ - ("system", system_prompt), - MessagesPlaceholder(variable_name="messages"), - ("system", "Relevant documents:\n{context}") - ]) + def __init__(self, chain: Runnable | None = None) -> None: + self._chain = chain - # Create chain - self.chain = self.prompt | self.llm | StrOutputParser() + def _ensure_chain(self) -> Runnable: + if self._chain is None: + self._chain = _build_default_chain() + return self._chain - async def generate_response( + async def astream( self, - messages: list, - context: str = "" - ) -> str: - """Generate response with optional RAG context.""" - try: - logger.info("Generating chatbot response") - - # Generate response - response = await self.chain.ainvoke({ - "messages": messages, - "context": context - }) - - logger.info(f"Generated response: {response[:100]}...") - return response - - except Exception as e: - logger.error("Response generation failed", error=str(e)) - raise - - async def astream_response(self, messages: list, context: str = ""): - """Stream response tokens as they are generated.""" - try: - token_counts = _count_tokens(messages, context) - logger.info("LLM input tokens", **token_counts) - async for token in self.chain.astream({"messages": messages, "context": context}): - yield token - except Exception as e: - logger.error("Response streaming failed", error=str(e)) - raise - - -chatbot = ChatbotAgent() + message: str, + history: list[BaseMessage] | None = None, + query_result: QueryResult | None = None, + chunks: list[DocumentChunk] | None = None, + ) -> AsyncIterator[str]: + """Stream tokens of the final answer. + + Caller wraps each token into the SSE format. Empty `history` and + no context = pure chat reply. + """ + chain = self._ensure_chain() + payload: dict[str, Any] = { + "message": message, + "history": history or [], + "context": _build_context_block(query_result, chunks), + } + async for token in chain.astream(payload): + yield token diff --git a/src/agents/orchestration.py b/src/agents/orchestration.py index 44614ede11ef183f1b16749be031ffce3be40f06..61ed7cb40383ee87a6bbc3201f3093fb6d7f8c98 100644 --- a/src/agents/orchestration.py +++ b/src/agents/orchestration.py @@ -1,79 +1,109 @@ -"""Orchestrator agent for intent recognition and planning.""" +"""OrchestratorAgent — classifies a user message and emits source_hint. -from langchain_openai import AzureChatOpenAI +Output: needs_search (bool) + source_hint ∈ { chat, unstructured, structured } ++ rewritten_query (standalone form of the user's question, history-resolved). + +Phase 2 replaces the previous intent-classification body. The class name +is preserved so existing import sites (`from src.agents.orchestration +import OrchestratorAgent`) keep working. The default LLM chain is +constructed lazily so the module is import-safe even without `.env` +populated. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +from langchain_core.messages import BaseMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from src.config.settings import settings +from langchain_core.runnables import Runnable +from langchain_openai import AzureChatOpenAI +from pydantic import BaseModel, Field + from src.middlewares.logging import get_logger -from src.models.structured_output import IntentClassification logger = get_logger("orchestrator") +SourceHint = Literal["chat", "unstructured", "structured"] + +_PROMPT_PATH = ( + Path(__file__).resolve().parent.parent + / "config" + / "prompts" + / "intent_router.md" +) + + +class IntentRouterDecision(BaseModel): + """LLM output. Pydantic so it can be used with `with_structured_output`.""" + + needs_search: bool = Field( + ..., description="True if we must look at the user's data to answer." + ) + source_hint: SourceHint = Field( + ..., + description="Which downstream path: 'chat' (no lookup), " + "'unstructured' (PDF/DOCX/TXT prose), 'structured' (DB / tabular file).", + ) + rewritten_query: str | None = Field( + None, + description="Standalone version of the question, history-resolved. " + "Null when needs_search=false.", + ) + + +def _load_prompt_text() -> str: + return _PROMPT_PATH.read_text(encoding="utf-8") + + +def _build_default_chain() -> Runnable: + from src.config.settings import settings + + llm = AzureChatOpenAI( + azure_deployment=settings.azureai_deployment_name_4o, + openai_api_version=settings.azureai_api_version_4o, + azure_endpoint=settings.azureai_endpoint_url_4o, + api_key=settings.azureai_api_key_4o, + temperature=0, + ) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", _load_prompt_text()), + MessagesPlaceholder(variable_name="history", optional=True), + ("human", "{message}"), + ] + ) + return prompt | llm.with_structured_output(IntentRouterDecision) + class OrchestratorAgent: - """Orchestrator agent for intent recognition and planning.""" - - def __init__(self): - self.llm = AzureChatOpenAI( - azure_deployment=settings.azureai_deployment_name_4o, - openai_api_version=settings.azureai_api_version_4o, - azure_endpoint=settings.azureai_endpoint_url_4o, - api_key=settings.azureai_api_key_4o, - temperature=0 + """Classifies a user message into chat / unstructured / structured. + + Inject `structured_chain` for tests; default builds the production + Azure OpenAI chain on first use. + """ + + def __init__(self, structured_chain: Runnable | None = None) -> None: + self._chain = structured_chain + + def _ensure_chain(self) -> Runnable: + if self._chain is None: + self._chain = _build_default_chain() + return self._chain + + async def classify( + self, + message: str, + history: list[BaseMessage] | None = None, + ) -> IntentRouterDecision: + chain = self._ensure_chain() + decision: IntentRouterDecision = await chain.ainvoke( + {"message": message, "history": history or []} ) - - self.prompt = ChatPromptTemplate.from_messages([ - ("system", """You are an orchestrator agent. You receive recent conversation history and the user's latest message. - -Your task: -1. Determine intent: question, greeting, goodbye, or other -2. Decide whether to search the user's documents (needs_search) -3. If search is needed, rewrite the user's message into a STANDALONE search query that incorporates necessary context from conversation history. If the user says "tell me more" or "how many papers?", the search_query must spell out the full topic explicitly from history. -4. If no search needed, provide a short direct_response (plain text only, no markdown formatting). - -Intent Routing: -- question -> needs_search=True, search_query= -- greeting -> needs_search=False, direct_response="Hello! How can I assist you today?" -- goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!" -- other -> needs_search=True, search_query= - -Source Routing (set source_hint): -- Columns, tables, sheets, data types, schema, row counts, statistics -> source_hint=schema -- Document content, paragraphs, reports, articles, text -> source_hint=document -- Unclear or spans both -> source_hint=both -"""), - MessagesPlaceholder(variable_name="history"), - ("user", "{message}") - ]) - - # with_structured_output uses function calling — guarantees valid schema regardless of LLM response style - self.chain = self.prompt | self.llm.with_structured_output(IntentClassification) - - async def analyze_message(self, message: str, history: list = None) -> dict: - """Analyze user message and determine next actions. - - Args: - message: The current user message. - history: Recent conversation as LangChain BaseMessage objects (oldest-first). - Used to rewrite ambiguous follow-ups into standalone search queries. - """ - try: - logger.info(f"Analyzing message: {message[:50]}...") - - history_messages = history or [] - result: IntentClassification = await self.chain.ainvoke({"message": message, "history": history_messages}) - - logger.info(f"Intent: {result.intent}, Needs search: {result.needs_search}, Search query: {result.search_query[:50] if result.search_query else ''}") - return result.model_dump() - - except Exception as e: - logger.error("Message analysis failed", error=str(e)) - # Fallback to treating everything as a question - return { - "intent": "question", - "needs_search": True, - "search_query": message, - "direct_response": None - } - - -orchestrator = OrchestratorAgent() + logger.info( + "intent classified", + source_hint=decision.source_hint, + needs_search=decision.needs_search, + ) + return decision diff --git a/src/api/v1/chat.py b/src/api/v1/chat.py index c74276ddffab83b195c3c6fe5145839bd4b47193..374626dff3563ab01c1fe2440325e09071d450be 100644 --- a/src/api/v1/chat.py +++ b/src/api/v1/chat.py @@ -1,46 +1,40 @@ """Chat endpoint with streaming support.""" -import asyncio import uuid +import json +from typing import List, Dict, Any, Optional + from fastapi import APIRouter, Depends, HTTPException +from langchain_core.messages import HumanMessage, AIMessage +from pydantic import BaseModel +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sse_starlette.sse import EventSourceResponse + +from src.agents.chat_handler import ChatHandler +from src.config.settings import settings from src.db.postgres.connection import get_db from src.db.postgres.models import ChatMessage, MessageSource -from src.agents.orchestration import orchestrator -from src.agents.chatbot import chatbot -from src.rag.retriever import retriever -from src.rag.base import RetrievalResult -from src.query.query_executor import query_executor -from src.query.base import QueryResult from src.db.redis.connection import get_redis -from src.config.settings import settings from src.middlewares.logging import get_logger, log_execution -from sse_starlette.sse import EventSourceResponse -from langchain_core.messages import HumanMessage, AIMessage -from sqlalchemy import select -from pydantic import BaseModel -from typing import List, Dict, Any, Optional -import json + +logger = get_logger("chat_api") + +router = APIRouter(prefix="/api/v1", tags=["Chat"]) _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"]) -def _fast_intent(message: str) -> Optional[dict]: - """Bypass LLM orchestrator for obvious greetings and farewells.""" +def _fast_intent(message: str) -> Optional[str]: + """Return a direct response for obvious greetings/farewells, else None.""" lower = message.lower().strip().rstrip("!.,?") if lower in _GREETINGS: - return {"intent": "greeting", "needs_search": False, - "direct_response": "Hello! How can I assist you today?", "search_query": ""} + return "Hello! How can I assist you today?" if lower in _GOODBYES: - return {"intent": "goodbye", "needs_search": False, - "direct_response": "Goodbye! Have a great day!", "search_query": ""} + return "Goodbye! Have a great day!" return None -logger = get_logger("chat_api") - -router = APIRouter(prefix="/api/v1", tags=["Chat"]) - class ChatRequest(BaseModel): user_id: str @@ -48,66 +42,6 @@ class ChatRequest(BaseModel): message: str -def _format_context(results: List[RetrievalResult]) -> str: - """Format retrieval results as context string for the LLM.""" - lines = [] - for result in results: - data = result.metadata.get("data", {}) - filename = data.get("filename", "Unknown") - page = data.get("page_label") - source_label = f"{filename}, p.{page}" if page else filename - lines.append(f"[Source: {source_label}]\n{result.content}\n") - return "\n".join(lines) - - -def _extract_sources(results: List[RetrievalResult]) -> List[Dict[str, Any]]: - """Extract deduplicated source references from retrieval results.""" - seen = set() - sources = [] - for result in results: - meta = result.metadata - data = meta.get("data", {}) - if "document_id" in data: - key = (data.get("document_id"), data.get("page_label")) - if key not in seen: - seen.add(key) - sources.append({ - "document_id": data.get("document_id"), - "filename": data.get("filename", "Unknown"), - "page_label": data.get("page_label", "Unknown"), - }) - else: - key = (data.get("table_name"), data.get("column_name")) - if key not in seen: - seen.add(key) - table_name = data.get("table_name") - user_id = meta.get("user_id") - sources.append({ - "document_id": f"{user_id}_{table_name}", - "filename": data.get("table_name", "Unknown"), - "page_label": data.get("column_name", "Unknown"), - }) - - logger.debug(f"Extracted sources: {sources}") - return sources - - -def _format_query_results(results: list[QueryResult]) -> str: - if not results: - return "" - lines = [] - for r in results: - name = r.metadata.get("client_name", r.source_id) - lines.append(f"[Query result — {name}, tables: {r.table_or_file}]") - lines.append(f"SQL: {r.metadata.get('sql', '')}") - if r.columns and r.rows: - lines.append(" | ".join(r.columns)) - for row in r.rows[:20]: - lines.append(" | ".join(str(row.get(c, "")) for c in r.columns)) - lines.append(f"({r.row_count} rows total)\n") - return "\n".join(lines) - - async def get_cached_response(redis, cache_key: str) -> Optional[str]: cached = await redis.get(cache_key) if cached: @@ -163,13 +97,15 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): """Chat endpoint with streaming response. SSE event sequence: - 1. sources — JSON array of {document_id, filename, page_label} + 1. sources — JSON array of source refs from ChatHandler (table for + structured; deduped document_id/page_label for unstructured) 2. chunk — text fragments of the answer 3. done — signals end of stream """ redis = await get_redis() - cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}" + + # Redis cache hit cached = await get_cached_response(redis, cache_key) if cached: logger.info("Returning cached response") @@ -183,96 +119,43 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): return EventSourceResponse(stream_cached()) try: - # Step 1: Fast local intent check (skips LLM for greetings/farewells) - intent_result = _fast_intent(request.message) - - context = "" - sources: List[Dict[str, Any]] = [] - - if intent_result is None: - # Step 2: Launch retrieval and history loading in parallel, then run orchestrator. - # k=5 - # tables — db_executor's FK expansion is one-hop and cannot bridge - # 2-hop gaps (e.g. customers -> order_items -> products) on its own. - retrieval_task = asyncio.create_task( - retriever.retrieve(request.message, request.user_id, db, k=5) - ) - history_task = asyncio.create_task( - load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator - ) - history = await history_task # fast DB query (<100ms), done before orchestrator finishes - intent_result = await orchestrator.analyze_message(request.message, history) - - search_query = intent_result.get("search_query", request.message) or request.message - if not intent_result.get("needs_search"): - retrieval_task.cancel() - try: - await retrieval_task - except asyncio.CancelledError: - pass - raw_results = [] - else: - logger.info(f"Searching for: {search_query}") - if search_query != request.message: - retrieval_task.cancel() - try: - await retrieval_task - except asyncio.CancelledError: - pass - raw_results = await retriever.retrieve( - query=search_query, - user_id=request.user_id, - db=db, - k=5, - source_hint=intent_result.get("source_hint", "both"), - ) - else: - raw_results = await retrieval_task - - context = _format_context(raw_results) - sources = _extract_sources(raw_results) - - source_hint = intent_result.get("source_hint", "both") - if source_hint in ("schema", "both"): - # Use search_query (orchestrator's standalone rewrite) so follow-up - # messages like "dive deeper" or "show me last year" resolve correctly. - # For first-turn questions search_query == request.message, so no change. - query_results = await query_executor.execute( - results=raw_results, - user_id=request.user_id, - db=db, - question=search_query, - ) - query_context = _format_query_results(query_results) - if query_context: - context = query_context + "\n\n" + context - - # Step 3: Direct response for greetings / non-document intents - if intent_result.get("direct_response"): - response = intent_result["direct_response"] - await cache_response(redis, cache_key, response) - await save_messages(db, request.room_id, request.message, response, sources=[]) + # Fast intent: greetings/farewells bypass LLM entirely + direct = _fast_intent(request.message) + if direct: + await cache_response(redis, cache_key, direct) + await save_messages(db, request.room_id, request.message, direct, sources=[]) async def stream_direct(): yield {"event": "sources", "data": json.dumps([])} - yield {"event": "message", "data": response} + yield {"event": "chunk", "data": direct} + yield {"event": "done", "data": ""} return EventSourceResponse(stream_direct()) - # Step 4: Stream answer token-by-token as LLM generates it - # Load full history (10 msgs) for chatbot — richer context than the 6 used by orchestrator - full_history = await load_history(db, request.room_id, limit=10) - messages = full_history + [HumanMessage(content=request.message)] + history = await load_history(db, request.room_id, limit=10) + handler = ChatHandler() async def stream_response(): full_response = "" - yield {"event": "sources", "data": json.dumps(sources)} - async for token in chatbot.astream_response(messages, context): - full_response += token - yield {"event": "chunk", "data": token} - yield {"event": "done", "data": ""} - await cache_response(redis, cache_key, full_response) - await save_messages(db, request.room_id, request.message, full_response, sources=sources) + sources: List[Dict[str, Any]] = [] + async for event in handler.handle(request.message, request.user_id, history): + if event["event"] == "sources": + try: + sources = json.loads(event["data"]) or [] + except (TypeError, ValueError): + sources = [] + yield event + elif event["event"] == "chunk": + full_response += event["data"] + yield event + elif event["event"] == "done": + await cache_response(redis, cache_key, full_response) + await save_messages(db, request.room_id, request.message, full_response, sources=sources) + yield event + elif event["event"] == "error": + yield event + return + # "intent" event: consumed internally, not forwarded to frontend return EventSourceResponse(stream_response()) diff --git a/src/api/v1/data_catalog.py b/src/api/v1/data_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..188d16d6172211d67a2f90468b56205568676d8c --- /dev/null +++ b/src/api/v1/data_catalog.py @@ -0,0 +1,100 @@ +"""API endpoints for the per-user data catalog index. + +The index is a lightweight summary of every structured source registered +by a user (DB connections and tabular files). It is intended to be +consumed by the catalog refresher and by frontend listings — full +catalog payloads (tables + columns + samples + stats) are not exposed +here on purpose. +""" + +from typing import List + +from fastapi import APIRouter, HTTPException, Query, status + +from src.catalog.store import CatalogStore +from src.middlewares.logging import get_logger, log_execution +from src.models.api.catalog import CatalogIndexEntry +from src.pipeline.triggers import on_catalog_rebuild_requested + +logger = get_logger("data_catalog_api") + +router = APIRouter(prefix="/api/v1", tags=["Data Catalog"]) + + +@router.get( + "/data-catalog/{user_id}", + response_model=List[CatalogIndexEntry], + summary="List the user's data catalog index", + response_description="One entry per registered structured source.", + responses={ + 200: {"description": "Returns an empty list if the user has no registered sources."}, + 500: {"description": "Internal server error while reading the catalog."}, + }, +) +@log_execution(logger) +async def list_data_catalog_index(user_id: str): + """ + Return a lightweight index of every structured source registered by the user. + + One entry per source (DB connection or tabular file), including the + `source_id`, `source_type`, display `name`, `location_ref`, current + `table_count`, and `updated_at` timestamp. + + Used by the catalog refresher to decide which sources need to be + rebuilt. Returns an empty list if the user has no catalog yet. + """ + try: + catalog = await CatalogStore().get(user_id) + except Exception as e: + logger.error("Failed to read catalog index", user_id=user_id, error=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to read catalog index: {e}", + ) + + if catalog is None: + return [] + + return [ + CatalogIndexEntry( + source_id=s.source_id, + source_type=s.source_type, + name=s.name, + location_ref=s.location_ref, + table_count=len(s.tables), + updated_at=s.updated_at, + ) + for s in catalog.sources + ] + + +@router.post( + "/data-catalog/rebuild", + status_code=status.HTTP_200_OK, + summary="Rebuild the catalog for a user", + response_description="Confirmation that the rebuild was triggered.", + responses={ + 200: {"description": "Rebuild completed. Per-source errors are logged but do not fail this request."}, + 500: {"description": "Unexpected error before the rebuild loop started."}, + }, +) +@log_execution(logger) +async def rebuild_data_catalog( + user_id: str = Query(..., description="ID of the user whose catalog should be rebuilt."), +): + """ + Re-introspect every source in the user's catalog and upsert the results. + + Each source (DB connection or tabular file) is processed independently. + A failure on one source is logged but does not abort the remaining sources. + If the user has no catalog yet, returns success with no-op. + """ + try: + await on_catalog_rebuild_requested(user_id) + except Exception as e: + logger.error("catalog rebuild failed", user_id=user_id, error=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Catalog rebuild failed: {e}", + ) + return {"status": "success", "user_id": user_id} diff --git a/src/api/v1/db_client.py b/src/api/v1/db_client.py index 08bd4d75fe74fc2f0537552b1ff783fd8762a2a5..8879b9d3448e3c4a3ef091bc661243d6326655cf 100644 --- a/src/api/v1/db_client.py +++ b/src/api/v1/db_client.py @@ -27,8 +27,7 @@ from src.models.credentials import ( # noqa: F401 — re-exported for Swagger s SqlServerCredentials, SupabaseCredentials, ) -from src.pipeline.db_pipeline import db_pipeline_service -from src.utils.db_credential_encryption import decrypt_credentials_dict +from src.pipeline.triggers import on_db_registered logger = get_logger("database_client_api") @@ -407,20 +406,22 @@ async def delete_database_client( raise HTTPException(status_code=403, detail="Access denied") await database_client_service.delete(db, client_id) + from src.pipeline.triggers import on_db_deleted + await on_db_deleted(client_id, user_id) return {"status": "success", "message": "Database client deleted successfully"} @router.post( "/database-clients/{client_id}/ingest", status_code=status.HTTP_200_OK, - summary="Ingest schema from a registered database into the vector store", - response_description="Count of chunks ingested.", + summary="Build the catalog for a registered database connection", + response_description="Confirmation that the catalog was built.", responses={ - 200: {"description": "Ingestion completed successfully."}, + 200: {"description": "Catalog built successfully."}, 403: {"description": "Access denied — user_id does not own this connection."}, 404: {"description": "Connection not found."}, - 501: {"description": "The connection's db_type is not yet supported by the pipeline."}, - 500: {"description": "Ingestion failed (connection error, profiling error, etc.)."}, + 409: {"description": "Connection is inactive."}, + 500: {"description": "Catalog build failed."}, }, ) @limiter.limit("5/minute") @@ -432,11 +433,9 @@ async def ingest_database_client( db: AsyncSession = Depends(get_db), ): """ - Decrypt the stored credentials, connect to the user's database, introspect - its schema, profile each column, embed the descriptions, and store them in - the shared PGVector collection tagged with `source_type="database"`. - - Chunks become retrievable via the same retriever used for document chunks. + Introspect the registered database and build (or rebuild) the catalog entry + for this connection. The catalog is stored in `data_catalog` and used by + the query pipeline to plan structured queries. """ client = await database_client_service.get(db, client_id) @@ -453,21 +452,12 @@ async def ingest_database_client( ) try: - creds = decrypt_credentials_dict(client.credentials) - with db_pipeline_service.engine_scope( - db_type=client.db_type, - credentials=creds, - ) as engine: - total = await db_pipeline_service.run(user_id=user_id, client_id=client_id, engine=engine) - except NotImplementedError as e: - raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e)) + await on_db_registered(client_id, user_id) except Exception as e: - logger.error( - f"Ingestion failed for client {client_id}", user_id=user_id, error=str(e) - ) + logger.error("catalog build failed", client_id=client_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Ingestion failed: {e}", + detail=f"Catalog build failed: {e}", ) - return {"status": "success", "client_id": client_id, "chunks_ingested": total} + return {"status": "success", "client_id": client_id} diff --git a/src/api/v1/document.py b/src/api/v1/document.py index 92380bb2c20b74fb4cabeb3d5704ae8a29e33477..245abd23ec1299c57564ec4f7ac2de189cd3fd13 100644 --- a/src/api/v1/document.py +++ b/src/api/v1/document.py @@ -6,7 +6,7 @@ from src.db.postgres.connection import get_db from src.document.document_service import document_service from src.middlewares.logging import get_logger, log_execution from src.middlewares.rate_limit import limiter -from src.pipeline.document_pipeline.document_pipeline import document_pipeline +from src.pipeline.document_pipeline import document_pipeline from pydantic import BaseModel from typing import List @@ -24,7 +24,7 @@ class DocumentResponse(BaseModel): created_at: str -# NOTE: Keep in sync with SUPPORTED_FILE_TYPES in src/pipeline/document_pipeline/document_pipeline.py +# NOTE: Keep in sync with SUPPORTED_FILE_TYPES in src/pipeline/document_pipeline.py _DOC_TYPES = [ {"doc_type": "pdf", "max_size": 10, "status": "active", "message": None}, {"doc_type": "docx", "max_size": 10, "status": "active", "message": None}, @@ -92,6 +92,8 @@ async def delete_document( ): """Delete a document.""" await document_pipeline.delete(document_id, user_id, db) + from src.pipeline.triggers import on_tabular_deleted + await on_tabular_deleted(document_id, user_id) return {"status": "success", "message": "Document deleted successfully"} @@ -104,5 +106,13 @@ async def process_document( ): """Process document and ingest to vector index.""" data = await document_pipeline.process(document_id, user_id, db) + + if data["file_type"] in ("csv", "xlsx"): + from src.pipeline.triggers import on_tabular_uploaded + try: + await on_tabular_uploaded(document_id, user_id) + except Exception as e: + logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e)) + return {"status": "success", "message": "Document processed successfully", "data": data} \ No newline at end of file diff --git a/src/api/v1/knowledge.py b/src/api/v1/knowledge.py deleted file mode 100644 index 2af2fa3a0b2e7acb1481ffd7bfcd438b267556a3..0000000000000000000000000000000000000000 --- a/src/api/v1/knowledge.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Knowledge base management API endpoints.""" - -from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncSession -from src.db.postgres.connection import get_db -from src.middlewares.logging import get_logger, log_execution - -logger = get_logger("knowledge_api") - -router = APIRouter(prefix="/api/v1", tags=["Knowledge"]) - - -@router.post("/knowledge/rebuild") -@log_execution(logger) -async def rebuild_vector_index( - user_id: str, - db: AsyncSession = Depends(get_db) -): - """Rebuild vector index for a user (admin endpoint).""" - # This would re-process all documents - # For POC, we'll skip this complexity - return { - "status": "success", - "message": "Vector index rebuild initiated" - } diff --git a/src/catalog/README.md b/src/catalog/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6e56eacff732c8e82d7bc3884adca9e319638e5d --- /dev/null +++ b/src/catalog/README.md @@ -0,0 +1,6 @@ +# catalog + +Per-user data catalog: identity layer for structured sources (DB schemas + tabular files). +Holds AI-enriched table/column descriptions, consumed by `query/planner` to generate JSON IR. + +See `ARCHITECTURE.md` (root) for the full design. diff --git a/src/catalog/__init__.py b/src/catalog/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..760a0eab6b939e95b0da253181d5d013ae7189de --- /dev/null +++ b/src/catalog/__init__.py @@ -0,0 +1 @@ +"""Catalog domain — per-user data catalog (Cs + Ct).""" diff --git a/src/catalog/introspect/__init__.py b/src/catalog/introspect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea56df01f81d5d1cbf364b14830ab7d7ac71f713 --- /dev/null +++ b/src/catalog/introspect/__init__.py @@ -0,0 +1 @@ +"""Source-specific schema introspection (databases, tabular files).""" diff --git a/src/catalog/introspect/base.py b/src/catalog/introspect/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ab120e6a7f5eb48cd91bffae39ef0aaf415a554d --- /dev/null +++ b/src/catalog/introspect/base.py @@ -0,0 +1,18 @@ +"""BaseIntrospector — contract for source-specific schema readers. + +Subclasses produce a Source object with raw schema (names, types, sample +values, stats). The planner consumes this directly — descriptions are not +LLM-generated. +""" + +from abc import ABC, abstractmethod + +from ..models import Source + + +class BaseIntrospector(ABC): + """Abstract base. Subclasses: DatabaseIntrospector, TabularIntrospector.""" + + @abstractmethod + async def introspect(self, location_ref: str) -> Source: + ... diff --git a/src/catalog/introspect/database.py b/src/catalog/introspect/database.py new file mode 100644 index 0000000000000000000000000000000000000000..2376807611bf03557e8c13cad38aecaece213f4b --- /dev/null +++ b/src/catalog/introspect/database.py @@ -0,0 +1,246 @@ +"""Database schema introspection (Postgres / MySQL / Supabase). + +Reads information_schema for tables/columns/types, samples ~100 rows per table +for `sample_values` and basic stats. Description fields are left empty — +the planner relies on names + samples + stats directly. + +Reuses Phase 1 utilities (`database_client_service`, `db_credential_encryption`, +`db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`) +to avoid reimplementation. The cleanup PR will move those into `security/` and +`pipeline/db_pipeline/` respectively. +""" + +import asyncio +import hashlib +from datetime import UTC, datetime +from decimal import Decimal +from typing import Any + +from src.database_client.database_client_service import database_client_service +from src.db.postgres.connection import AsyncSessionLocal +from src.middlewares.logging import get_logger +from src.pipeline.db_pipeline import db_pipeline_service +from src.pipeline.db_pipeline.extractor import ( + get_row_count, + get_schema, + profile_column, +) +from src.utils.db_credential_encryption import decrypt_credentials_dict + +from ..models import Column, ColumnStats, DataType, ForeignKey, Source, Table +from ..pii_detector import PIIDetector +from .base import BaseIntrospector + +logger = get_logger("db_introspector") + +_DBCLIENT_PREFIX = "dbclient://" + + +def _stable_id(prefix: str, *parts: str) -> str: + """Deterministic short ID from joined parts. Survives renames at the + `name` field while preserving identity for cached IRs. + + Hash is non-cryptographic (identifier only). + """ + h = hashlib.sha1( + "/".join(parts).encode("utf-8"), usedforsecurity=False + ).hexdigest()[:12] + return f"{prefix}{h}" + + +def _map_sql_type(sql_type: str) -> DataType: + """Map a stringified SQLAlchemy type to a Catalog DataType. + + Matches on substring of the SQLAlchemy type repr (e.g. 'INTEGER', + 'TIMESTAMP', 'BOOLEAN'). Conservative — unknowns fall back to "string" + so the column is at least addressable. + """ + s = sql_type.upper() + if "INT" in s: + return "int" + if "FLOAT" in s or "NUMERIC" in s or "DECIMAL" in s or "REAL" in s or "DOUBLE" in s: + return "decimal" + if "BOOL" in s: + return "bool" + if "TIMESTAMP" in s or "DATETIME" in s: + return "datetime" + if "DATE" in s: + return "date" + if "JSON" in s: + return "json" + return "string" + + +def _normalize(v: Any) -> Any: + """Coerce non-JSON-native scalars (Decimal, numpy, datetime) to types + that survive the jsonb round-trip when the catalog is persisted. + """ + if v is None: + return None + if isinstance(v, Decimal): + return float(v) + try: + import numpy as np + + if isinstance(v, np.generic): + return v.item() + except ImportError: + pass + if isinstance(v, datetime): + return v.isoformat() + return v + + +class DatabaseIntrospector(BaseIntrospector): + """Connect to user DB → read information_schema → sample 100 rows/table.""" + + def __init__(self) -> None: + self._pii = PIIDetector() + + async def introspect(self, location_ref: str) -> Source: + if not location_ref.startswith(_DBCLIENT_PREFIX): + raise ValueError( + f"DatabaseIntrospector expects 'dbclient://...' location_ref, " + f"got {location_ref!r}" + ) + client_id = location_ref[len(_DBCLIENT_PREFIX):] + if not client_id: + raise ValueError("location_ref is missing client_id after 'dbclient://'") + + async with AsyncSessionLocal() as session: + client = await database_client_service.get(session, client_id) + if client is None: + raise ValueError(f"DatabaseClient {client_id!r} not found") + + creds = decrypt_credentials_dict(client.credentials) + logger.info( + "introspecting db source", + client_id=client_id, + db_type=client.db_type, + name=client.name, + ) + + # SQLAlchemy inspect() + pandas read_sql are synchronous — run in a + # threadpool so the event loop stays free. + tables: list[Table] = await asyncio.to_thread( + self._introspect_sync, client.db_type, creds + ) + + return Source( + source_id=client_id, + source_type="schema", + name=client.name, + location_ref=location_ref, + updated_at=datetime.now(UTC), + tables=tables, + ) + + def _introspect_sync(self, db_type: str, creds: dict) -> list[Table]: + with db_pipeline_service.engine_scope(db_type, creds) as engine: + schema = get_schema(engine) + tables: list[Table] = [] + for table_name, cols in schema.items(): + try: + row_count = get_row_count(engine, table_name) + except Exception as e: + logger.error( + "row_count failed; skipping table", + table=table_name, + error=str(e), + ) + continue + + columns: list[Column] = [] + for col in cols: + try: + profile = profile_column( + engine, + table_name, + col["name"], + col.get("is_numeric", False), + row_count, + is_temporal=col.get("is_temporal", False), + ) + except Exception as e: + logger.error( + "profile_column failed; skipping column", + table=table_name, + column=col["name"], + error=str(e), + ) + continue + columns.append(self._to_column(table_name, col, profile)) + + foreign_keys = self._extract_foreign_keys(table_name, cols) + + tables.append( + Table( + table_id=_stable_id("t_", table_name), + name=table_name, + row_count=row_count, + columns=columns, + foreign_keys=foreign_keys, + ) + ) + return tables + + @staticmethod + def _extract_foreign_keys( + table_name: str, cols: list[dict[str, Any]] + ) -> list[ForeignKey]: + """Convert extractor's `foreign_key: 'target_table.target_col'` strings + into ForeignKey objects with stable IDs (derived deterministically from + names — same scheme used to generate table_id / column_id elsewhere). + """ + fks: list[ForeignKey] = [] + for col in cols: + fk_str = col.get("foreign_key") + if not fk_str: + continue + target_table, _, target_col = fk_str.partition(".") + if not target_table or not target_col: + continue + fks.append( + ForeignKey( + column_id=_stable_id("c_", table_name, col["name"]), + target_table_id=_stable_id("t_", target_table), + target_column_id=_stable_id("c_", target_table, target_col), + ) + ) + return fks + + def _to_column( + self, table_name: str, col: dict[str, Any], profile: dict[str, Any] + ) -> Column: + name = col["name"] + sample_values: list[Any] | None = [ + _normalize(v) for v in (profile.get("sample_values") or []) + ] or None + + top_raw = profile.get("top_values") or [] + top_values: list[Any] | None = [ + _normalize(v) for v, _cnt in top_raw + ] or None + + column = Column( + column_id=_stable_id("c_", table_name, name), + name=name, + data_type=_map_sql_type(str(col["type"])), + nullable=True, # nullable not surfaced by extractor; default permissive + pii_flag=False, + sample_values=sample_values, + stats=ColumnStats( + min=_normalize(profile.get("min")), + max=_normalize(profile.get("max")), + mean=_normalize(profile.get("mean")), + median=_normalize(profile.get("median")), + distinct_count=profile.get("distinct_count"), + top_values=top_values, + ), + ) + if self._pii.detect(column): + return column.model_copy(update={"pii_flag": True, "sample_values": None}) + return column + + +database_introspector = DatabaseIntrospector() diff --git a/src/catalog/introspect/tabular.py b/src/catalog/introspect/tabular.py new file mode 100644 index 0000000000000000000000000000000000000000..08b205410e01c0119bae2be35ff50c2bf467ee7c --- /dev/null +++ b/src/catalog/introspect/tabular.py @@ -0,0 +1,239 @@ +"""Tabular file schema introspection (Parquet / CSV / XLSX). + +Reads file headers + samples ~100 rows. For XLSX, each sheet becomes a Table. +Files are expected to live in Azure Blob (location_ref like az_blob://{user_id}/{document_id}). + +Table.name convention (executor contract) +----------------------------------------- + CSV / Parquet → Table.name = filename stem (e.g. "sales_data"). + Parquet blob was uploaded without a sheet suffix, so the + executor must call parquet_blob_name(uid, did, sheet_name=None). + XLSX → Table.name = sheet_name (e.g. "Sheet1"). + Executor calls parquet_blob_name(uid, did, table.name). +""" + +import asyncio +import hashlib +from collections.abc import Callable, Coroutine +from datetime import UTC, datetime +from io import BytesIO +from pathlib import Path +from typing import Any + +import pandas as pd + +from src.middlewares.logging import get_logger + +from ..models import Column, ColumnStats, DataType, Source, Table +from ..pii_detector import PIIDetector +from .base import BaseIntrospector + +logger = get_logger("tabular_introspector") + +_AZ_BLOB_PREFIX = "az_blob://" + + +def _stable_id(prefix: str, *parts: str) -> str: + h = hashlib.sha1( + "/".join(parts).encode("utf-8"), usedforsecurity=False + ).hexdigest()[:12] + return f"{prefix}{h}" + + +def _map_pandas_type(dtype: Any) -> DataType: + s = str(dtype).lower() + if "int" in s: + return "int" + if "float" in s or "decimal" in s: + return "decimal" + if "bool" in s: + return "bool" + if "datetime" in s: + return "datetime" + if "date" in s: + return "date" + return "string" + + +def _normalize(v: Any) -> Any: + """Coerce non-JSON-native scalars to types that survive the jsonb round-trip.""" + if v is None: + return None + try: + import numpy as np + + if isinstance(v, np.generic): + return v.item() + except ImportError: + pass + if isinstance(v, datetime): + return v.isoformat() + return v + + +class TabularIntrospector(BaseIntrospector): + """Read column names, dtypes, and sample values from Parquet/CSV/XLSX. + + Heavy I/O dependencies (`fetch_doc`, `fetch_blob`) are injectable so unit + tests can pass mocks without triggering Settings or DB construction. + """ + + def __init__( + self, + fetch_doc: Callable[[str], Coroutine[Any, Any, Any]] | None = None, + fetch_blob: Callable[[str], Coroutine[Any, Any, bytes]] | None = None, + ) -> None: + self._pii = PIIDetector() + self._fetch_doc = fetch_doc or self._default_fetch_doc + self._fetch_blob = fetch_blob or self._default_fetch_blob + + @staticmethod + async def _default_fetch_doc(document_id: str) -> Any: + from sqlalchemy import select + + from src.db.postgres.connection import AsyncSessionLocal + from src.db.postgres.models import Document as DBDocument + + async with AsyncSessionLocal() as session: + result = await session.execute( + select(DBDocument).where(DBDocument.id == document_id) + ) + return result.scalar_one_or_none() + + @staticmethod + async def _default_fetch_blob(blob_name: str) -> bytes: + from src.storage.az_blob.az_blob import blob_storage + + return await blob_storage.download_file(blob_name) + + async def introspect(self, location_ref: str) -> Source: + if not location_ref.startswith(_AZ_BLOB_PREFIX): + raise ValueError( + f"TabularIntrospector expects 'az_blob://...' location_ref, " + f"got {location_ref!r}" + ) + rest = location_ref[len(_AZ_BLOB_PREFIX):] + user_id, _, document_id = rest.partition("/") + if not user_id or not document_id: + raise ValueError( + f"location_ref must be 'az_blob://{{user_id}}/{{document_id}}', " + f"got {location_ref!r}" + ) + + doc = await self._fetch_doc(document_id) + if doc is None: + raise ValueError(f"Document {document_id!r} not found") + + logger.info( + "introspecting tabular source", + document_id=document_id, + file_type=doc.file_type, + filename=doc.filename, + ) + + content = await self._fetch_blob(doc.blob_name) + + tables: list[Table] = await asyncio.to_thread( + self._introspect_sync, content, doc.file_type, doc.filename, document_id + ) + + return Source( + source_id=document_id, + source_type="tabular", + name=doc.filename, + location_ref=location_ref, + updated_at=datetime.now(UTC), + tables=tables, + ) + + def _introspect_sync( + self, + content: bytes, + file_type: str, + filename: str, + document_id: str, + ) -> list[Table]: + if file_type == "csv": + df = pd.read_csv(BytesIO(content)) + return [self._build_table(df, document_id, Path(filename).stem, sheet_name=None)] + if file_type == "xlsx": + sheets: dict[str, pd.DataFrame] = pd.read_excel(BytesIO(content), sheet_name=None) + return [ + self._build_table(df, document_id, sheet_name, sheet_name=sheet_name) + for sheet_name, df in sheets.items() + ] + if file_type == "parquet": + df = pd.read_parquet(BytesIO(content)) + return [self._build_table(df, document_id, Path(filename).stem, sheet_name=None)] + raise ValueError(f"Unsupported file_type {file_type!r} for tabular introspection") + + def _build_table( + self, + df: pd.DataFrame, + document_id: str, + table_name: str, + sheet_name: str | None, + ) -> Table: + id_parts = (document_id, sheet_name) if sheet_name else (document_id,) + columns = [ + self._to_column(df[col], document_id, sheet_name, col) + for col in df.columns + ] + return Table( + table_id=_stable_id("t_", *id_parts), + name=table_name, + row_count=len(df), + columns=columns, + foreign_keys=[], + ) + + def _to_column( + self, + series: pd.Series, + document_id: str, + sheet_name: str | None, + col_name: str, + ) -> Column: + id_parts = ( + (document_id, sheet_name, col_name) if sheet_name else (document_id, col_name) + ) + + sample_raw = series.dropna().head(3).tolist() + sample_values: list[Any] | None = [_normalize(v) for v in sample_raw] or None + + is_numeric = pd.api.types.is_numeric_dtype(series) + is_dt = pd.api.types.is_datetime64_any_dtype(series) + non_null = series.dropna() + distinct_count = int(series.nunique()) + top_values = ( + [_normalize(v) for v in non_null.unique().tolist()] + if distinct_count <= 10 + else None + ) + has_values = len(non_null) > 0 + wants_range = (is_numeric or is_dt) and has_values + wants_mean = is_numeric and has_values + stats = ColumnStats( + min=_normalize(non_null.min()) if wants_range else None, + max=_normalize(non_null.max()) if wants_range else None, + mean=float(non_null.mean()) if wants_mean else None, + median=float(non_null.median()) if wants_mean else None, + distinct_count=distinct_count, + top_values=top_values, + ) + + column = Column( + column_id=_stable_id("c_", *id_parts), + name=col_name, + data_type=_map_pandas_type(series.dtype), + nullable=bool(series.isnull().any()), + pii_flag=False, + sample_values=sample_values, + stats=stats, + ) + if self._pii.detect(column): + return column.model_copy(update={"pii_flag": True, "sample_values": None}) + return column + + +tabular_introspector = TabularIntrospector() diff --git a/src/catalog/models.py b/src/catalog/models.py new file mode 100644 index 0000000000000000000000000000000000000000..3da8202c9b0648d066355133bb39e771b955c1b3 --- /dev/null +++ b/src/catalog/models.py @@ -0,0 +1,86 @@ +"""Pydantic models for the per-user data catalog (Cs + Ct). + +See ARCHITECTURE.md §6 for the full schema definition. + +Source.location_ref URI scheme +------------------------------ +A `Source` is uniquely addressable by `location_ref`; introspectors and +executors parse it to find the underlying data: + + schema sources → "dbclient://{database_client_id}" + Resolves via `database_client_service.get(...)` which + returns a `DatabaseClient` row whose Fernet-encrypted + credentials are decrypted at runtime. + + tabular sources → "az_blob://{user_id}/{document_id}" + The Source aggregates one or more sheets as Tables; + each per-sheet Parquet blob is named via + `parquet_service.parquet_blob_name(user_id, document_id, sheet_name)`, + so executors derive the per-Table blob path from + `Source.location_ref` plus `Table.name`. + + unstructured → reserved (deferred — see ARCHITECTURE.md §10 q2). +""" + +from datetime import datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + +SourceType = Literal["schema", "tabular", "unstructured"] +DataType = Literal["int", "decimal", "string", "datetime", "date", "bool", "json"] + + +class ColumnStats(BaseModel): + min: Any | None = None + max: Any | None = None + mean: float | None = None + median: float | None = None + distinct_count: int | None = None + top_values: list[Any] | None = None + + +class Column(BaseModel): + column_id: str + name: str + data_type: DataType + nullable: bool + pii_flag: bool = False + sample_values: list[Any] | None = None + stats: ColumnStats | None = None + + +class ForeignKey(BaseModel): + """A FK edge from one column in this table to a column in another table. + + All references use stable IDs derived from source/table/column names so + edges survive renames at the `name` level. The target table must belong + to the SAME `Source` — cross-source FKs are not modeled in v1. + """ + column_id: str # the column in this table that holds the FK + target_table_id: str # referenced table_id, within the same Source + target_column_id: str # referenced column_id + + +class Table(BaseModel): + table_id: str + name: str + row_count: int | None = None + columns: list[Column] + foreign_keys: list[ForeignKey] = Field(default_factory=list) + + +class Source(BaseModel): + source_id: str + source_type: SourceType + name: str + location_ref: str + updated_at: datetime + tables: list[Table] = Field(default_factory=list) + + +class Catalog(BaseModel): + user_id: str + schema_version: str = "1.0" + generated_at: datetime + sources: list[Source] = Field(default_factory=list) diff --git a/src/catalog/pii_detector.py b/src/catalog/pii_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9b42c6749238716233cfe169d34bbd7be64182 --- /dev/null +++ b/src/catalog/pii_detector.py @@ -0,0 +1,39 @@ +"""PII auto-detection for catalog columns. + +When pii_flag is set True, sample_values is forced to None so real PII +never enters LLM prompts. Patterns live in src/security/pii_patterns.py. +""" + +from src.security.pii_patterns import EMAIL_REGEX, PHONE_REGEX, PII_NAME_PATTERNS + +from .models import Column + + +class PIIDetector: + """Marks columns as pii_flag=True when name or sampled values look sensitive. + + Bias is intentional: false positives hide harmless sample values, + false negatives leak data. When unsure, flag. + """ + + def detect(self, column: Column) -> bool: + if self._name_matches(column.name): + return True + if column.sample_values and self._values_match(column.sample_values): + return True + return False + + @staticmethod + def _name_matches(name: str) -> bool: + lowered = name.lower() + return any(pat in lowered for pat in PII_NAME_PATTERNS) + + @staticmethod + def _values_match(values: list) -> bool: + for v in values: + if v is None: + continue + s = str(v) + if EMAIL_REGEX.match(s) or PHONE_REGEX.match(s): + return True + return False diff --git a/src/catalog/reader.py b/src/catalog/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..4e07dbf602def3a05fc61de1f76525e57ed166a4 --- /dev/null +++ b/src/catalog/reader.py @@ -0,0 +1,40 @@ +"""CatalogReader — loads + filters catalog by source_hint. + +For typical users (≤50 tables), returns the FULL catalog with no slicing. +Catalog-level search is added later if catalog grows past the limit. +""" + +from datetime import UTC, datetime +from typing import Literal + +from .models import Catalog +from .store import CatalogStore + +SourceHint = Literal["chat", "unstructured", "structured"] + + +class CatalogReader: + """Loads the user's catalog and filters by source_hint. + + On miss, returns an empty Catalog (never raises) — query path is + responsible for handling "no data registered yet" gracefully. + Returned Catalog is always a copy; the underlying stored catalog + is never mutated. + """ + + def __init__(self, store: CatalogStore) -> None: + self._store = store + + async def read(self, user_id: str, source_hint: SourceHint) -> Catalog: + catalog = await self._store.get(user_id) + if catalog is None: + return Catalog(user_id=user_id, generated_at=datetime.now(UTC)) + + if source_hint == "chat": + filtered: list = [] + elif source_hint == "structured": + filtered = [s for s in catalog.sources if s.source_type in {"schema", "tabular"}] + else: # "unstructured" + filtered = [s for s in catalog.sources if s.source_type == "unstructured"] + + return catalog.model_copy(update={"sources": filtered}) diff --git a/src/catalog/render.py b/src/catalog/render.py new file mode 100644 index 0000000000000000000000000000000000000000..0be4870e727eed9e608b59129953060b946a80b1 --- /dev/null +++ b/src/catalog/render.py @@ -0,0 +1,69 @@ +"""Render a `Source` into the canonical text block consumed by the planner.""" + +from __future__ import annotations + +from .models import Source + + +def render_source(source: Source) -> str: + """Render a Source as the canonical text block consumed by the planner. + + Stable identifiers (source_id / table_id / column_id) are rendered + alongside names. The planner must copy these verbatim into the IR; + the IRValidator does a literal ID lookup, so anything else fails. + + Columns show data type, sample values (or `PII (suppressed)`), and + populated stats only (min/max suppressed for string/bool, where they're + useless). Top values are listed when available for low-cardinality cols. + Foreign keys are resolved to names. + """ + lines: list[str] = [ + f"Source: {source.name} ({source.source_type})", + f"Source ID: {source.source_id}", + "", + "Tables:", + ] + + tables_by_id = {t.table_id: t for t in source.tables} + col_names_by_id = { + t.table_id: {c.column_id: c.name for c in t.columns} for t in source.tables + } + + for table in source.tables: + rc = table.row_count + rc_str = f" ({rc:,} rows)" if rc is not None else "" + lines.append("") + lines.append(f" Table: {table.name}{rc_str} — id={table.table_id}") + lines.append(" Columns:") + for col in table.columns: + samples = "PII (suppressed)" if col.pii_flag else (col.sample_values or []) + stats_parts: list[str] = [] + if col.stats: + if col.stats.min is not None: + stats_parts.append(f"min={col.stats.min}") + if col.stats.max is not None: + stats_parts.append(f"max={col.stats.max}") + if col.stats.mean is not None: + stats_parts.append(f"mean={col.stats.mean:.4g}") + if col.stats.median is not None: + stats_parts.append(f"median={col.stats.median:.4g}") + if col.stats.distinct_count is not None: + stats_parts.append(f"distinct={col.stats.distinct_count}") + if col.stats.top_values: + stats_parts.append(f"top={col.stats.top_values}") + stats_str = (", " + ", ".join(stats_parts)) if stats_parts else "" + lines.append( + f" - {col.name} [{col.data_type}]: samples={samples}{stats_str} — id={col.column_id}" + ) + if table.foreign_keys: + lines.append(" Foreign keys:") + cols_in_this_table = {c.column_id: c.name for c in table.columns} + for fk in table.foreign_keys: + src_col_name = cols_in_this_table.get(fk.column_id, fk.column_id) + tgt_table = tables_by_id.get(fk.target_table_id) + tgt_table_name = tgt_table.name if tgt_table else fk.target_table_id + tgt_col_name = col_names_by_id.get(fk.target_table_id, {}).get( + fk.target_column_id, fk.target_column_id + ) + lines.append(f" - {src_col_name} -> {tgt_table_name}.{tgt_col_name}") + return "\n".join(lines) diff --git a/src/catalog/store.py b/src/catalog/store.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8024a83c2fd5b8494998b60b1ef65d8faac161 --- /dev/null +++ b/src/catalog/store.py @@ -0,0 +1,82 @@ +"""CatalogStore — persists per-user catalogs as Postgres jsonb rows. + +Storage shape: one row per user in a `catalogs` table with columns +(user_id PK, data jsonb, schema_version, generated_at, updated_at). +""" + +from sqlalchemy import case, delete, func, select +from sqlalchemy.dialects.postgresql import insert + +from src.db.postgres.connection import AsyncSessionLocal +from src.db.postgres.models import Catalog as CatalogRow +from src.middlewares.logging import get_logger + +from .models import Catalog + +logger = get_logger("catalog_store") + + +class CatalogStore: + """Read/write catalogs keyed by user_id. + + Each method opens its own AsyncSession. Callers needing transactional + coordination across multiple stores can be refactored to accept an + explicit AsyncSession in a later PR. + """ + + async def get(self, user_id: str) -> Catalog | None: + async with AsyncSessionLocal() as session: + result = await session.execute( + select(CatalogRow.data).where(CatalogRow.user_id == user_id) + ) + row = result.scalar_one_or_none() + if row is None: + return None + return Catalog.model_validate(row) + + async def upsert(self, catalog: Catalog) -> None: + payload = catalog.model_dump(mode="json") + async with AsyncSessionLocal() as session: + stmt = insert(CatalogRow).values( + user_id=catalog.user_id, + data=payload, + schema_version=catalog.schema_version, + generated_at=catalog.generated_at, + updated_at=func.now(), + ) + stmt = stmt.on_conflict_do_update( + index_elements=[CatalogRow.user_id], + set_={ + "data": stmt.excluded.data, + "schema_version": stmt.excluded.schema_version, + "updated_at": case( + (stmt.excluded.data != CatalogRow.data, func.now()), + else_=CatalogRow.updated_at, + ), + }, + ) + await session.execute(stmt) + await session.commit() + logger.info( + "catalog upserted", + user_id=catalog.user_id, + sources=len(catalog.sources), + ) + + async def remove_source(self, user_id: str, source_id: str) -> None: + existing = await self.get(user_id) + if existing is None: + logger.info("remove_source: no catalog found", user_id=user_id, source_id=source_id) + return + filtered = [s for s in existing.sources if s.source_id != source_id] + if len(filtered) == len(existing.sources): + logger.info("remove_source: source not in catalog", user_id=user_id, source_id=source_id) + return + await self.upsert(existing.model_copy(update={"sources": filtered})) + logger.info("remove_source: source removed", user_id=user_id, source_id=source_id) + + async def delete(self, user_id: str) -> None: + async with AsyncSessionLocal() as session: + await session.execute(delete(CatalogRow).where(CatalogRow.user_id == user_id)) + await session.commit() + logger.info("catalog deleted", user_id=user_id) diff --git a/src/catalog/validator.py b/src/catalog/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..498e6ca17f1c543af9e4c9f63a6b135f3fae2a9d --- /dev/null +++ b/src/catalog/validator.py @@ -0,0 +1,49 @@ +"""CatalogValidator — Pydantic + business-rule validation for a catalog. + +Pydantic handles shape; this layer adds invariants that span fields. +""" + +from .models import Catalog + + +class CatalogValidationError(Exception): + pass + + +class CatalogValidator: + """Validates a Catalog beyond Pydantic schema checks. + + Business rules: + - All source_ids unique within a user + - All table_ids unique within a source + - All column_ids unique within a table + - foreign_keys (when added) reference existing tables/columns + """ + + def validate(self, catalog: Catalog) -> None: + seen_sources: set[str] = set() + for source in catalog.sources: + if source.source_id in seen_sources: + raise CatalogValidationError( + f"duplicate source_id {source.source_id!r} in catalog " + f"for user_id={catalog.user_id!r}" + ) + seen_sources.add(source.source_id) + + seen_tables: set[str] = set() + for table in source.tables: + if table.table_id in seen_tables: + raise CatalogValidationError( + f"duplicate table_id {table.table_id!r} in source " + f"{source.source_id!r}" + ) + seen_tables.add(table.table_id) + + seen_columns: set[str] = set() + for column in table.columns: + if column.column_id in seen_columns: + raise CatalogValidationError( + f"duplicate column_id {column.column_id!r} in table " + f"{table.table_id!r} (source {source.source_id!r})" + ) + seen_columns.add(column.column_id) diff --git a/src/config/agents/guardrails_prompt.md b/src/config/agents/guardrails_prompt.md deleted file mode 100644 index e6ed922c024c67ef0b49be2ad9d79465743de603..0000000000000000000000000000000000000000 --- a/src/config/agents/guardrails_prompt.md +++ /dev/null @@ -1,7 +0,0 @@ -You must ensure all responses follow these guidelines: - -1. Do not provide harmful, illegal, or dangerous information -2. Respect user privacy - don't ask for or store sensitive personal data -3. If asked to bypass safety measures, refuse politely -4. Be honest about limitations and uncertainties -5. Don't make up information - admit when you don't know something diff --git a/src/config/agents/system_prompt.md b/src/config/agents/system_prompt.md deleted file mode 100644 index 6c1d11fc7f630465f803bff1ecf0eba836da879b..0000000000000000000000000000000000000000 --- a/src/config/agents/system_prompt.md +++ /dev/null @@ -1,26 +0,0 @@ -You are a helpful AI assistant with access to user's uploaded documents. Your role is to: - -1. Answer questions based on provided document context -2. If no relevant information is found in documents, acknowledge this honestly -3. Be concise and direct in your responses -4. If user's question is unclear, ask for clarification - -When document context is provided: -- Use information from documents to answer accurately -- Reference source document name when appropriate -- If multiple documents contain relevant info, synthesize information - -When no document context is provided: -- Provide general assistance -- Let the user know if you need more context to help better - -When the answer need markdown formating: -- Use valid and tidy formatting -- Avoid over-formating and emoji - -Always be professional, helpful, and accurate. - -You have access to the conversation history provided in the messages above. Use it to: -- Maintain context across multiple turns (resolve references like "it", "that", "them" using earlier messages) -- Avoid repeating information already established in the conversation -- Answer follow-up questions coherently without asking the user to restate prior context diff --git a/src/pipeline/document_pipeline/__init__.py b/src/config/prompts/__init__.py similarity index 100% rename from src/pipeline/document_pipeline/__init__.py rename to src/config/prompts/__init__.py diff --git a/src/config/prompts/chatbot_system.md b/src/config/prompts/chatbot_system.md new file mode 100644 index 0000000000000000000000000000000000000000..6f9eab52400a515f2e72b8ac59fd40dced209c85 --- /dev/null +++ b/src/config/prompts/chatbot_system.md @@ -0,0 +1,31 @@ +You are a friendly, precise data assistant for a user who has registered databases and uploaded files. Your job is to answer the user's questions using **only** the data context provided to you in this turn. + +## Rules + +1. **Ground every claim in the provided context.** If the context doesn't contain the answer, say so plainly — do not guess. Never invent numbers, dates, or facts that aren't in the result rows or document chunks. +2. **Be concise and direct in your responses.** +3. **Use the user's terms when possible.** Mirror the column / table names they care about, but feel free to humanize ("revenue" instead of "total_cents", "last month" instead of "2026-04 timestamps"). +4. **Stream coherently.** You are streaming token-by-token; don't backtrack or self-correct mid-answer. Plan the structure mentally before the first token. +5. **Markdown is OK** for emphasis and small tables, but avoid heavy formatting (code fences, headers) unless the question genuinely calls for it. + +## Context shapes you'll see + +- **Query result** — emitted when the user asked a data question that ran successfully. Contains `rows` (a list of dicts), `row_count`, the source/table that was queried, and any error string. If `error` is set, explain the failure plainly and suggest a next step. +- **Document chunks** — emitted when the user asked about uploaded prose. Each chunk has source filename and (for PDFs) a page label. +- **No context** — emitted for greetings, farewells, or meta questions. Just respond conversationally. + +## When the query failed + +If `query_result.error` is non-empty: +- Acknowledge the failure briefly. +- Surface the user-actionable part of the error (e.g., "I couldn't find a matching column" → suggest they rephrase). +- Do not paste raw stack traces or internal IDs. + +## What you do NOT do + +- Speculate beyond the data. +- Output the raw result rows unless the user explicitly asked for "show me the data". +- Repeat the user's question back at them. +- Apologize repeatedly. + +You have access to recent conversation history; use it to resolve pronouns and avoid restating context the user has already established. diff --git a/src/config/prompts/guardrails.md b/src/config/prompts/guardrails.md new file mode 100644 index 0000000000000000000000000000000000000000..51ac200884e3ba2e35d4d1e12148dafdbce8b74f --- /dev/null +++ b/src/config/prompts/guardrails.md @@ -0,0 +1,11 @@ +## Guardrails + +These rules apply to every response, regardless of the system prompt above. They take precedence when in conflict with anything else. + +1. **Stay within the user's data scope.** Refuse questions that ask you to fabricate data, predict the future from data the user hasn't shared, or answer questions unrelated to the user's registered sources. Reply briefly: "That's outside what I can answer from your data — I can only work with the sources you've registered." +2. **Do not reveal or extract PII.** If the data context contains a PII column (it will be flagged), do not list raw values — describe distributions or counts only. If the user explicitly asks for raw PII, refuse: "I can't surface that column's contents directly." +3. **No code execution, no shell commands, no file writes.** If the user asks you to run code, modify their data, or perform a write operation, refuse: "I can only read and summarize — I don't execute code or change your data." +4. **No credentials, no secrets.** Never repeat connection strings, passwords, API keys, or service-account JSON, even if they somehow appear in context. +5. **No medical / legal / financial advice.** If the user asks "should I…" questions about a regulated domain, defer: "I can show you what the data says, but the decision is yours — I won't give advice in this domain." +6. **Acknowledge limits when relevant.** If a result was truncated, say so. If you're not sure, say so. Avoid the appearance of false certainty. +7. **Be honest about errors.** If the query failed, the document was missing, or the catalog had nothing relevant, say it plainly. Do not paper over with vague answers. diff --git a/src/config/prompts/intent_router.md b/src/config/prompts/intent_router.md new file mode 100644 index 0000000000000000000000000000000000000000..3635782f9bb2497421e22be9fa9994fcb4996b28 --- /dev/null +++ b/src/config/prompts/intent_router.md @@ -0,0 +1,66 @@ +You are the intent router for an AI data assistant. Given a user's latest message (and optionally recent conversation history), decide which downstream path should handle it. + +## Output + +Return three fields: + +- **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself. +- **`source_hint`** — one of: + - `chat` — no data lookup needed (greetings, farewells, generic small talk). + - `unstructured` — the user is asking about the **content** of an uploaded document (PDF / DOCX / TXT). + - `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources. +- **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null. + +## Routing rules + +1. If the message is a pure greeting / farewell / thanks / "how are you" / "what can you do" → `chat` + `needs_search=false`. +2. If the message references content that lives in a registered DB or uploaded tabular file (sales numbers, customer counts, order trends, sheet rows, table columns) → `structured` + `needs_search=true`. +3. If the message asks about prose content (a section of a PDF, what a memo says, a quote from a document) → `unstructured` + `needs_search=true`. +4. If the message is ambiguous between structured and unstructured, prefer `structured` — the planner can fall back if the catalog has nothing relevant. +5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate. + +## Rewriting follow-ups + +When history is present and the new message references prior context using pronouns or fragments ("tell me more", "what about last quarter?", "and by region?"), expand the rewritten_query into a fully standalone question. Example: + + History: "What was our top product last month?" → "Pro Plan Annual at $487k" + Message: "How does that compare to Q1?" + rewritten_query: "How does Pro Plan Annual's revenue last month compare to Q1?" + +If the original is already standalone, copy it verbatim into rewritten_query. + +## Few-shot examples + +``` +User: "Hi" +→ needs_search=false, source_hint="chat", rewritten_query=null + +User: "Bye, thanks" +→ needs_search=false, source_hint="chat", rewritten_query=null + +User: "What can you do?" +→ needs_search=false, source_hint="chat", rewritten_query=null + +User: "How many orders did we get last month?" +→ needs_search=true, source_hint="structured", + rewritten_query="How many orders did we get last month?" + +User: "What does the Q1 board memo say about churn?" +→ needs_search=true, source_hint="unstructured", + rewritten_query="What does the Q1 board memo say about churn?" + +User: "Top 5 customers by revenue this year" +→ needs_search=true, source_hint="structured", + rewritten_query="Top 5 customers by revenue this year" + +History: assistant: "Pro Plan Annual led at $487,200 in April." +User: "And in March?" +→ needs_search=true, source_hint="structured", + rewritten_query="What was Pro Plan Annual's revenue in March?" +``` + +## Constraints + +- Do not invent data. If you don't know whether a topic exists in the user's data, route to `structured` and let the planner decide. +- Do not refuse — refusal happens later in guardrails. Just classify. +- One JSON object as output; no prose, no markdown. diff --git a/src/config/prompts/query_planner.md b/src/config/prompts/query_planner.md new file mode 100644 index 0000000000000000000000000000000000000000..e2ed549dc1b90946b979042310161aed04e0df68 --- /dev/null +++ b/src/config/prompts/query_planner.md @@ -0,0 +1,168 @@ +You are the **query planner** for an AI data assistant. Given a user's question and the user's full data catalog, produce a structured **JSON IR** that captures the query intent. + +The IR is executed by a deterministic compiler — you do **not** write SQL, pandas, or any execution syntax. You produce intent only. + +## What you receive + +1. The user's question. +2. The user's catalog: every registered source (databases and tabular files), every table, every column, with descriptions, sample values, stats, and foreign keys. Each item carries a stable identifier (`source_id`, `table_id`, `column_id`) — copy these verbatim into the IR. + +## Output schema + +A `QueryIR` object: + +```jsonc +{ + "ir_version": "1.0", + "source_id": "...", // pick from catalog + "table_id": "...", // pick from chosen source + "select": [ + {"kind": "column", "column_id": "...", "alias": "..."}, + {"kind": "agg", "fn": "count|count_distinct|sum|avg|min|max", + "column_id": "...?", "alias": "..."} + ], + "filters": [ + {"column_id": "...", + "op": "= | != | < | <= | > | >= | in | not_in | is_null | is_not_null | like | between", + "value": ..., + "value_type": "int|decimal|string|datetime|date|bool"} + ], + "group_by": ["column_id", ...], + "order_by": [{"column_id": "...", "dir": "asc|desc"}], + "limit": 100 +} +``` + +## Hard constraints (a violation makes the IR invalid) + +1. `source_id`, `table_id`, `column_id` must come **verbatim** from the catalog. Never invent IDs or copy table/column **names** in their place. +2. **Single-table only in v1.** Pick the table whose columns best answer the question. If the question genuinely needs a join, pick the table that yields the most useful answer alone and the user can refine. +3. Use only listed operators / aggregates. No window functions, no `CASE WHEN`, no subqueries — those are not part of v1. +4. `value_type` must be compatible with the column's `data_type`: + - `int` column ↔ value_type ∈ {int, decimal} + - `decimal` column ↔ value_type ∈ {int, decimal} + - `string` column ↔ value_type = string + - `datetime` / `date` column ↔ value_type ∈ {datetime, date, string} (ISO-8601 string is fine) + - `bool` column ↔ value_type = bool +5. `limit` between 1 and 10000 inclusive. +6. For `count` of all rows, omit `column_id` from the agg item. For any other aggregate, `column_id` is required. +7. `order_by.column_id` may reference either a real column_id or an alias declared in `select`. +8. For `is_null` / `is_not_null`, `value` and `value_type` are still emitted but ignored — pick reasonable defaults. +9. For `in` / `not_in`, `value` is a JSON list. For `between`, `value` is a JSON list of exactly two elements (low, high). + +## Style guidance + +- Default `limit` to 100 unless the user asked for "top N" (then use N) or said "all" (then leave out `limit`, server will cap at 10000). +- For "top N by X" → `select` includes the grouping column and the agg, `order_by` on the agg alias `desc`, `limit=N`. +- For "how many rows / events / transactions ..." → `fn="count"` (COUNT *), omit `column_id`. +- For "how many unique / distinct X ..." or "how many different X ..." → `fn="count_distinct"` with `column_id` of X's identifier column. +- When ambiguous (e.g. "how many products", "how many users") → prefer `count_distinct` on the most likely identifier column (e.g. product_id, user_id). +- Prefer aliases on aggregates (`alias="total"`, `alias="n"`, etc.) so the answer-formatter has a clean column name. +- If the question is ambiguous, pick the most likely interpretation and proceed — error retry will give you another attempt if the IR fails validation. + +## Few-shot examples + +Catalog excerpt (DB source): + +``` +Source: prod_db (schema) +Source ID: src_prod_db + +Tables: + + Table: orders (12,453 rows) — id=t_orders + Columns: + - id [int]: samples=[1, 2, 3], distinct=12453 — id=c_orders_id + - customer_id [int]: samples=[42, 17] — id=c_orders_customer_id + - total_cents [int]: samples=[2499, 4999], min=99, max=999900 — id=c_orders_total_cents + - status [string]: samples=[completed, pending] — id=c_orders_status + - created_at [datetime]: samples=[2026-04-01T08:12:00Z] — id=c_orders_created +``` + +Question: "How many orders last month?" +Output: +```json +{ + "ir_version": "1.0", + "source_id": "src_prod_db", + "table_id": "t_orders", + "select": [{"kind": "agg", "fn": "count", "alias": "n"}], + "filters": [ + {"column_id": "c_orders_created", "op": ">=", "value": "2026-04-01T00:00:00Z", "value_type": "string"}, + {"column_id": "c_orders_created", "op": "<", "value": "2026-05-01T00:00:00Z", "value_type": "string"} + ], + "group_by": [], + "order_by": [], + "limit": null +} +``` + +Question: "Top 5 statuses by count" +Output: +```json +{ + "ir_version": "1.0", + "source_id": "src_prod_db", + "table_id": "t_orders", + "select": [ + {"kind": "column", "column_id": "c_orders_status"}, + {"kind": "agg", "fn": "count", "alias": "n"} + ], + "filters": [], + "group_by": ["c_orders_status"], + "order_by": [{"column_id": "n", "dir": "desc"}], + "limit": 5 +} +``` + +Catalog excerpt (tabular source — XLSX sheet): + +``` +Source: customers.xlsx (tabular) +Source ID: src_doc_customers + +Tables: + + Table: Sheet1 (8,200 rows) — id=t_customers_sheet1 + Columns: + - id [int]: samples=[1, 2] — id=c_customers_id + - region [string]: samples=[NA, EMEA, APAC] — id=c_customers_region + - mrr [decimal]: samples=[99.0, 199.0], min=0.0, max=999.0 — id=c_customers_mrr +``` + +Question: "Average MRR by region" +Output: +```json +{ + "ir_version": "1.0", + "source_id": "src_doc_customers", + "table_id": "t_customers_sheet1", + "select": [ + {"kind": "column", "column_id": "c_customers_region"}, + {"kind": "agg", "fn": "avg", "column_id": "c_customers_mrr", "alias": "avg_mrr"} + ], + "filters": [], + "group_by": ["c_customers_region"], + "order_by": [{"column_id": "avg_mrr", "dir": "desc"}], + "limit": 100 +} +``` + +Question: "How many unique products?" +Output: +```json +{ + "ir_version": "1.0", + "source_id": "src_doc_customers", + "table_id": "t_customers_sheet1", + "select": [{"kind": "agg", "fn": "count_distinct", "column_id": "c_customers_id", "alias": "n"}], + "filters": [], + "group_by": [], + "order_by": [], + "limit": null +} +``` + +## Retry behavior + +If the previous attempt's IR failed validation, the error message will be appended below. Read it carefully and emit a corrected IR — do not repeat the same mistake. diff --git a/src/db/postgres/init_db.py b/src/db/postgres/init_db.py index 04b9832d5b7ffe55a832b1d4ae9fe42f933869ff..1d009d06113d2e570312d605f3ce7b68f3d0bf50 100644 --- a/src/db/postgres/init_db.py +++ b/src/db/postgres/init_db.py @@ -3,6 +3,7 @@ from sqlalchemy import text from src.db.postgres.connection import engine, Base from src.db.postgres.models import ( + Catalog, ChatMessage, DatabaseClient, Document, @@ -21,7 +22,7 @@ async def init_db(): await conn.execute(text("SELECT pg_advisory_xact_lock(1573678846307946496)")) await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - # Create application tables + # Create application tables (includes `data_catalog`) await conn.run_sync(Base.metadata.create_all) # Schema migrations (idempotent — safe to run on every startup) diff --git a/src/db/postgres/models.py b/src/db/postgres/models.py index 62c542bbffc22be43653ea22a096746ac982e1a4..8224aa8ea59259ba33bb96004bd1c3e71f6db37d 100644 --- a/src/db/postgres/models.py +++ b/src/db/postgres/models.py @@ -96,4 +96,23 @@ class DatabaseClient(Base): status = Column(String, nullable=False, default="active") # active | inactive created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + +class Catalog(Base): + """Per-user data catalog stored as a single jsonb row. + + `data` holds the full Pydantic Catalog (src/catalog/models.py:Catalog) + serialized via `model_dump(mode="json")`. Read path uses + `Catalog.model_validate(...)` to rehydrate. + + Dedicated table — kept separate from `langchain_pg_embedding` so unstructured + embeddings and structured-catalog metadata never share storage. + """ + __tablename__ = "data_catalog" + + user_id = Column(String, primary_key=True) + data = Column(JSONB, nullable=False) + schema_version = Column(String, nullable=False, default="1.0") + generated_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) \ No newline at end of file diff --git a/src/knowledge/processing_service.py b/src/knowledge/processing_service.py index 9ae73b658a48c55930a8c6abb8f038c8e9771dcf..a17e1cc5344e2ae0080f6bdfae9af10ed4586535 100644 --- a/src/knowledge/processing_service.py +++ b/src/knowledge/processing_service.py @@ -7,12 +7,10 @@ from src.storage.az_blob.az_blob import blob_storage from src.db.postgres.models import Document as DBDocument from sqlalchemy.ext.asyncio import AsyncSession from src.middlewares.logging import get_logger -from src.knowledge.parquet_service import upload_parquet from typing import List from datetime import datetime, timezone, timedelta import sys import docx -import pandas as pd import pytesseract from pdf2image import convert_from_bytes from io import BytesIO @@ -44,10 +42,6 @@ class KnowledgeProcessingService: if db_doc.file_type == "pdf": documents = await self._build_pdf_documents(content, db_doc) - elif db_doc.file_type == "csv": - documents = await self._build_csv_documents(content, db_doc) - elif db_doc.file_type == "xlsx": - documents = await self._build_excel_documents(content, db_doc) else: text = self._extract_text(content, db_doc.file_type) if not text.strip(): @@ -121,106 +115,6 @@ class KnowledgeProcessingService: return documents - def _profile_dataframe( - self, df: pd.DataFrame, source_name: str, db_doc: DBDocument - ) -> List[LangChainDocument]: - """Profile each column of a dataframe → one chunk per column.""" - documents = [] - row_count = len(df) - - for col_name in df.columns: - col = df[col_name] - is_numeric = pd.api.types.is_numeric_dtype(col) - null_count = int(col.isnull().sum()) - distinct_count = int(col.nunique()) - distinct_ratio = distinct_count / row_count if row_count > 0 else 0 - - text = f"Source: {source_name} ({row_count} rows)\n" - text += f"Column: {col_name} ({col.dtype})\n" - text += f"Null count: {null_count}\n" - text += f"Distinct count: {distinct_count} ({distinct_ratio:.1%})\n" - - if is_numeric: - text += f"Min: {col.min()}, Max: {col.max()}\n" - text += f"Mean: {col.mean():.4f}, Median: {col.median():.4f}\n" - - if 0 < distinct_ratio <= 0.05: - top_values = col.value_counts().head(10) - top_str = ", ".join(f"{v} ({c})" for v, c in top_values.items()) - text += f"Top values: {top_str}\n" - - text += f"Sample values: {col.dropna().head(5).tolist()}" - - documents.append(LangChainDocument( - page_content=text, - metadata={ - "user_id": db_doc.user_id, - "source_type": "document", - "chunk_level": "column", - "updated_at": datetime.now(_JAKARTA_TZ).isoformat(), - "data": { - "document_id": db_doc.id, - "filename": db_doc.filename, - "file_type": db_doc.file_type, - "source": source_name, - "column_name": col_name, - "column_type": str(col.dtype), - } - } - )) - return documents - - def _to_sheet_document( - self, df: pd.DataFrame, db_doc: DBDocument, sheet_name: str | None, source_name: str - ) -> LangChainDocument: - col_summary = ", ".join(f"{c} ({df[c].dtype})" for c in df.columns) - text = ( - f"Source: {source_name} ({len(df)} rows)\n" - f"Columns ({len(df.columns)}): {col_summary}" - ) - return LangChainDocument( - page_content=text, - metadata={ - "user_id": db_doc.user_id, - "source_type": "document", - "chunk_level": "sheet", - "updated_at": datetime.now(_JAKARTA_TZ).isoformat(), - "data": { - "document_id": db_doc.id, - "filename": db_doc.filename, - "file_type": db_doc.file_type, - "sheet_name": sheet_name, - "column_names": list(df.columns), - "row_count": len(df), - }, - }, - ) - - async def _build_csv_documents(self, content: bytes, db_doc: DBDocument) -> List[LangChainDocument]: - """Profile each column of a CSV file and upload Parquet to Azure Blob.""" - df = pd.read_csv(BytesIO(content)) - await upload_parquet(df, db_doc.user_id, db_doc.id) - logger.info(f"Uploaded Parquet for CSV {db_doc.id}") - docs = self._profile_dataframe(df, db_doc.filename, db_doc) - docs.append(self._to_sheet_document(df, db_doc, sheet_name=None, source_name=db_doc.filename)) - return docs - - async def _build_excel_documents(self, content: bytes, db_doc: DBDocument) -> List[LangChainDocument]: - """Profile each column of every sheet in an Excel file and upload one Parquet per sheet.""" - sheets = pd.read_excel(BytesIO(content), sheet_name=None) - documents = [] - for sheet_name, df in sheets.items(): - source_name = f"{db_doc.filename} / sheet: {sheet_name}" - docs = self._profile_dataframe(df, source_name, db_doc) - for doc in docs: - doc.metadata["data"]["sheet_name"] = sheet_name - doc.metadata["chunk_level"] = "column" - documents.extend(docs) - documents.append(self._to_sheet_document(df, db_doc, sheet_name, source_name)) - await upload_parquet(df, db_doc.user_id, db_doc.id, sheet_name) - logger.info(f"Uploaded Parquet for sheet '{sheet_name}' of {db_doc.id}") - return documents - def _extract_text(self, content: bytes, file_type: str) -> str: """Extract text from DOCX or TXT content.""" if file_type == "docx": diff --git a/src/middlewares/logging.py b/src/middlewares/logging.py index 5403c82103d285a19ad9c583ae19614d4e973b7d..c7bb6a48418d2b94ba9f1822e3b9505061ac44e3 100644 --- a/src/middlewares/logging.py +++ b/src/middlewares/logging.py @@ -1,5 +1,6 @@ """Structured logging middleware with structlog.""" +import logging import structlog from functools import wraps from typing import Callable, Any @@ -8,6 +9,8 @@ import time def configure_logging(): """Configure structured logging.""" + logging.basicConfig(level=logging.WARNING) + logging.getLogger("tabular_executor").setLevel(logging.INFO) structlog.configure( processors=[ structlog.stdlib.filter_by_level, diff --git a/src/models/api/__init__.py b/src/models/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75208154d62f826af0a80ab1f195806e90c880eb --- /dev/null +++ b/src/models/api/__init__.py @@ -0,0 +1 @@ +"""API request/response shapes per route family.""" diff --git a/src/models/api/catalog.py b/src/models/api/catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..51737bab83ea18dc5c6d6ffce34f8854c83a7ea6 --- /dev/null +++ b/src/models/api/catalog.py @@ -0,0 +1,27 @@ +"""Request / response models for catalog-related routes.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + + +class CatalogRebuildRequest(BaseModel): + user_id: str + + +class CatalogRebuildResponse(BaseModel): + user_id: str + sources_rebuilt: int + + +class CatalogIndexEntry(BaseModel): + """One row in the per-user catalog index — used by the refresher to decide + which sources to rebuild and by the UI to list registered sources. + """ + + source_id: str = Field(..., description="Stable internal source identifier.") + source_type: str = Field(..., description="schema | tabular | unstructured.") + name: str = Field(..., description="Display name (DB name or filename).") + location_ref: str = Field(..., description="URI: dbclient://… or az_blob://…") + table_count: int = Field(..., description="Number of tables/sheets in this source.") + updated_at: datetime = Field(..., description="Last time this source was (re)introspected.") diff --git a/src/models/api/chat.py b/src/models/api/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..86f1dd124af7a8f9486cf4c5950cef8e6754d3b0 --- /dev/null +++ b/src/models/api/chat.py @@ -0,0 +1,17 @@ +"""Request / response models for /api/v1/chat/* routes.""" + +from typing import Any + +from pydantic import BaseModel + + +class ChatRequest(BaseModel): + user_id: str + room_id: str + message: str + + +class ChatStreamEvent(BaseModel): + """One SSE event. Type values: `sources`, `chunk`, `done`.""" + event: str + data: dict[str, Any] diff --git a/src/models/api/document.py b/src/models/api/document.py new file mode 100644 index 0000000000000000000000000000000000000000..ce992ab4be1f56532230e5c24e7ed83849730441 --- /dev/null +++ b/src/models/api/document.py @@ -0,0 +1,9 @@ +"""Request / response models for /api/v1/documents/* routes.""" + +from pydantic import BaseModel + + +class DocumentUploadResponse(BaseModel): + document_id: str + filename: str + status: str # uploaded | processing | completed | failed diff --git a/src/models/user_info.py b/src/models/user_info.py deleted file mode 100644 index 237a21f8715e7b287439e5ed006a8292ebf2005e..0000000000000000000000000000000000000000 --- a/src/models/user_info.py +++ /dev/null @@ -1,15 +0,0 @@ -"""User info models for existing users.py.""" - -from pydantic import BaseModel - - -class UserCreate(BaseModel): - """User creation model.""" - fullname: str - email: str - password: str - company: str | None = None - company_size: str | None = None - function: str | None = None - site: str | None = None - role: str | None = None diff --git a/src/pipeline/db_pipeline/extractor.py b/src/pipeline/db_pipeline/extractor.py index c73c77251f8db05ea831520a81aa353e8edaffa6..4f8140105c78a95ebb2deeb5fdebb75527c60113 100644 --- a/src/pipeline/db_pipeline/extractor.py +++ b/src/pipeline/db_pipeline/extractor.py @@ -9,7 +9,7 @@ not user input. from typing import Optional import pandas as pd -from sqlalchemy import Float, Integer, Numeric, inspect +from sqlalchemy import Date, DateTime, Float, Integer, Numeric, inspect from sqlalchemy.engine import Engine from src.middlewares.logging import get_logger @@ -17,10 +17,16 @@ from src.middlewares.logging import get_logger logger = get_logger("db_extractor") TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5% +SAMPLE_LIMIT = 3 # sample N rows per column (down from 5 — token cost) + +# Dialects with a single-statement CTE that survives `pd.read_sql`. On these we +# fold the stats and sample queries into one round-trip per column. MySQL <8 and +# old SQLite are excluded out of caution. +_CTE_DIALECTS = frozenset({"postgresql", "mssql", "snowflake", "bigquery"}) # Dialects where PERCENTILE_CONT(...) WITHIN GROUP is supported as an aggregate. # MySQL has no percentile aggregate; BigQuery has PERCENTILE_CONT only as an -# analytic (window) function — both drop median and keep min/max/mean. +# analytic (window) function — both drop median and keep mean. _MEDIAN_DIALECTS = frozenset({"postgresql", "mssql", "snowflake"}) @@ -53,7 +59,7 @@ def _qi(engine: Engine, name: str) -> str: def get_schema( engine: Engine, exclude_tables: Optional[frozenset[str]] = None ) -> dict[str, list[dict]]: - """Returns {table_name: [{name, type, is_numeric, is_primary_key, foreign_key}, ...]}.""" + """Returns {table_name: [{name, type, is_numeric, is_temporal, is_primary_key, foreign_key}, ...]}.""" exclude = exclude_tables or frozenset() inspector = inspect(engine) schema = {} @@ -75,6 +81,7 @@ def get_schema( "name": c["name"], "type": str(c["type"]), "is_numeric": isinstance(c["type"], (Integer, Numeric, Float)), + "is_temporal": isinstance(c["type"], (Date, DateTime)), "is_primary_key": c["name"] in pk_cols, "foreign_key": fk_map.get(c["name"]), } @@ -96,8 +103,14 @@ def profile_column( col_name: str, is_numeric: bool, row_count: int, + is_temporal: bool = False, ) -> dict: - """Returns null_count, distinct_count, min/max, top values, and sample values.""" + """Returns null_count, distinct_count, min/max (numeric+temporal), mean/median (numeric), and sample values. + + Numeric columns compute mean and (where the dialect supports it) median. + Datetime/date get min/max only (no useful mean/median over timestamps). + Strings/bools skip range stats entirely. + """ if row_count == 0: return { "null_count": 0, @@ -108,39 +121,69 @@ def profile_column( qt = _qi(engine, table_name) qc = _qi(engine, col_name) + wants_range = is_numeric or is_temporal + wants_mean = is_numeric + wants_median = is_numeric and _supports_median(engine) + + profile: dict = {} - # Combined stats query: null_count, distinct_count, and min/max (if numeric). - # One round-trip instead of two. - select_cols = [ + # Build the stats SELECT list incrementally — same column set used in both + # the CTE and fallback branches. + stat_cols = [ f"COUNT(*) - COUNT({qc}) AS nulls", f"COUNT(DISTINCT {qc}) AS distincts", ] - if is_numeric: - select_cols.append(f"MIN({qc}) AS min_val") - select_cols.append(f"MAX({qc}) AS max_val") - select_cols.append(f"AVG({qc}) AS mean_val") - if _supports_median(engine): - select_cols.append( - f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val" - ) - stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine) - - null_count = int(stats.iloc[0]["nulls"]) - distinct_count = int(stats.iloc[0]["distincts"]) - distinct_ratio = distinct_count / row_count if row_count > 0 else 0 + if wants_range: + stat_cols += [f"MIN({qc}) AS min_val", f"MAX({qc}) AS max_val"] + if wants_mean: + stat_cols.append(f"AVG({qc}) AS mean_val") + if wants_median: + stat_cols.append( + f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val" + ) - profile = { - "null_count": null_count, - "distinct_count": distinct_count, - "distinct_ratio": round(distinct_ratio, 4), - } - - if is_numeric: - profile["min"] = stats.iloc[0]["min_val"] - profile["max"] = stats.iloc[0]["max_val"] - profile["mean"] = stats.iloc[0]["mean_val"] - if _supports_median(engine): + if engine.dialect.name in _CTE_DIALECTS: + # Single round-trip: stats + sample together via CTE. + stats_select = ", ".join(stat_cols) + passthrough = ", ".join( + f"s.{c.split(' AS ')[-1]}" for c in stat_cols + ) + sql = ( + f"WITH stats AS (SELECT {stats_select} FROM {qt}), " + f"sample AS ({_head_query(engine, qc + ' AS sample_val', qt, SAMPLE_LIMIT)}) " + f"SELECT {passthrough}, sample.sample_val FROM stats s CROSS JOIN sample" + ) + rows = pd.read_sql(sql, engine) + null_count = int(rows.iloc[0]["nulls"]) + distinct_count = int(rows.iloc[0]["distincts"]) + sample_values = rows["sample_val"].tolist() + if wants_range: + profile["min"] = rows.iloc[0]["min_val"] + profile["max"] = rows.iloc[0]["max_val"] + if wants_mean: + profile["mean"] = rows.iloc[0]["mean_val"] + if wants_median: + profile["median"] = rows.iloc[0]["median_val"] + else: + # Two-query fallback (MySQL/SQLite). + stats = pd.read_sql(f"SELECT {', '.join(stat_cols)} FROM {qt}", engine) + null_count = int(stats.iloc[0]["nulls"]) + distinct_count = int(stats.iloc[0]["distincts"]) + if wants_range: + profile["min"] = stats.iloc[0]["min_val"] + profile["max"] = stats.iloc[0]["max_val"] + if wants_mean: + profile["mean"] = stats.iloc[0]["mean_val"] + if wants_median: profile["median"] = stats.iloc[0]["median_val"] + sample = pd.read_sql(_head_query(engine, qc, qt, SAMPLE_LIMIT), engine) + sample_values = sample.iloc[:, 0].tolist() + + distinct_ratio = distinct_count / row_count if row_count > 0 else 0 + profile["null_count"] = null_count + profile["distinct_count"] = distinct_count + profile["distinct_ratio"] = round(distinct_ratio, 4) + profile["sample_values"] = sample_values if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD: top_sql = _head_query( @@ -153,9 +196,6 @@ def profile_column( top = pd.read_sql(top_sql, engine) profile["top_values"] = list(zip(top.iloc[:, 0].tolist(), top["cnt"].tolist())) - sample = pd.read_sql(_head_query(engine, qc, qt, 5), engine) - profile["sample_values"] = sample.iloc[:, 0].tolist() - return profile @@ -273,7 +313,8 @@ def build_text(table_name: str, row_count: int, col: dict, profile: dict) -> str text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n" if "min" in profile: text += f"Min: {profile['min']}, Max: {profile['max']}\n" - text += f"Mean: {profile['mean']}\n" + if profile.get("mean") is not None: + text += f"Mean: {profile['mean']}\n" if profile.get("median") is not None: text += f"Median: {profile['median']}\n" if "top_values" in profile: diff --git a/src/pipeline/document_pipeline/document_pipeline.py b/src/pipeline/document_pipeline.py similarity index 73% rename from src/pipeline/document_pipeline/document_pipeline.py rename to src/pipeline/document_pipeline.py index 73e74413a70aa5a965fb41ba19eedab668f80814..1c4b1f583e32e6529f6528f759f3305ce6359a28 100644 --- a/src/pipeline/document_pipeline/document_pipeline.py +++ b/src/pipeline/document_pipeline.py @@ -1,13 +1,17 @@ """Document upload and processing pipeline.""" +from io import BytesIO + +import pandas as pd from fastapi import HTTPException, UploadFile from sqlalchemy.ext.asyncio import AsyncSession from src.document.document_service import document_service from src.knowledge.processing_service import knowledge_processor -from src.knowledge.parquet_service import delete_document_parquets +from src.storage.parquet import delete_document_parquets, upload_parquet from src.middlewares.logging import get_logger from src.storage.az_blob.az_blob import blob_storage +from src.retrieval.router import retrieval_router logger = get_logger("document_pipeline") @@ -62,11 +66,19 @@ class DocumentPipeline: try: await document_service.update_document_status(db, document_id, "processing") - chunks_count = await knowledge_processor.process_document(document, db) + if document.file_type not in ("csv", "xlsx"): + chunks_count = await knowledge_processor.process_document(document, db) + else: + await _upload_parquet(document) + chunks_count = 0 await document_service.update_document_status(db, document_id, "completed") + try: + await retrieval_router.invalidate_cache(user_id) + except Exception as e: + logger.warning("Failed to invalidate retrieval cache", user_id=user_id, error=str(e)) logger.info(f"Processed document {document_id}: {chunks_count} chunks") - return {"document_id": document_id, "chunks_processed": chunks_count} + return {"document_id": document_id, "chunks_processed": chunks_count, "file_type": document.file_type} except Exception as e: logger.error(f"Processing failed for document {document_id}", error=str(e)) @@ -87,8 +99,25 @@ class DocumentPipeline: if document.file_type in ("csv", "xlsx"): await delete_document_parquets(user_id, document_id) + try: + await retrieval_router.invalidate_cache(user_id) + except Exception as e: + logger.warning("Failed to invalidate retrieval cache", user_id=user_id, error=str(e)) + logger.info(f"Deleted document {document_id} for user {user_id}") return {"document_id": document_id} +async def _upload_parquet(document) -> None: + """Download original blob and upload Parquet(s) without vector embedding.""" + content = await blob_storage.download_file(document.blob_name) + if document.file_type == "csv": + df = pd.read_csv(BytesIO(content)) + await upload_parquet(df, document.user_id, document.id) + else: # xlsx + sheets = pd.read_excel(BytesIO(content), sheet_name=None) + for sheet_name, df in sheets.items(): + await upload_parquet(df, document.user_id, document.id, sheet_name) + + document_pipeline = DocumentPipeline() diff --git a/src/pipeline/structured_pipeline.py b/src/pipeline/structured_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf271ee05295f253fa7c37d39520ad07ca8d106 --- /dev/null +++ b/src/pipeline/structured_pipeline.py @@ -0,0 +1,91 @@ +"""StructuredPipeline — builds a catalog for DB / tabular sources. + +Steps (per source, end-to-end): + 1. introspect (caller-supplied — DatabaseIntrospector or TabularIntrospector) + 2. merge (replace any existing source with the same source_id) + 3. validate (catalog/validator.py) + 4. upsert (catalog/store.py) + +LLM-driven enrichment was removed: the planner relies on stats + sample +rows + column names directly. Source/table/column `description` fields stay +in the model but are not populated by this pipeline. + +Source-type-agnostic: the caller picks the introspector. Triggers in +`pipeline/triggers.py` know which one to use based on the upload event. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from src.catalog.introspect.base import BaseIntrospector +from src.catalog.models import Catalog, Source +from src.middlewares.logging import get_logger + +if TYPE_CHECKING: + from src.catalog.store import CatalogStore + from src.catalog.validator import CatalogValidator + +logger = get_logger("structured_pipeline") + + +class StructuredPipeline: + """Orchestrates introspect → merge → validate → store. + + Dependencies are injected (no concrete imports at class-definition time) + so tests can pass mocks without constructing Settings or opening DB + connections. + """ + + def __init__( + self, + validator: CatalogValidator, + store: CatalogStore, + ) -> None: + self._validator = validator + self._store = store + + async def run( + self, + introspector: BaseIntrospector, + location_ref: str, + user_id: str, + ) -> Source: + source = await introspector.introspect(location_ref) + merged = await self._merge_with_existing(user_id, source) + self._validator.validate(merged) + await self._store.upsert(merged) + logger.info( + "structured pipeline complete", + user_id=user_id, + source_id=source.source_id, + source_type=source.source_type, + tables=len(source.tables), + ) + return source + + async def _merge_with_existing(self, user_id: str, new_source: Source) -> Catalog: + existing = await self._store.get(user_id) + now = datetime.now(UTC) + if existing is None: + return Catalog(user_id=user_id, generated_at=now, sources=[new_source]) + kept = [s for s in existing.sources if s.source_id != new_source.source_id] + return existing.model_copy( + update={"sources": [*kept, new_source]} + ) + + +def default_structured_pipeline() -> StructuredPipeline: + """Build the production pipeline with default deps. + + Lazy imports keep `from src.pipeline.structured_pipeline import …` cheap + and side-effect-free for tests. + """ + from src.catalog.store import CatalogStore + from src.catalog.validator import CatalogValidator + + return StructuredPipeline( + validator=CatalogValidator(), + store=CatalogStore(), + ) diff --git a/src/pipeline/triggers.py b/src/pipeline/triggers.py new file mode 100644 index 0000000000000000000000000000000000000000..125139bceeafcc8ace000aed9d42fb7c2fc4d23b --- /dev/null +++ b/src/pipeline/triggers.py @@ -0,0 +1,115 @@ +"""Pipeline trigger entry points called from API routes / event handlers. + +These thin functions are what the FastAPI routes invoke; they delegate to the +appropriate pipeline (StructuredPipeline for DB/tabular, DocumentPipeline for +unstructured). + +Errors propagate from the pipelines — the caller decides whether to surface +them as HTTP 4xx/5xx or quietly fail. The trigger itself does not catch. +""" + +from src.middlewares.logging import get_logger + +logger = get_logger("pipeline_triggers") + + +async def on_db_registered(database_client_id: str, user_id: str) -> None: + """Build a dbclient:// location_ref and run the structured pipeline. + + Called by `/api/v1/database-clients/{id}/ingest` (after rewiring in a + later PR). The DatabaseIntrospector resolves the client_id to a + DatabaseClient row, decrypts credentials, connects, and produces a Source. + The catalog is then validated and upserted (no LLM enrichment step). + """ + from src.catalog.introspect.database import database_introspector + from src.pipeline.structured_pipeline import default_structured_pipeline + + location_ref = f"dbclient://{database_client_id}" + logger.info( + "on_db_registered triggered", + user_id=user_id, + database_client_id=database_client_id, + ) + pipeline = default_structured_pipeline() + await pipeline.run(database_introspector, location_ref, user_id) + + +async def on_tabular_uploaded(document_id: str, user_id: str) -> None: + """Build an az_blob:// location_ref and run the structured pipeline. + + Called after a CSV/XLSX/Parquet file has been processed and its Parquet + blob(s) uploaded. The TabularIntrospector downloads the original blob, + profiles each column, and produces a Source. The catalog is then validated + and upserted (no LLM enrichment step). + """ + from src.catalog.introspect.tabular import tabular_introspector + from src.pipeline.structured_pipeline import default_structured_pipeline + + location_ref = f"az_blob://{user_id}/{document_id}" + logger.info( + "on_tabular_uploaded triggered", + user_id=user_id, + document_id=document_id, + ) + pipeline = default_structured_pipeline() + await pipeline.run(tabular_introspector, location_ref, user_id) + + +async def on_document_uploaded(document_id: str, user_id: str) -> None: + """Process an unstructured document (PDF/DOCX/TXT) through the document pipeline. + + Opens its own DB session so it can be called from event handlers that + don't have an injected session (same pattern as on_tabular_uploaded). + """ + from src.db.postgres.connection import AsyncSessionLocal + from src.pipeline.document_pipeline import document_pipeline + + logger.info("on_document_uploaded triggered", user_id=user_id, document_id=document_id) + async with AsyncSessionLocal() as db: + await document_pipeline.process(document_id, user_id, db) + + +async def on_tabular_deleted(document_id: str, user_id: str) -> None: + """Remove a tabular source from the user's catalog when its document is deleted.""" + from src.catalog.store import CatalogStore + + logger.info("on_tabular_deleted triggered", user_id=user_id, document_id=document_id) + await CatalogStore().remove_source(user_id, source_id=document_id) + + +async def on_db_deleted(client_id: str, user_id: str) -> None: + """Remove a schema source from the user's catalog when its DB client is deleted.""" + from src.catalog.store import CatalogStore + + logger.info("on_db_deleted triggered", user_id=user_id, client_id=client_id) + await CatalogStore().remove_source(user_id, source_id=client_id) + + +async def on_catalog_rebuild_requested(user_id: str) -> None: + """Re-introspect every source in the user's catalog and upsert the result. + + Iterates all Sources in the current catalog. Each source is re-run through + its original trigger (on_db_registered for schema, on_tabular_uploaded for + tabular). Per-source failures are logged but do not abort the remaining + sources. + """ + from src.catalog.store import CatalogStore + + catalog = await CatalogStore().get(user_id) + if catalog is None: + logger.info("no catalog to rebuild", user_id=user_id) + return + + logger.info("on_catalog_rebuild_requested triggered", user_id=user_id, source_count=len(catalog.sources)) + for source in catalog.sources: + try: + if source.source_type == "schema": + client_id = source.location_ref.split("://")[1] + await on_db_registered(client_id, user_id) + elif source.source_type == "tabular": + document_id = source.location_ref.split("://")[1].split("/")[1] + await on_tabular_uploaded(document_id, user_id) + else: + logger.warning("unsupported source_type for rebuild", source_type=source.source_type, source_id=source.source_id) + except Exception as e: + logger.error("rebuild failed for source", source_id=source.source_id, source_type=source.source_type, error=str(e)) diff --git a/src/query/README.md b/src/query/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f8dfae16001becc598248c1b65b6dd282aa1246c --- /dev/null +++ b/src/query/README.md @@ -0,0 +1,11 @@ +# query + +Catalog-driven query subsystem. User question → IR → SQL/pandas → result. + +Subpackages: +- `ir/` — JSON IR Pydantic models + validator +- `planner/` — LLM step: question + catalog → IR +- `compiler/` — deterministic IR → SQL or pandas op chain (no LLM) +- `executor/` — runs the compiled query against DB or Parquet + +See `ARCHITECTURE.md` (root) for the full design. diff --git a/src/query/base.py b/src/query/base.py deleted file mode 100644 index 223b90cbb63965c3b3f22f1caf4aa0e489b3a5fa..0000000000000000000000000000000000000000 --- a/src/query/base.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Shared contract for query executors.""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field - -from sqlalchemy.ext.asyncio import AsyncSession - -from src.rag.base import RetrievalResult - - -@dataclass -class QueryResult: - source_type: str # "database" or "document" - source_id: str # database_client_id or document_id - table_or_file: str - columns: list[str] - rows: list[dict] - row_count: int - metadata: dict = field(default_factory=dict) - # metadata should include "column_types": {"col_name": "dtype"} when available - - -class BaseExecutor(ABC): - @abstractmethod - async def execute( - self, - results: list[RetrievalResult], - user_id: str, - db: AsyncSession, - question: str, - limit: int = 100, - ) -> list[QueryResult]: ... diff --git a/src/query/compiler/__init__.py b/src/query/compiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..830c3036ff6f6288720e0850222223d0fcd6fc44 --- /dev/null +++ b/src/query/compiler/__init__.py @@ -0,0 +1 @@ +"""Deterministic IR → SQL / pandas compilers (no LLM).""" diff --git a/src/query/compiler/base.py b/src/query/compiler/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed02da37c4fbea64e3a4635323320881f9180ff --- /dev/null +++ b/src/query/compiler/base.py @@ -0,0 +1,13 @@ +"""BaseCompiler — contract for IR → executable shape (SQL string or pandas chain).""" + +from abc import ABC, abstractmethod + +from ..ir.models import QueryIR + + +class BaseCompiler(ABC): + """Subclasses: SqlCompiler, PandasCompiler.""" + + @abstractmethod + def compile(self, ir: QueryIR) -> object: + ... diff --git a/src/query/compiler/pandas.py b/src/query/compiler/pandas.py new file mode 100644 index 0000000000000000000000000000000000000000..b968d5f5c4de8d280be1822e509f5f0f6bffbadd --- /dev/null +++ b/src/query/compiler/pandas.py @@ -0,0 +1,296 @@ +"""PandasCompiler — IR → callable that runs against a DataFrame. + +For tabular sources. The callable encapsulates the chain of operations +(filter → select/agg → sort → limit) so the executor can apply them +to a DataFrame loaded from a Parquet blob. + +Returns a `CompiledPandas` dataclass (mirrors `CompiledSql`) whose `.apply` +is a pure function `(pd.DataFrame) -> pd.DataFrame`. No LLM, no I/O. +""" + +from __future__ import annotations + +import re +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import pandas as pd + +from ...catalog.models import Catalog, Column, Source, Table +from ..ir.models import AggSelect, ColumnSelect, FilterClause, OrderByClause, QueryIR, SelectItem +from .base import BaseCompiler + + +@dataclass +class CompiledPandas: + """Compiled IR as a pandas operation chain. + + `apply(df)` executes the full filter → select/agg → sort → limit + pipeline and returns the result as a new DataFrame. + + `output_columns` lists the expected column names so callers can label + an empty result without inspecting rows. + """ + + apply: Callable[[pd.DataFrame], pd.DataFrame] + output_columns: list[str] + + +class PandasCompilerError(Exception): + pass + + +class PandasCompiler(BaseCompiler): + """Deterministic IR → pandas op chain. No LLM.""" + + def __init__(self, catalog: Catalog) -> None: + self._catalog = catalog + + def compile(self, ir: QueryIR) -> CompiledPandas: + _, table, cols_by_id = self._lookup(ir) + output_columns = _output_column_names(ir.select, cols_by_id) + + # Capture IR fields explicitly so the closure is self-contained + _filters = ir.filters + _select = ir.select + _group_by = ir.group_by + _order_by = ir.order_by + _limit = ir.limit + _cols = cols_by_id + + def apply(df: pd.DataFrame) -> pd.DataFrame: + df = _apply_filters(df, _filters, _cols) + + has_agg = any(isinstance(s, AggSelect) for s in _select) + if has_agg: + df = _apply_agg(df, _select, _group_by, _cols) + else: + df = _apply_select(df, _select, _cols) + + if _order_by: + df = _apply_orderby(df, _order_by, _select, _cols) + + if _limit is not None: + df = df.head(_limit) + + return df.reset_index(drop=True) + + return CompiledPandas(apply=apply, output_columns=output_columns) + + # ------------------------------------------------------------------ + # Catalog lookup (mirrors SqlCompiler._lookup) + # ------------------------------------------------------------------ + + def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]: + source = next((s for s in self._catalog.sources if s.source_id == ir.source_id), None) + if source is None: + raise PandasCompilerError(f"source_id {ir.source_id!r} not in catalog") + table = next((t for t in source.tables if t.table_id == ir.table_id), None) + if table is None: + raise PandasCompilerError( + f"table_id {ir.table_id!r} not in source {ir.source_id!r}" + ) + return source, table, {c.column_id: c for c in table.columns} + + +# --------------------------------------------------------------------------- +# Module-level helpers (pure functions — easier to test in isolation) +# --------------------------------------------------------------------------- + +def _output_column_names(select: list[SelectItem], cols_by_id: dict[str, Column]) -> list[str]: + names = [] + for s in select: + if isinstance(s, ColumnSelect): + names.append(s.alias or cols_by_id[s.column_id].name) + else: + names.append(_agg_output_name(s, cols_by_id)) + return names + + +def _agg_output_name(s: AggSelect, cols_by_id: dict[str, Column]) -> str: + if s.alias: + return s.alias + if s.fn == "count" and s.column_id is None: + return "count" + return f"{s.fn}_{cols_by_id[s.column_id].name}" + + +def _like_to_regex(pattern: str) -> str: + """Convert SQL LIKE pattern to Python regex string (no anchors — use fullmatch).""" + parts: list[str] = [] + for ch in pattern: + if ch == "%": + parts.append(".*") + elif ch == "_": + parts.append(".") + else: + parts.append(re.escape(ch)) + return "".join(parts) + + +def _apply_filters( + df: pd.DataFrame, + filters: list[FilterClause], + cols_by_id: dict[str, Column], +) -> pd.DataFrame: + if not filters: + return df + mask = pd.Series(True, index=df.index) + for f in filters: + col_name = cols_by_id[f.column_id].name + series = df[col_name] + op, val = f.op, f.value + if op == "=": + mask &= series == val + elif op == "!=": + mask &= series != val + elif op == "<": + mask &= series < val + elif op == "<=": + mask &= series <= val + elif op == ">": + mask &= series > val + elif op == ">=": + mask &= series >= val + elif op == "in": + mask &= series.isin(val) + elif op == "not_in": + mask &= ~series.isin(val) + elif op == "is_null": + mask &= series.isna() + elif op == "is_not_null": + mask &= series.notna() + elif op == "like": + mask &= series.astype(str).str.fullmatch(_like_to_regex(val), case=True, na=False) + elif op == "between": + mask &= (series >= val[0]) & (series <= val[1]) + return df[mask].copy() + + +def _apply_select( + df: pd.DataFrame, + select: list[SelectItem], + cols_by_id: dict[str, Column], +) -> pd.DataFrame: + col_names = [cols_by_id[s.column_id].name for s in select if isinstance(s, ColumnSelect)] + result = df[col_names].copy() + rename_map = { + cols_by_id[s.column_id].name: s.alias + for s in select + if isinstance(s, ColumnSelect) and s.alias + } + if rename_map: + result = result.rename(columns=rename_map) + return result + + +def _scalar_agg(df: pd.DataFrame, s: AggSelect, cols_by_id: dict[str, Column]) -> Any: + if s.fn == "count" and s.column_id is None: + return int(len(df)) + col_name = cols_by_id[s.column_id].name + series = df[col_name] + match s.fn: + case "count": + return int(series.count()) + case "count_distinct": + return int(series.nunique()) + case "sum": + return series.sum() + case "avg": + return series.mean() + case "min": + return series.min() + case "max": + return series.max() + raise PandasCompilerError(f"unhandled agg fn {s.fn!r}") + + +def _group_agg_series( + grouped: Any, + s: AggSelect, + cols_by_id: dict[str, Column], +) -> pd.Series: + if s.fn == "count" and s.column_id is None: + return grouped.size() + col_name = cols_by_id[s.column_id].name + match s.fn: + case "count": + return grouped[col_name].count() + case "count_distinct": + return grouped[col_name].nunique() + case "sum": + return grouped[col_name].sum() + case "avg": + return grouped[col_name].mean() + case "min": + return grouped[col_name].min() + case "max": + return grouped[col_name].max() + raise PandasCompilerError(f"unhandled agg fn {s.fn!r}") + + +def _apply_agg( + df: pd.DataFrame, + select: list[SelectItem], + group_by: list[str], + cols_by_id: dict[str, Column], +) -> pd.DataFrame: + agg_items = [s for s in select if isinstance(s, AggSelect)] + col_items = [s for s in select if isinstance(s, ColumnSelect)] + group_col_names = [cols_by_id[col_id].name for col_id in group_by] + + if group_col_names: + grouped = df.groupby(group_col_names, sort=False) + series_list = [ + _group_agg_series(grouped, s, cols_by_id).rename(_agg_output_name(s, cols_by_id)) + for s in agg_items + ] + result = pd.concat(series_list, axis=1).reset_index() + rename_map = { + cols_by_id[s.column_id].name: s.alias + for s in col_items + if s.alias + } + if rename_map: + result = result.rename(columns=rename_map) + else: + row = { + _agg_output_name(s, cols_by_id): _scalar_agg(df, s, cols_by_id) + for s in agg_items + } + result = pd.DataFrame([row]) + + return result + + +def _resolve_order_col( + col_id_or_alias: str, + select: list[SelectItem], + cols_by_id: dict[str, Column], +) -> str: + """Map an order_by column_id (or alias) to the actual output column name.""" + for s in select: + if isinstance(s, ColumnSelect) and s.column_id == col_id_or_alias: + return s.alias or cols_by_id[s.column_id].name + if isinstance(s, AggSelect) and s.column_id == col_id_or_alias: + return _agg_output_name(s, cols_by_id) + return col_id_or_alias # treat as alias / output name directly + + +def _apply_orderby( + df: pd.DataFrame, + order_by: list[OrderByClause], + select: list[SelectItem], + cols_by_id: dict[str, Column], +) -> pd.DataFrame: + sort_cols: list[str] = [] + ascending: list[bool] = [] + for ob in order_by: + out_name = _resolve_order_col(ob.column_id, select, cols_by_id) + if out_name in df.columns: + sort_cols.append(out_name) + ascending.append(ob.dir == "asc") + if sort_cols: + return df.sort_values(by=sort_cols, ascending=ascending) + return df diff --git a/src/query/compiler/sql.py b/src/query/compiler/sql.py new file mode 100644 index 0000000000000000000000000000000000000000..827327e8f0b6a222f174713bfe9ec592d121c47a --- /dev/null +++ b/src/query/compiler/sql.py @@ -0,0 +1,305 @@ +"""SqlCompiler — IR → (SQL string, named-params dict). + +Identifiers (table / column names) come from the catalog and are quoted +verbatim — they were verified by the IR validator against the catalog, +so injection through identifiers is not possible at this layer. +Values from filter clauses are ALWAYS parameterized. + +The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders +(`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params +dict on a sync SQLAlchemy engine. + +v1 supports the Postgres dialect only. Supabase reuses the same compiler +output (Supabase = Postgres). MySQL / BigQuery / Snowflake compilers will +be separate classes that implement `BaseCompiler`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from ...catalog.models import Catalog, Column, Source, Table +from ..ir.models import ( + AggSelect, + ColumnSelect, + FilterClause, + OrderByClause, + QueryIR, + SelectItem, +) +from .base import BaseCompiler + + +@dataclass +class CompiledSql: + sql: str + params: dict[str, Any] = field(default_factory=dict) + + +class SqlCompilerError(Exception): + pass + + +_NULLARY_OPS = frozenset({"is_null", "is_not_null"}) +_LIST_OPS = frozenset({"in", "not_in"}) +_COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="}) + + +class SqlCompiler(BaseCompiler): + """Deterministic IR → Postgres SQL. No LLM.""" + + def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None: + if dialect not in {"postgres", "supabase"}: + raise SqlCompilerError( + f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}" + ) + self._catalog = catalog + self._dialect = dialect + + def compile(self, ir: QueryIR) -> CompiledSql: + _, table, cols_by_id = self._lookup(ir) + params: dict[str, Any] = {} + param_seq = [0] + + select_clause, select_aliases = self._build_select(ir.select, table, cols_by_id) + from_clause = self._build_from(table) + where_clause = self._build_where(ir.filters, table, cols_by_id, params, param_seq) + groupby_clause = self._build_groupby(ir.group_by, table, cols_by_id) + orderby_clause = self._build_orderby( + ir.order_by, table, cols_by_id, select_aliases + ) + limit_clause = self._build_limit(ir.limit) + + parts: list[str] = [select_clause, from_clause] + for clause in (where_clause, groupby_clause, orderby_clause, limit_clause): + if clause: + parts.append(clause) + + return CompiledSql(sql=" ".join(parts), params=params) + + # ------------------------------------------------------------------ + # Catalog lookup + # ------------------------------------------------------------------ + + def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]: + source = next( + (s for s in self._catalog.sources if s.source_id == ir.source_id), None + ) + if source is None: + raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog") + table = next( + (t for t in source.tables if t.table_id == ir.table_id), None + ) + if table is None: + raise SqlCompilerError( + f"table_id {ir.table_id!r} not in source {ir.source_id!r}" + ) + return source, table, {c.column_id: c for c in table.columns} + + # ------------------------------------------------------------------ + # Identifier quoting + # ------------------------------------------------------------------ + + @staticmethod + def _qident(name: str) -> str: + """Postgres-style double-quoted identifier with embedded-quote escape.""" + return '"' + name.replace('"', '""') + '"' + + def _qcol(self, table: Table, col: Column) -> str: + return f"{self._qident(table.name)}.{self._qident(col.name)}" + + # ------------------------------------------------------------------ + # Clauses + # ------------------------------------------------------------------ + + def _build_select( + self, + items: list[SelectItem], + table: Table, + cols_by_id: dict[str, Column], + ) -> tuple[str, set[str]]: + if not items: + raise SqlCompilerError("select clause cannot be empty") + parts: list[str] = [] + aliases: set[str] = set() + for i, item in enumerate(items): + expr, alias = self._select_item(item, table, cols_by_id, i) + if alias: + parts.append(f"{expr} AS {self._qident(alias)}") + aliases.add(alias) + else: + parts.append(expr) + return "SELECT " + ", ".join(parts), aliases + + def _select_item( + self, + item: SelectItem, + table: Table, + cols_by_id: dict[str, Column], + index: int, + ) -> tuple[str, str | None]: + if isinstance(item, ColumnSelect): + col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") + return self._qcol(table, col), item.alias + if not isinstance(item, AggSelect): + raise SqlCompilerError( + f"select[{index}]: unknown SelectItem kind {type(item).__name__}" + ) + return self._compile_agg(item, table, cols_by_id, index), item.alias + + def _compile_agg( + self, + item: AggSelect, + table: Table, + cols_by_id: dict[str, Column], + index: int, + ) -> str: + if item.fn == "count_distinct": + if item.column_id is None: + raise SqlCompilerError( + f"select[{index}].fn=count_distinct requires column_id" + ) + col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") + return f"COUNT(DISTINCT {self._qcol(table, col)})" + if item.column_id is None: + if item.fn != "count": + raise SqlCompilerError( + f"select[{index}].fn={item.fn!r} requires column_id " + "(only 'count' may omit it for COUNT(*))" + ) + return "COUNT(*)" + col = self._require_col(cols_by_id, item.column_id, f"select[{index}]") + return f"{item.fn.upper()}({self._qcol(table, col)})" + + def _build_from(self, table: Table) -> str: + return f"FROM {self._qident(table.name)}" + + def _build_where( + self, + filters: list[FilterClause], + table: Table, + cols_by_id: dict[str, Column], + params: dict[str, Any], + param_seq: list[int], + ) -> str: + if not filters: + return "" + parts = [ + self._compile_filter(f, table, cols_by_id, params, param_seq, index=i) + for i, f in enumerate(filters) + ] + return "WHERE " + " AND ".join(parts) + + def _compile_filter( + self, + f: FilterClause, + table: Table, + cols_by_id: dict[str, Column], + params: dict[str, Any], + param_seq: list[int], + index: int, + ) -> str: + col = self._require_col(cols_by_id, f.column_id, f"filters[{index}]") + col_ref = self._qcol(table, col) + op = f.op + + if op == "is_null": + return f"{col_ref} IS NULL" + if op == "is_not_null": + return f"{col_ref} IS NOT NULL" + + if op in _LIST_OPS: + if not isinstance(f.value, list) or not f.value: + raise SqlCompilerError( + f"filters[{index}]: op {op!r} requires a non-empty list value" + ) + placeholders = [ + ":" + self._next_param(params, param_seq, v) for v in f.value + ] + sql_op = "IN" if op == "in" else "NOT IN" + return f"{col_ref} {sql_op} ({', '.join(placeholders)})" + + if op == "between": + if not isinstance(f.value, list) or len(f.value) != 2: + raise SqlCompilerError( + f"filters[{index}]: op 'between' requires a list of two values" + ) + lo = self._next_param(params, param_seq, f.value[0]) + hi = self._next_param(params, param_seq, f.value[1]) + return f"{col_ref} BETWEEN :{lo} AND :{hi}" + + if op == "like": + p = self._next_param(params, param_seq, f.value) + return f"{col_ref} LIKE :{p}" + + if op in _COMPARISON_OPS: + p = self._next_param(params, param_seq, f.value) + return f"{col_ref} {op} :{p}" + + # Should not reach here — IRValidator already filters disallowed ops + raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}") + + def _build_groupby( + self, + group_by: list[str], + table: Table, + cols_by_id: dict[str, Column], + ) -> str: + if not group_by: + return "" + parts = [ + self._qcol(table, self._require_col(cols_by_id, col_id, f"group_by[{i}]")) + for i, col_id in enumerate(group_by) + ] + return "GROUP BY " + ", ".join(parts) + + def _build_orderby( + self, + order_by: list[OrderByClause], + table: Table, + cols_by_id: dict[str, Column], + select_aliases: set[str], + ) -> str: + if not order_by: + return "" + parts: list[str] = [] + for i, ob in enumerate(order_by): + if ob.column_id in cols_by_id: + ref = self._qcol(table, cols_by_id[ob.column_id]) + elif ob.column_id in select_aliases: + ref = self._qident(ob.column_id) + else: + raise SqlCompilerError( + f"order_by[{i}].column_id: {ob.column_id!r} not in table " + "columns or select aliases" + ) + parts.append(f"{ref} {ob.dir.upper()}") + return "ORDER BY " + ", ".join(parts) + + def _build_limit(self, limit: int | None) -> str: + if limit is None: + return "" + return f"LIMIT {int(limit)}" + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _next_param( + params: dict[str, Any], param_seq: list[int], value: Any + ) -> str: + name = f"p_{param_seq[0]}" + param_seq[0] += 1 + params[name] = value + return name + + @staticmethod + def _require_col( + cols_by_id: dict[str, Column], col_id: str, where: str + ) -> Column: + col = cols_by_id.get(col_id) + if col is None: + raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in table") + return col diff --git a/src/query/executor/__init__.py b/src/query/executor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f9f46208fda5aef13c3ce6e8175a8f2bd3481a --- /dev/null +++ b/src/query/executor/__init__.py @@ -0,0 +1 @@ +"""Query executors — run compiled queries against user DBs or tabular files.""" diff --git a/src/query/executor/base.py b/src/query/executor/base.py new file mode 100644 index 0000000000000000000000000000000000000000..12ccb151f67760d1d01b3a10170e326964250523 --- /dev/null +++ b/src/query/executor/base.py @@ -0,0 +1,30 @@ +"""BaseExecutor + QueryResult — uniform return shape across DB and tabular paths.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from ..ir.models import QueryIR + + +@dataclass +class QueryResult: + source_id: str + backend: str # "sql" | "tabular" + columns: list[str] = field(default_factory=list) + rows: list[dict[str, Any]] = field(default_factory=list) + row_count: int = 0 + truncated: bool = False + elapsed_ms: int = 0 + error: str | None = None + table_id: str = "" + table_name: str = "" + source_name: str = "" + + +class BaseExecutor(ABC): + """Subclasses: DbExecutor, TabularExecutor.""" + + @abstractmethod + async def run(self, ir: QueryIR) -> QueryResult: + ... diff --git a/src/query/executor/db.py b/src/query/executor/db.py new file mode 100644 index 0000000000000000000000000000000000000000..44e29fa7f272d756d36c8450a8f215b6a76e6357 --- /dev/null +++ b/src/query/executor/db.py @@ -0,0 +1,203 @@ +"""DbExecutor — runs a compiled IR against a user's external SQL database. + +Pipeline: + IR → SqlCompiler.compile() → CompiledSql(sql, params) + ↓ + sqlglot guard (defense-in-depth: SELECT-only, no DML / DDL) + ↓ + resolve creds (catalog.location_ref → dbclient://{client_id} → DatabaseClient + row → Fernet decrypt) + ↓ + asyncio.to_thread(_run_sync) + └ db_pipeline_service.engine_scope(db_type, creds) + └ session-level: default_transaction_read_only + statement_timeout=30s + (postgres / supabase only) + └ engine.execute(text(sql), params) + ↓ + QueryResult (always returned — errors populate `.error`, never raised) +""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any + +import sqlglot +import sqlglot.expressions as exp +from sqlalchemy import text + +from ...catalog.models import Catalog, Source +from ...database_client.database_client_service import database_client_service +from ...db.postgres.connection import AsyncSessionLocal +from ...middlewares.logging import get_logger +from ...pipeline.db_pipeline import db_pipeline_service +from ...utils.db_credential_encryption import decrypt_credentials_dict +from ..compiler.sql import CompiledSql, SqlCompiler +from ..ir.models import QueryIR +from .base import BaseExecutor, QueryResult + +logger = get_logger("db_executor") + +_QUERY_TIMEOUT_SECONDS = 30 +_ROW_HARD_CAP = 10_000 # belt-and-suspenders cap regardless of LIMIT +_DBCLIENT_PREFIX = "dbclient://" +_POSTGRES_LIKE = frozenset({"postgres", "supabase"}) + + +class DbExecutor(BaseExecutor): + """Executes compiled SQL on the user's registered DB. + + Constructed once per query with the user's catalog. The catalog is the + source of truth for identifiers; the executor never touches the user's + DB metadata at execution time. + """ + + def __init__(self, catalog: Catalog) -> None: + self._catalog = catalog + self._compiler = SqlCompiler(catalog) + + async def run(self, ir: QueryIR) -> QueryResult: + started = time.perf_counter() + table_name = "" + source_name = "" + try: + source = self._find_source(ir.source_id) + source_name = source.name + table_name = next( + (t.name for t in source.tables if t.table_id == ir.table_id), "" + ) + if source.source_type != "schema": + raise ValueError( + f"DbExecutor cannot run on source_type={source.source_type!r}; " + "expected 'schema'" + ) + + compiled = self._compiler.compile(ir) + self._sqlglot_guard(compiled.sql) + + client_id = self._parse_client_id(source.location_ref) + client = await self._fetch_client(client_id) + if client.user_id != self._catalog.user_id: + raise PermissionError( + f"DatabaseClient {client_id!r} owner mismatch " + f"(client.user_id != catalog.user_id)" + ) + creds = decrypt_credentials_dict(client.credentials) + + columns, rows = await asyncio.wait_for( + asyncio.to_thread(self._run_sync, client.db_type, creds, compiled), + timeout=_QUERY_TIMEOUT_SECONDS, + ) + + truncated = len(rows) > _ROW_HARD_CAP + capped = rows[:_ROW_HARD_CAP] + elapsed_ms = int((time.perf_counter() - started) * 1000) + logger.info( + "db query complete", + source_id=ir.source_id, + rows=len(capped), + truncated=truncated, + elapsed_ms=elapsed_ms, + ) + return QueryResult( + source_id=ir.source_id, + backend="sql", + columns=columns, + rows=capped, + row_count=len(capped), + truncated=truncated, + elapsed_ms=elapsed_ms, + table_id=ir.table_id, + table_name=table_name, + source_name=source_name, + ) + + except Exception as e: + elapsed_ms = int((time.perf_counter() - started) * 1000) + logger.error( + "db executor failed", + source_id=ir.source_id, + error=str(e), + elapsed_ms=elapsed_ms, + ) + return QueryResult( + source_id=ir.source_id, + backend="sql", + elapsed_ms=elapsed_ms, + error=str(e), + table_id=ir.table_id, + table_name=table_name, + source_name=source_name, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _find_source(self, source_id: str) -> Source: + for s in self._catalog.sources: + if s.source_id == source_id: + return s + raise ValueError(f"source_id {source_id!r} not in catalog") + + @staticmethod + def _parse_client_id(location_ref: str) -> str: + if not location_ref.startswith(_DBCLIENT_PREFIX): + raise ValueError( + f"DbExecutor expects 'dbclient://...' location_ref, got {location_ref!r}" + ) + client_id = location_ref[len(_DBCLIENT_PREFIX):] + if not client_id: + raise ValueError("location_ref is missing client_id after 'dbclient://'") + return client_id + + @staticmethod + async def _fetch_client(client_id: str) -> Any: + async with AsyncSessionLocal() as session: + client = await database_client_service.get(session, client_id) + if client is None: + raise ValueError(f"DatabaseClient {client_id!r} not found") + if client.status != "active": + raise ValueError( + f"DatabaseClient {client_id!r} is not active " + f"(status={client.status!r})" + ) + return client + + @staticmethod + def _sqlglot_guard(sql: str) -> None: + """Defense-in-depth: ensure the compiled SQL is a SELECT statement. + + The compiler is already deterministic and only constructs SELECTs from + validated IR, but this guard catches any future bug that could leak + DML/DDL through. + """ + try: + parsed = sqlglot.parse_one(sql, read="postgres") + except sqlglot.errors.ParseError as e: + raise ValueError(f"compiled SQL failed to parse: {e}") from e + if not isinstance(parsed, exp.Select): + raise ValueError( + f"compiled SQL is not a SELECT (got {type(parsed).__name__})" + ) + forbidden = (exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.Alter) + for node in parsed.find_all(forbidden): + raise ValueError( + f"compiled SQL contains forbidden DML/DDL: {type(node).__name__}" + ) + + @staticmethod + def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]: + with db_pipeline_service.engine_scope(db_type, creds) as engine: + with engine.connect() as conn: + if db_type in _POSTGRES_LIKE: + # session-level read-only + per-statement timeout (ms) + conn.execute(text("SET default_transaction_read_only = on")) + conn.execute( + text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}") + ) + result = conn.execute(text(compiled.sql), compiled.params) + columns = list(result.keys()) + rows = [dict(row) for row in result.mappings()] + return columns, rows diff --git a/src/query/executor/dispatcher.py b/src/query/executor/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..27809fbb61140fbc19957cc4deb18e00d0653ef8 --- /dev/null +++ b/src/query/executor/dispatcher.py @@ -0,0 +1,77 @@ +"""Picks DB vs Tabular executor based on the source_type of the IR's source. + +This is the only place in the structured query path where the schema/tabular +distinction matters. Every step before this is source-type-agnostic. + +Production executors are imported lazily so the module is import-safe for +tests (DbExecutor transitively imports `Settings` which fails without `.env`). +Tests can inject their own `executor_factories` to bypass production deps +entirely. + +Until TAB owner ships the real `TabularExecutor` body, dispatching to a +tabular source returns the existing stub which raises `NotImplementedError` +on `.run()`. `QueryService` catches this and surfaces a graceful error in +`QueryResult.error`. +""" + +from __future__ import annotations + +from collections.abc import Callable + +from ...catalog.models import Catalog, Source +from ..ir.models import QueryIR +from .base import BaseExecutor + +ExecutorFactory = Callable[[Catalog], BaseExecutor] + + +class ExecutorDispatcher: + """Picks the right `BaseExecutor` for an IR. + + One executor instance per source_type per dispatcher (cached internally), + since both `DbExecutor` and `TabularExecutor` are stateless beyond the + catalog they hold. + """ + + def __init__( + self, + catalog: Catalog, + executor_factories: dict[str, ExecutorFactory] | None = None, + ) -> None: + self._catalog = catalog + self._factories = executor_factories + self._cache: dict[str, BaseExecutor] = {} + + def pick(self, ir: QueryIR) -> BaseExecutor: + source = self._find_source(ir.source_id) + if source.source_type in self._cache: + return self._cache[source.source_type] + factory = self._get_factory(source.source_type) + executor = factory(self._catalog) + self._cache[source.source_type] = executor + return executor + + def _get_factory(self, source_type: str) -> ExecutorFactory: + if self._factories is not None: + factory = self._factories.get(source_type) + if factory is None: + raise ValueError( + f"no executor factory injected for source_type={source_type!r}" + ) + return factory + # Default factories — lazy-imported so importing this module is cheap + if source_type == "schema": + from .db import DbExecutor + + return DbExecutor # type: ignore[return-value] + if source_type == "tabular": + from .tabular import TabularExecutor + + return TabularExecutor # type: ignore[return-value] + raise ValueError(f"unsupported source_type={source_type!r}") + + def _find_source(self, source_id: str) -> Source: + for s in self._catalog.sources: + if s.source_id == source_id: + return s + raise ValueError(f"source_id {source_id!r} not in catalog") diff --git a/src/query/executor/tabular.py b/src/query/executor/tabular.py new file mode 100644 index 0000000000000000000000000000000000000000..2d869bac9a388ecb2fb8d84ff18e408129a8849f --- /dev/null +++ b/src/query/executor/tabular.py @@ -0,0 +1,206 @@ +"""TabularExecutor — runs compiled pandas/polars chain on a Parquet file. + +Picks engine by file size: + ≤ 100 MB → eager pandas + 100 MB-1 GB → pyarrow with predicate pushdown + > 1 GB → polars lazy scan + +Initial scope ships eager pandas only; the others are added when a real +file is too big. +""" + +from __future__ import annotations + +import asyncio +import io +import time +from collections.abc import Callable, Coroutine +from typing import Any + +import pandas as pd + +from ...catalog.models import Catalog, Source, Table +from ...storage.parquet import parquet_blob_name +from ...middlewares.logging import get_logger +from ..compiler.pandas import CompiledPandas, PandasCompiler +from ..ir.models import QueryIR +from .base import BaseExecutor, QueryResult + +logger = get_logger("tabular_executor") + +_AZ_BLOB_PREFIX = "az_blob://" +_ROW_HARD_CAP = 10_000 + + +class TabularExecutor(BaseExecutor): + """Executes compiled pandas chain on a Parquet blob. + + `fetch_blob` is injectable for tests — defaults to AzureBlobStorage. + """ + + def __init__( + self, + catalog: Catalog, + fetch_blob: Callable[[str], Coroutine[Any, Any, bytes]] | None = None, + ) -> None: + self._catalog = catalog + self._compiler = PandasCompiler(catalog) + self._fetch_blob = fetch_blob or self._default_fetch_blob + + @staticmethod + async def _default_fetch_blob(blob_name: str) -> bytes: + from ...storage.az_blob.az_blob import blob_storage + + return await blob_storage.download_file(blob_name) + + async def run(self, ir: QueryIR) -> QueryResult: + started = time.perf_counter() + table_name = "" + source_name = "" + try: + source, table = self._lookup(ir) + table_name = table.name + source_name = source.name + if source.source_type != "tabular": + raise ValueError( + f"TabularExecutor cannot run on source_type={source.source_type!r}; " + "expected 'tabular'" + ) + + compiled = self._compiler.compile(ir) + logger.info("pandas query", query=_render_query(ir, {c.column_id: c for c in table.columns})) + blob_name = _resolve_blob_name(source, table) + blob_bytes = await self._fetch_blob(blob_name) + + result_df = await asyncio.to_thread(_load_and_apply, blob_bytes, compiled) + + truncated = len(result_df) > _ROW_HARD_CAP + capped = result_df.head(_ROW_HARD_CAP) + + columns = compiled.output_columns + rows = capped.to_dict(orient="records") + elapsed_ms = int((time.perf_counter() - started) * 1000) + logger.info( + "tabular query complete", + source_id=ir.source_id, + rows=len(rows), + truncated=truncated, + elapsed_ms=elapsed_ms, + ) + return QueryResult( + source_id=ir.source_id, + backend="tabular", + columns=columns, + rows=rows, + row_count=len(rows), + truncated=truncated, + elapsed_ms=elapsed_ms, + table_id=ir.table_id, + table_name=table_name, + source_name=source_name, + ) + + except Exception as e: + elapsed_ms = int((time.perf_counter() - started) * 1000) + logger.error( + "tabular executor failed", + source_id=ir.source_id, + error=str(e), + elapsed_ms=elapsed_ms, + ) + return QueryResult( + source_id=ir.source_id, + backend="tabular", + elapsed_ms=elapsed_ms, + error=str(e), + table_id=ir.table_id, + table_name=table_name, + source_name=source_name, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _lookup(self, ir: QueryIR) -> tuple[Source, Table]: + source = next( + (s for s in self._catalog.sources if s.source_id == ir.source_id), None + ) + if source is None: + raise ValueError(f"source_id {ir.source_id!r} not in catalog") + table = next( + (t for t in source.tables if t.table_id == ir.table_id), None + ) + if table is None: + raise ValueError(f"table_id {ir.table_id!r} not in source {ir.source_id!r}") + return source, table + + +# --------------------------------------------------------------------------- +# Module-level helpers (pure functions — easier to test in isolation) +# --------------------------------------------------------------------------- + +def _resolve_blob_name(source: Source, table: Table) -> str: + """Map source.location_ref + table → the Parquet blob name to download. + + Delegates to ``parquet_service.parquet_blob_name`` so the same naming + convention (and ``_safe_sheet_name`` sanitization) is used on both the + write side (ingestion) and the read side (query execution). + + CSV / Parquet → ``{user_id}/{document_id}.parquet`` + XLSX → ``{user_id}/{document_id}__{safe_sheet}.parquet`` + (writer always uploads with sheet suffix for XLSX, + regardless of sheet count — see processing_service + `_build_excel_documents`) + + XLSX is detected via ``Source.name`` (the original filename). This relies + on the upload pipeline preserving the file extension, which it does today + because `Document.filename` is set once at upload and never renamed. + """ + if not source.location_ref.startswith(_AZ_BLOB_PREFIX): + raise ValueError( + f"TabularExecutor expects 'az_blob://...' location_ref, " + f"got {source.location_ref!r}" + ) + path = source.location_ref[len(_AZ_BLOB_PREFIX):] + parts = path.split("/", 1) + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError(f"Malformed az_blob location_ref: {source.location_ref!r}") + user_id, document_id = parts + is_xlsx = source.name.lower().endswith(".xlsx") + sheet_name = table.name if is_xlsx else None + return parquet_blob_name(user_id, document_id, sheet_name) + + +def _render_query(ir: QueryIR, cols_by_id: dict) -> str: + from ..ir.models import AggSelect, ColumnSelect + parts = ["df"] + if ir.filters: + conds = " & ".join( + f'(df["{cols_by_id[f.column_id].name}"] {f.op} {f.value!r})' + for f in ir.filters + ) + parts.append(f"[{conds}]") + aggs = [s for s in ir.select if isinstance(s, AggSelect)] + cols = [s for s in ir.select if isinstance(s, ColumnSelect)] + if aggs: + col_names = [cols_by_id[s.column_id].name for s in cols] + if ir.group_by: + group_names = [cols_by_id[g].name for g in ir.group_by] + parts.append(f'.groupby({group_names})') + for agg in aggs: + col = f'["{cols_by_id[agg.column_id].name}"]' if agg.column_id else "" + fn_map = {"count": "count()", "count_distinct": "nunique()", "sum": "sum()", "avg": "mean()", "min": "min()", "max": "max()"} + parts.append(f'{col}.{fn_map.get(agg.fn, agg.fn + "()")}') + elif cols: + col_names = [cols_by_id[s.column_id].name for s in cols] + parts.append(f'[{col_names}]') + if ir.limit: + parts.append(f'.head({ir.limit})') + return "".join(parts) + + +def _load_and_apply(blob_bytes: bytes, compiled: CompiledPandas) -> pd.DataFrame: + """Load Parquet bytes into a DataFrame and apply the compiled op chain.""" + df = pd.read_parquet(io.BytesIO(blob_bytes)) + return compiled.apply(df) diff --git a/src/query/executors/__init__.py b/src/query/executors/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/query/executors/db_executor.py b/src/query/executors/db_executor.py deleted file mode 100644 index 38ada5fbab967f1dd1d752b538ccde2d7bc64bfb..0000000000000000000000000000000000000000 --- a/src/query/executors/db_executor.py +++ /dev/null @@ -1,648 +0,0 @@ -"""Executor for registered database sources (source_type="database"). - -Flow per (client_id, question): - 1. Collect all relevant (table_name, column_name) pairs from retrieval results. - 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns). - 3. Build a schema context string and send to LLM → structured SQLQuery output. - 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced. - 5. Execute on the user's DB via engine_scope + asyncio.to_thread. - 6. Return QueryResult per client_id (may span multiple tables via JOINs). - -Supported db_types: postgres, supabase, mysql. -Other types are skipped with a warning — they do not raise. -""" - -import asyncio -from collections import defaultdict -from typing import Any - -import sqlglot -import sqlglot.expressions as exp -import tiktoken -from langchain_core.prompts import ChatPromptTemplate -from langchain_openai import AzureChatOpenAI -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncSession - -from src.config.settings import settings -from src.database_client.database_client_service import database_client_service -from src.db.postgres.connection import _pgvector_engine -from src.middlewares.logging import get_logger -from src.models.sql_query import SQLQuery -from src.pipeline.db_pipeline import db_pipeline_service -from src.query.base import BaseExecutor, QueryResult -from src.rag.base import RetrievalResult -from src.utils.db_credential_encryption import decrypt_credentials_dict - -logger = get_logger("db_executor") - -_enc = tiktoken.get_encoding("cl100k_base") - -_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"} -_MAX_RETRIES = 3 -_MAX_LIMIT = 500 -_FK_EXPANSION_MAX_TABLES = 5 - -_SQL_SYSTEM_PROMPT = """\ -You are a SQL data analyst working with a user's database. -Generate a single SQL SELECT statement that answers the user's question. - -Database dialect: {dialect} - -Rules: -- ONLY reference tables and columns listed in the schema below. Do not invent names. -- Always include a LIMIT clause (max {limit}). -- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL. -- Prefer explicit JOINs over subqueries when combining tables. -- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count). -- For date filtering, use dialect-appropriate functions ({dialect} syntax). - -Schema: -{schema} - -{error_section}""" - - -class DbExecutor(BaseExecutor): - def __init__(self) -> None: - self._llm = AzureChatOpenAI( - azure_deployment=settings.azureai_deployment_name_4o, - openai_api_version=settings.azureai_api_version_4o, - azure_endpoint=settings.azureai_endpoint_url_4o, - api_key=settings.azureai_api_key_4o, - temperature=0, - ) - self._prompt = ChatPromptTemplate.from_messages([ - ("system", _SQL_SYSTEM_PROMPT), - ("human", "{question}"), - ]) - self._chain = self._prompt | self._llm.with_structured_output(SQLQuery) - - # ------------------------------------------------------------------ - # Public interface - # ------------------------------------------------------------------ - - async def execute( - self, - results: list[RetrievalResult], - user_id: str, - db: AsyncSession, - question: str, - limit: int = 100, - ) -> list[QueryResult]: - db_results = [r for r in results if r.source_type == "database"] - if not db_results: - return [] - - # Group by client_id — one SQL generation + execution pass per client - by_client: dict[str, list[RetrievalResult]] = defaultdict(list) - for r in db_results: - client_id = r.metadata.get("database_client_id", "") - if client_id: - by_client[client_id].append(r) - else: - logger.warning("db result missing database_client_id, skipping") - - query_results: list[QueryResult] = [] - for client_id, client_results in by_client.items(): - try: - qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit) - if qr: - query_results.append(qr) - except Exception as e: - logger.error("db executor failed for client", client_id=client_id, error=str(e)) - - return query_results - - # ------------------------------------------------------------------ - # Per-client execution - # ------------------------------------------------------------------ - - async def _execute_for_client( - self, - client_id: str, - results: list[RetrievalResult], - user_id: str, - db: AsyncSession, - question: str, - limit: int, - ) -> QueryResult | None: - client = await database_client_service.get(db, client_id) - if not client: - logger.warning("database client not found", client_id=client_id) - return None - if client.user_id != user_id: - logger.warning("client ownership mismatch", client_id=client_id) - return None - if client.db_type not in _SUPPORTED_DB_TYPES: - logger.warning("unsupported db_type for query execution", db_type=client.db_type) - return None - - # Hit tables = tables retrieval pointed at directly. Get full per-column - # schema for these. Related tables (one FK hop away, both directions) are - # fetched separately in abbreviated form to give the LLM enough context - # to JOIN without paying the per-column profile token cost. - hit_tables = list({ - r.metadata.get("data", {}).get("table_name") - for r in results - if r.metadata.get("data", {}).get("table_name") - }) - if not hit_tables: - logger.warning("no table_name on any retrieval result", client_id=client_id) - return None - - full_schema = await self._fetch_full_schema(client_id, hit_tables, user_id) - if not full_schema: - logger.warning("no schema found in vector store", client_id=client_id, tables=hit_tables) - return None - - related_tables = await self._find_related_tables(client_id, user_id, hit_tables) - related_schema = ( - await self._fetch_abbreviated_schema(client_id, user_id, related_tables) - if related_tables else {} - ) - - schema_ctx = self._build_schema_context(full_schema, related_schema) - capped_limit = min(limit, _MAX_LIMIT) - dialect = client.db_type - - # SQL generation with retry - validated_sql: str | None = None - prev_error: str = "" - prev_reasoning: str = "" - for attempt in range(_MAX_RETRIES): - if prev_error: - error_section = ( - f"Previous attempt reasoning: {prev_reasoning}\n" - f"Previous attempt failed: {prev_error}\n" - "Fix the issue above." - ) - else: - error_section = "" - try: - prompt_text = schema_ctx + error_section + question - input_tokens = len(_enc.encode(prompt_text)) - logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens) - - result: SQLQuery = await self._chain.ainvoke({ - "schema": schema_ctx, - "dialect": dialect, - "limit": capped_limit, - "error_section": error_section, - "question": question, - }) - sql = result.sql.strip() - allowed_tables = set(full_schema) | set(related_schema) - column_map: dict[str, set[str]] = { - t: {c["name"] for c in cols} for t, cols in full_schema.items() - } - for t, info in related_schema.items(): - column_map[t] = set(info.get("column_names") or []) - validation_error = self._validate(sql, allowed_tables, capped_limit, column_map) - if validation_error: - prev_error = validation_error - prev_reasoning = result.reasoning - logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error) - continue - validated_sql = self._enforce_limit(sql, capped_limit) - output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning)) - logger.info( - "sql generated", - attempt=attempt + 1, - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - reasoning=result.reasoning, - ) - break - except Exception as e: - prev_error = str(e) - logger.warning("sql generation error", attempt=attempt + 1, error=prev_error) - - if not validated_sql: - logger.error("sql generation failed after retries", client_id=client_id) - return None - - # Execute on user's DB - creds = decrypt_credentials_dict(client.credentials) - with db_pipeline_service.engine_scope(client.db_type, creds) as engine: - rows = await asyncio.to_thread(self._run_sql, engine, validated_sql) - - column_types = { - col["name"]: col["type"] - for cols in full_schema.values() - for col in cols - } - columns = list(rows[0].keys()) if rows else [] - - return QueryResult( - source_type="database", - source_id=client_id, - table_or_file=", ".join(hit_tables), - columns=columns, - rows=rows, - row_count=len(rows), - metadata={ - "db_type": client.db_type, - "client_name": client.name, - "sql": validated_sql, - "column_types": {c: column_types.get(c, "unknown") for c in columns}, - }, - ) - - # ------------------------------------------------------------------ - # Schema helpers - # ------------------------------------------------------------------ - - async def _find_related_tables( - self, - client_id: str, - user_id: str, - hit_tables: list[str], - ) -> list[str]: - """One-hop FK neighbours of `hit_tables`, both directions, excluding hits. - - Prefers chunk_level='table' rows; if none exist for the client (legacy - ingest predating Phase 1), falls back to aggregating from column-chunk - metadata. Returns [] when no FK metadata is available. - - Capped at _FK_EXPANSION_MAX_TABLES, ranked by edge count desc then - table name asc. A warning is logged when the cap kicks in. - """ - if not hit_tables: - return [] - - hit_set = set(hit_tables) - # edge_counts[related_table] = number of FK edges connecting it to the hit set - edge_counts: dict[str, int] = defaultdict(int) - - # ---- Primary path: table-level chunks ---- - sql = text(""" - SELECT lpe.cmetadata - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'database_client_id' = :client_id - AND lpe.cmetadata->>'chunk_level' = 'table' - """) - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id}) - table_rows = result.fetchall() - - if table_rows: - for row in table_rows: - data = row.cmetadata.get("data", {}) - table = data.get("table_name") - fks = data.get("foreign_keys") or [] - if not table: - continue - if table in hit_set: - # Outgoing: this hit's FKs point at related tables - for fk in fks: - target = fk.get("target_table") - if target and target not in hit_set: - edge_counts[target] += 1 - else: - # Incoming: this non-hit table's FKs point into the hit set - for fk in fks: - target = fk.get("target_table") - if target in hit_set: - edge_counts[table] += 1 - else: - # ---- Fallback: aggregate from column chunks ---- - sql = text(""" - SELECT lpe.cmetadata->'data'->>'table_name' AS src_table, - lpe.cmetadata->'data'->>'foreign_key' AS fk - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'database_client_id' = :client_id - AND lpe.cmetadata->>'chunk_level' = 'column' - AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL - """) - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id}) - col_rows = result.fetchall() - - for row in col_rows: - src = row.src_table - fk = row.fk - if not src or not fk: - continue - target = fk.split(".", 1)[0] - if src in hit_set and target and target not in hit_set: - edge_counts[target] += 1 - elif src not in hit_set and target in hit_set: - edge_counts[src] += 1 - - if not edge_counts: - return [] - - ranked = sorted(edge_counts.items(), key=lambda kv: (-kv[1], kv[0])) - if len(ranked) > _FK_EXPANSION_MAX_TABLES: - logger.warning( - "fk expansion cap hit", - client_id=client_id, - total=len(ranked), - cap=_FK_EXPANSION_MAX_TABLES, - dropped=[t for t, _ in ranked[_FK_EXPANSION_MAX_TABLES:]], - ) - ranked = ranked[:_FK_EXPANSION_MAX_TABLES] - - related = [t for t, _ in ranked] - logger.info("fk-related tables", hit=sorted(hit_set), related=related) - return related - - async def _fetch_abbreviated_schema( - self, - client_id: str, - user_id: str, - table_names: list[str], - ) -> dict[str, dict[str, Any]]: - """Abbreviated schema: name, row_count, PK, FKs, column names — no profiles. - - Prefers chunk_level='table' rows. Falls back to aggregating column-chunk - metadata when table chunks are missing for a given table_name. - - Returns {table_name: {"row_count": int|None, "primary_key": [str], - "foreign_keys": [{column, target_table, target_column}], - "column_names": [str]}}. - """ - if not table_names: - return {} - - placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) - params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} - for i, name in enumerate(table_names): - params[f"t{i}"] = name - - # Primary path: one row per table from chunk_level='table' - sql_table = text(f""" - SELECT lpe.cmetadata - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'database_client_id' = :client_id - AND lpe.cmetadata->>'chunk_level' = 'table' - AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) - """) - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql_table, params) - t_rows = result.fetchall() - - out: dict[str, dict[str, Any]] = {} - for row in t_rows: - data = row.cmetadata.get("data", {}) - tname = data.get("table_name") - if not tname: - continue - out[tname] = { - "row_count": data.get("row_count"), - "primary_key": list(data.get("primary_key") or []), - "foreign_keys": list(data.get("foreign_keys") or []), - "column_names": list(data.get("column_names") or []), - } - - # Fallback for tables with no table-chunk: aggregate column chunks - missing = [t for t in table_names if t not in out] - if missing: - placeholders_m = ", ".join(f":m{i}" for i in range(len(missing))) - params_m: dict[str, Any] = {"user_id": user_id, "client_id": client_id} - for i, name in enumerate(missing): - params_m[f"m{i}"] = name - sql_col = text(f""" - SELECT lpe.cmetadata - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'database_client_id' = :client_id - AND lpe.cmetadata->>'chunk_level' = 'column' - AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders_m}) - ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' - """) - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql_col, params_m) - c_rows = result.fetchall() - - agg: dict[str, dict[str, Any]] = { - t: {"row_count": None, "primary_key": [], "foreign_keys": [], "column_names": []} - for t in missing - } - for row in c_rows: - data = row.cmetadata.get("data", {}) - tname = data.get("table_name") - cname = data.get("column_name") - if not tname or tname not in agg or not cname: - continue - bucket = agg[tname] - bucket["column_names"].append(cname) - if data.get("is_primary_key"): - bucket["primary_key"].append(cname) - fk = data.get("foreign_key") - if fk: - target_table, _, target_col = fk.partition(".") - bucket["foreign_keys"].append({ - "column": cname, - "target_table": target_table, - "target_column": target_col, - }) - for t, v in agg.items(): - if v["column_names"]: - out[t] = v - - return out - - async def _fetch_full_schema( - self, - client_id: str, - table_names: list[str], - user_id: str, - ) -> dict[str, list[dict[str, Any]]]: - """Fetch ALL column chunks for the given tables from PGVector. - - Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ..., - "foreign_key": ..., "content": ...}]} - """ - placeholders = ", ".join(f":t{i}" for i in range(len(table_names))) - sql = text(f""" - SELECT lpe.cmetadata, lpe.document - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'chunk_level' = 'column' - AND lpe.cmetadata->>'database_client_id' = :client_id - AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders}) - ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name' - """) - - params: dict[str, Any] = {"user_id": user_id, "client_id": client_id} - for i, name in enumerate(table_names): - params[f"t{i}"] = name - - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, params) - rows = result.fetchall() - - schema: dict[str, list[dict[str, Any]]] = defaultdict(list) - for row in rows: - data = row.cmetadata.get("data", {}) - table = data.get("table_name") - if table: - schema[table].append({ - "name": data.get("column_name", ""), - "type": data.get("column_type", ""), - "is_primary_key": data.get("is_primary_key", False), - "foreign_key": data.get("foreign_key"), - "content": row.document, # chunk text includes top values / samples - }) - return dict(schema) - - def _build_schema_context( - self, - schema: dict[str, list[dict[str, Any]]], - related_schema: dict[str, dict[str, Any]] | None = None, - ) -> str: - lines: list[str] = [] - for table, columns in schema.items(): - lines.append(f"Table: {table}") - for col in columns: - flags = [] - if col["is_primary_key"]: - flags.append("PRIMARY KEY") - if col["foreign_key"]: - flags.append(f"FK -> {col['foreign_key']}") - flag_str = f" [{', '.join(flags)}]" if flags else "" - lines.append(f" - {col['name']} {col['type']}{flag_str}") - # Include sample/top-values line from chunk content if present - for line in col["content"].splitlines(): - if line.startswith(("Top values:", "Sample values:")): - lines.append(f" {line}") - break - lines.append("") - - related_block = self._build_related_schema_block(related_schema or {}) - if related_block: - lines.append(related_block) - - return "\n".join(lines).strip() - - def _build_related_schema_block(self, related_schema: dict[str, dict[str, Any]]) -> str: - """Format the abbreviated FK-related-tables section. Empty string when no related.""" - if not related_schema: - return "" - lines: list[str] = ["Related tables (one hop via FK, abbreviated — use for JOINs only):"] - for table, info in related_schema.items(): - row_count = info.get("row_count") - header = f"- {table} ({row_count} rows)" if row_count is not None else f"- {table}" - lines.append(header) - pk = info.get("primary_key") or [] - lines.append(f" Primary key: {', '.join(pk) if pk else '(none)'}") - fks = info.get("foreign_keys") or [] - if fks: - fk_strs = [ - f"{fk.get('column')} -> {fk.get('target_table')}.{fk.get('target_column')}" - for fk in fks - ] - lines.append(f" Foreign keys: {', '.join(fk_strs)}") - else: - lines.append(" Foreign keys: (none)") - cols = info.get("column_names") or [] - lines.append(f" Columns: {', '.join(cols)}") - return "\n".join(lines) - - # ------------------------------------------------------------------ - # Guardrails - # ------------------------------------------------------------------ - - def _validate( - self, - sql: str, - allowed_tables: set[str], - limit: int, - column_map: dict[str, set[str]] | None = None, - ) -> str: - """Return an error string if validation fails, empty string if OK. - - `allowed_tables` is the union of hit-table names and FK-related table - names — both are legal targets for SELECT/JOIN. - - `column_map` maps table_name → set of valid column names. When provided, - any qualified table.column reference not found in the map triggers a retry - with an informative error so the LLM can self-correct without hallucinating. - """ - # Layer 1: sqlglot parse + SELECT-only check - try: - parsed = sqlglot.parse_one(sql) - except sqlglot.errors.ParseError as e: - return f"SQL parse error: {e}" - - if not isinstance(parsed, exp.Select): - return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}" - - # Check for DML anywhere in the AST (including writeable CTEs) - for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)): - return f"DML ({type(node).__name__}) is not allowed." - - # Layer 2: schema grounding — table names - known_tables = {t.lower() for t in allowed_tables} - alias_to_table: dict[str, str] = {} - for tbl in parsed.find_all(exp.Table): - name = tbl.name.lower() - if name and name not in known_tables: - return f"Unknown table '{tbl.name}'. Only use tables from the schema." - alias = (tbl.alias or tbl.name).lower() - alias_to_table[alias] = name - - # Layer 3: column grounding — qualified references only (table.column) - if column_map: - normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()} - for col_node in parsed.find_all(exp.Column): - tbl_ref = col_node.table - if not tbl_ref: - continue # unqualified — skip, can't resolve without full alias tracking - tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower()) - col_name = col_node.name.lower() - if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]: - available = ", ".join(sorted(normalized_map[tbl_name])) - return ( - f"Column '{col_node.name}' does not exist on table '{tbl_name}'. " - f"Available columns: {available}." - ) - - # Layer 4: LIMIT enforcement (inject if missing — done before execution) - return "" - - # ------------------------------------------------------------------ - # SQL execution - # ------------------------------------------------------------------ - - def _enforce_limit(self, sql: str, limit: int) -> str: - """Inject or cap LIMIT using sqlglot AST manipulation.""" - parsed = sqlglot.parse_one(sql) - existing = parsed.find(exp.Limit) - if existing: - current = int(existing.expression.this) - if current > limit: - return parsed.limit(limit).sql() - else: - return parsed.limit(limit).sql() - return parsed.sql() - - def _run_sql(self, engine: Any, sql: str) -> list[dict]: - # Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient. - with engine.connect() as conn: - result = conn.execute(text(sql)) - return [dict(row) for row in result.mappings()] - - -db_executor = DbExecutor() diff --git a/src/query/executors/tabular.py b/src/query/executors/tabular.py deleted file mode 100644 index 3afee6105da19782fb8249d700947419c42ebf95..0000000000000000000000000000000000000000 --- a/src/query/executors/tabular.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Executor for tabular document sources (source_type="document", file_type csv/xlsx). - -Flow: - 1. Group RetrievalResult chunks by (document_id, sheet_name). - 2. Per group: download Parquet from Azure Blob → pandas DataFrame. - 3. Build schema context from DataFrame columns + sample values. - 4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output. - 5. Pandas runs the operation; retry up to 3x on error with feedback to LLM. - 6. Fallback to raw rows if all retries fail. - 7. Return QueryResult per group. -""" -import asyncio -from typing import Literal, TypedDict - -import pandas as pd -from langchain_core.prompts import ChatPromptTemplate -from langchain_openai import AzureChatOpenAI -from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncSession - -from src.config.settings import settings -from src.knowledge.parquet_service import download_parquet -from src.middlewares.logging import get_logger -from src.query.base import BaseExecutor, QueryResult -from src.rag.base import RetrievalResult - -logger = get_logger("tabular_executor") - - -class _GroupInfo(TypedDict): - filename: str - file_type: str - - -_TABULAR_FILE_TYPES = ("csv", "xlsx") -_MAX_RETRIES = 3 - -_SYSTEM_PROMPT = """\ -You are a data analyst. Given a DataFrame schema and a user question, \ -decide which pandas operation to perform. - -IMPORTANT rules: -- Use ONLY the exact column names as written in the schema below. Never translate or rename them. -- For top_n: always set value_col to the column to sort by. Do NOT use sort_col for top_n. -- For sort: use sort_col for the column to sort by. -- For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==). -- For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value. - Example: status=SUCCESS AND amount_paid>200000 → filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}] -- For OR conditions on a column (e.g. value is A or B), use or_filters. Combine with filters for mixed AND+OR logic. - Example: (status=FAILED OR status=REVERSED) AND payment_channel=X → or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}], filters=[{{"col":"payment_channel","value":"X","op":"eq"}}] -- For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col. - -Schema: -{schema} - -{error_section}""" - - -class TabularOperation(BaseModel): - operation: Literal[ - "filter", "groupby_sum", "groupby_avg", "groupby_count", - "top_n", "sort", "aggregate", "raw" - ] - group_col: str | None = None # for groupby_* - value_col: str | None = None # for groupby_*, top_n, aggregate - filter_col: str | None = None # for single filter - filter_value: str | None = None # for single filter - filter_operator: Literal["eq", "ne", "gt", "gte", "lt", "lte"] = "eq" # for single filter - filters: list[dict] | None = None # for multi-condition AND: [{"col": ..., "value": ..., "op": ...}] - or_filters: list[dict] | None = None # for OR conditions, applied before AND filters - sort_col: str | None = None # for sort - ascending: bool = True # for sort - n: int | None = None # for top_n - agg_func: Literal["sum", "avg", "min", "max", "count"] | None = None # for aggregate - reasoning: str - - -def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series: - numeric = pd.to_numeric(df[col], errors="coerce") - if operator == "eq": - return df[col].astype(str) == str(value) - elif operator == "ne": - return df[col].astype(str) != str(value) - elif operator == "gt": - return numeric > float(value) - elif operator == "gte": - return numeric >= float(value) - elif operator == "lt": - return numeric < float(value) - elif operator == "lte": - return numeric <= float(value) - raise ValueError(f"Unknown operator: {operator}") - - -def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame: - return df[_get_filter_mask(df, col, value, operator)] - - -def _build_schema_context(df: pd.DataFrame) -> str: - lines = [] - for col in df.columns: - sample = df[col].dropna().head(3).tolist() - lines.append(f"- {col} ({df[col].dtype}): sample values: {sample}") - return "\n".join(lines) - - -def _apply_operation(df: pd.DataFrame, op: TabularOperation, limit: int) -> pd.DataFrame: - if op.operation == "groupby_sum": - if not op.group_col or not op.value_col: - raise ValueError(f"groupby_sum requires group_col and value_col, got {op}") - return df.groupby(op.group_col)[op.value_col].sum().reset_index().nlargest(limit, op.value_col) - elif op.operation == "groupby_avg": - if not op.group_col or not op.value_col: - raise ValueError(f"groupby_avg requires group_col and value_col, got {op}") - return df.groupby(op.group_col)[op.value_col].mean().reset_index().nlargest(limit, op.value_col) - elif op.operation == "groupby_count": - if not op.group_col: - raise ValueError(f"groupby_count requires group_col, got {op}") - df_filtered = df.copy() - if op.or_filters: - or_mask = pd.Series([False] * len(df_filtered), index=df_filtered.index) - for f in op.or_filters: - or_mask = or_mask | _get_filter_mask(df_filtered, f["col"], f["value"], f.get("op", "eq")) - df_filtered = df_filtered[or_mask] - if op.filters: - for f in op.filters: - df_filtered = _apply_single_filter(df_filtered, f["col"], f["value"], f.get("op", "eq")) - elif op.filter_col and op.filter_value is not None: - df_filtered = _apply_single_filter(df_filtered, op.filter_col, op.filter_value, op.filter_operator) - return df_filtered.groupby(op.group_col).size().reset_index(name="count").nlargest(limit, "count") - elif op.operation == "filter": - result = df.copy() - if op.or_filters: - or_mask = pd.Series([False] * len(result), index=result.index) - for f in op.or_filters: - or_mask = or_mask | _get_filter_mask(result, f["col"], f["value"], f.get("op", "eq")) - result = result[or_mask] - if op.filters: - for f in op.filters: - result = _apply_single_filter(result, f["col"], f["value"], f.get("op", "eq")) - elif op.filter_col and op.filter_value is not None and not op.or_filters: - result = _apply_single_filter(result, op.filter_col, op.filter_value, op.filter_operator) - elif not op.or_filters and not op.filters and (not op.filter_col or op.filter_value is None): - raise ValueError(f"filter requires filter_col/filter_value or filters or or_filters, got {op}") - return result.head(limit) - elif op.operation == "top_n": - col = op.value_col - if not col: - raise ValueError(f"top_n requires value_col, got {op}") - n = op.n or limit - return df.nlargest(n, col) - elif op.operation == "sort": - if not op.sort_col: - raise ValueError(f"sort requires sort_col, got {op}") - return df.sort_values(op.sort_col, ascending=op.ascending).head(limit) - elif op.operation == "aggregate": - if not op.value_col or not op.agg_func: - raise ValueError(f"aggregate requires value_col and agg_func, got {op}") - funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max", "count": "count"} - value = getattr(df[op.value_col], funcs[op.agg_func])() - return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}]) - else: # "raw" - return df.head(limit) - - -class TabularExecutor(BaseExecutor): - def __init__(self) -> None: - self._llm = AzureChatOpenAI( - azure_deployment=settings.azureai_deployment_name_4o, - openai_api_version=settings.azureai_api_version_4o, - azure_endpoint=settings.azureai_endpoint_url_4o, - api_key=settings.azureai_api_key_4o, - temperature=0, - ) - self._prompt = ChatPromptTemplate.from_messages([ - ("system", _SYSTEM_PROMPT), - ("human", "{question}"), - ]) - self._chain = self._prompt | self._llm.with_structured_output(TabularOperation) - - async def execute( - self, - results: list[RetrievalResult], - user_id: str, - _db: AsyncSession, - question: str, - limit: int = 100, - ) -> list[QueryResult]: - tabular = [ - r for r in results - if r.source_type == "document" - and r.metadata.get("data", {}).get("file_type") in _TABULAR_FILE_TYPES - ] - - if not tabular: - return [] - - # Group by (document_id, sheet_name) — one parquet download per group - groups: dict[tuple[str, str | None], _GroupInfo] = {} - for r in tabular: - data = r.metadata.get("data", {}) - doc_id = data.get("document_id") - if not doc_id: - continue - sheet_name = data.get("sheet_name") # None for CSV - key = (doc_id, sheet_name) - if key not in groups: - groups[key] = { - "filename": data.get("filename", ""), - "file_type": data.get("file_type", ""), - } - - async def _process_group( - doc_id: str, sheet_name: str | None, info: _GroupInfo - ) -> QueryResult | None: - try: - df = await download_parquet(user_id, doc_id, sheet_name) - df_result = await self._query_with_agent(df, question, limit) - - table_label = info["filename"] - if sheet_name: - table_label += f" / sheet: {sheet_name}" - - logger.info( - "tabular query complete", - document_id=doc_id, - sheet=sheet_name, - file_type=info["file_type"], - rows=len(df_result), - columns=len(df_result.columns), - ) - return QueryResult( - source_type="document", - source_id=doc_id, - table_or_file=table_label, - columns=list(df_result.columns), - rows=df_result.to_dict(orient="records"), - row_count=len(df_result), - ) - except Exception as e: - logger.error( - "tabular query failed", - document_id=doc_id, - sheet=sheet_name, - error=str(e), - ) - return None - - gathered = await asyncio.gather(*[ - _process_group(doc_id, sheet_name, info) - for (doc_id, sheet_name), info in groups.items() - ]) - return [r for r in gathered if r is not None] - - async def _query_with_agent( - self, df: pd.DataFrame, question: str, limit: int - ) -> pd.DataFrame: - schema_ctx = _build_schema_context(df) - prev_error = "" - - for attempt in range(_MAX_RETRIES): - error_section = ( - f"Previous attempt failed: {prev_error}\nFix the issue." - if prev_error else "" - ) - try: - op: TabularOperation = await self._chain.ainvoke({ - "schema": schema_ctx, - "error_section": error_section, - "question": question, - }) - logger.info( - "tabular operation decided", - operation=op.operation, - reasoning=op.reasoning, - ) - return _apply_operation(df, op, limit) - except Exception as e: - prev_error = str(e) - logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error) - - # Fallback: return raw rows - logger.warning("tabular agent failed after retries, returning raw rows") - return df.head(limit) - - -tabular_executor = TabularExecutor() diff --git a/src/query/ir/__init__.py b/src/query/ir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..324639d5767dae50cf7f5894ffe2b10acd58b2c2 --- /dev/null +++ b/src/query/ir/__init__.py @@ -0,0 +1 @@ +"""JSON IR (intermediate representation) for catalog-driven queries.""" diff --git a/src/query/ir/models.py b/src/query/ir/models.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb50f8410e81db66205dbd78350b145f7350d44 --- /dev/null +++ b/src/query/ir/models.py @@ -0,0 +1,59 @@ +"""JSON IR (intermediate representation) Pydantic models. + +See ARCHITECTURE.md §7 for the schema. + +Initial scope: single-table; filter, group_by, agg, order_by, limit. +Joins, having, offset, boolean tree filters are deferred to later versions. +""" + +from typing import Any, Literal + +from pydantic import BaseModel, Field + +FilterOp = Literal[ + "=", "!=", "<", "<=", ">", ">=", + "in", "not_in", "is_null", "is_not_null", + "like", "between", +] +AggFn = Literal["count", "count_distinct", "sum", "avg", "min", "max"] +ValueType = Literal["int", "decimal", "string", "datetime", "date", "bool"] +SortDir = Literal["asc", "desc"] + + +class ColumnSelect(BaseModel): + kind: Literal["column"] = "column" + column_id: str + alias: str | None = None + + +class AggSelect(BaseModel): + kind: Literal["agg"] = "agg" + fn: AggFn + column_id: str | None = None + alias: str | None = None + + +SelectItem = ColumnSelect | AggSelect + + +class FilterClause(BaseModel): + column_id: str + op: FilterOp + value: Any + value_type: ValueType + + +class OrderByClause(BaseModel): + column_id: str + dir: SortDir = "asc" + + +class QueryIR(BaseModel): + ir_version: str = "1.0" + source_id: str + table_id: str + select: list[SelectItem] + filters: list[FilterClause] = Field(default_factory=list) + group_by: list[str] = Field(default_factory=list) + order_by: list[OrderByClause] = Field(default_factory=list) + limit: int | None = None diff --git a/src/query/ir/operators.py b/src/query/ir/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..3415cf8de317cdf8f0b1b03815ae1e7f13cec1dd --- /dev/null +++ b/src/query/ir/operators.py @@ -0,0 +1,27 @@ +"""Whitelisted operators + aggregation functions for IR validation.""" + +ALLOWED_FILTER_OPS = frozenset({ + "=", "!=", "<", "<=", ">", ">=", + "in", "not_in", "is_null", "is_not_null", + "like", "between", +}) + +ALLOWED_AGG_FNS = frozenset({ + "count", "count_distinct", "sum", "avg", "min", "max", +}) + +LIMIT_HARD_CAP = 10_000 + +# Type compatibility: which value_types may appear in a FilterClause when the +# referenced column has the given data_type. Numeric types are mutually +# compatible (decimal literal against int column is fine). Date/datetime accept +# string so the planner can emit ISO-8601 literals without mode juggling. +TYPE_COMPATIBILITY: dict[str, frozenset[str]] = { + "int": frozenset({"int", "decimal"}), + "decimal": frozenset({"int", "decimal"}), + "string": frozenset({"string"}), + "datetime": frozenset({"datetime", "date", "string"}), + "date": frozenset({"date", "datetime", "string"}), + "bool": frozenset({"bool"}), + "json": frozenset({"string"}), +} diff --git a/src/query/ir/validator.py b/src/query/ir/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d05abfdbdb8d2efa2f0fbfeaaa6cf86e7edff4 --- /dev/null +++ b/src/query/ir/validator.py @@ -0,0 +1,129 @@ +"""IRValidator — checks a QueryIR against a user's catalog. + +See ARCHITECTURE.md §7 for the validation rules. On failure, the planner +is re-prompted with the error context (max 3 retries) — error messages +must therefore be specific enough that the LLM can self-correct. +""" + +from ...catalog.models import Catalog, Column, Source, Table +from .models import QueryIR +from .operators import ( + ALLOWED_AGG_FNS, + ALLOWED_FILTER_OPS, + LIMIT_HARD_CAP, + TYPE_COMPATIBILITY, +) + +_NULLARY_FILTER_OPS = frozenset({"is_null", "is_not_null"}) + + +class IRValidationError(Exception): + pass + + +class IRValidator: + """Reject IRs that reference unknown sources/tables/columns or use disallowed ops. + + Rules: + - source_id exists in catalog for this user + - table_id belongs to that source + - every column_id exists in that table + - every agg.fn and filter.op is whitelisted (see operators.py) + - value_type consistent with column.data_type (TYPE_COMPATIBILITY) + - limit positive int, ≤ LIMIT_HARD_CAP + """ + + def validate(self, ir: QueryIR, catalog: Catalog) -> None: + source = self._find_source(catalog, ir.source_id) + table = self._find_table(source, ir.table_id) + columns_by_id: dict[str, Column] = {c.column_id: c for c in table.columns} + + select_aliases: set[str] = set() + for i, item in enumerate(ir.select): + where = f"select[{i}]" + if item.kind == "column": + self._require_column(columns_by_id, item.column_id, where) + else: # "agg" + if item.fn not in ALLOWED_AGG_FNS: + raise IRValidationError( + f"{where}.fn: must be in {sorted(ALLOWED_AGG_FNS)}, " + f"got {item.fn!r}" + ) + if item.column_id is not None: + self._require_column(columns_by_id, item.column_id, where) + elif item.fn != "count": + raise IRValidationError( + f"{where}.fn={item.fn!r} requires a column_id " + "(only 'count' may omit it for COUNT(*))" + ) + if item.alias: + select_aliases.add(item.alias) + + for i, f in enumerate(ir.filters): + where = f"filters[{i}]" + col = self._require_column(columns_by_id, f.column_id, where) + if f.op not in ALLOWED_FILTER_OPS: + raise IRValidationError( + f"{where}.op: must be in {sorted(ALLOWED_FILTER_OPS)}, " + f"got {f.op!r}" + ) + if f.op not in _NULLARY_FILTER_OPS: + allowed = TYPE_COMPATIBILITY.get(col.data_type, frozenset()) + if f.value_type not in allowed: + raise IRValidationError( + f"{where}: value_type {f.value_type!r} incompatible with " + f"column.data_type {col.data_type!r} " + f"(allowed: {sorted(allowed)})" + ) + + for i, col_id in enumerate(ir.group_by): + self._require_column(columns_by_id, col_id, f"group_by[{i}]") + + for i, ob in enumerate(ir.order_by): + if ob.column_id not in columns_by_id and ob.column_id not in select_aliases: + raise IRValidationError( + f"order_by[{i}].column_id: {ob.column_id!r} not found in table " + f"{ir.table_id!r} columns or select aliases " + f"(known columns: {sorted(columns_by_id.keys())}, " + f"aliases: {sorted(select_aliases)})" + ) + + if ir.limit is not None: + if ir.limit <= 0: + raise IRValidationError(f"limit must be positive, got {ir.limit}") + if ir.limit > LIMIT_HARD_CAP: + raise IRValidationError( + f"limit {ir.limit} exceeds hard cap {LIMIT_HARD_CAP}" + ) + + @staticmethod + def _find_source(catalog: Catalog, source_id: str) -> Source: + for s in catalog.sources: + if s.source_id == source_id: + return s + raise IRValidationError( + f"source_id {source_id!r} not in catalog " + f"(known: {[s.source_id for s in catalog.sources]})" + ) + + @staticmethod + def _find_table(source: Source, table_id: str) -> Table: + for t in source.tables: + if t.table_id == table_id: + return t + raise IRValidationError( + f"table_id {table_id!r} not in source {source.source_id!r} " + f"(known: {[t.table_id for t in source.tables]})" + ) + + @staticmethod + def _require_column( + columns_by_id: dict[str, Column], col_id: str, where: str + ) -> Column: + col = columns_by_id.get(col_id) + if col is None: + raise IRValidationError( + f"{where}.column_id: {col_id!r} not in table " + f"(known: {sorted(columns_by_id.keys())})" + ) + return col diff --git a/src/query/planner/__init__.py b/src/query/planner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fec06fecc2a602d3792698d854f7eb6aead41a92 --- /dev/null +++ b/src/query/planner/__init__.py @@ -0,0 +1 @@ +"""LLM-based query planner — turns user questions + catalog into JSON IR.""" diff --git a/src/query/planner/prompt.py b/src/query/planner/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..b50ce43c8f0cc5903e360ffe4ebff5c52bdfec3d --- /dev/null +++ b/src/query/planner/prompt.py @@ -0,0 +1,45 @@ +"""Builds the planner LLM prompt from question + catalog. + +Renders the catalog into a compact textual form that fits the LLM context +window. For users with ≤50 tables the full catalog goes in verbatim. +""" + +from __future__ import annotations + +from ...catalog.models import Catalog +from ...catalog.render import render_source + + +def render_catalog(catalog: Catalog) -> str: + """Render every Source in the catalog as text. One blank line between sources.""" + if not catalog.sources: + return "(catalog is empty — the user has not registered any structured data yet)" + return "\n\n".join(render_source(s) for s in catalog.sources) + + +def build_planner_prompt( + question: str, + catalog: Catalog, + previous_error: str | None = None, +) -> str: + """Return the human-message content for the planner LLM. + + Composed of three sections in order: + 1. The user's question. + 2. The user's full catalog (rendered). + 3. (optional) The previous attempt's error, on retry. + + The system prompt (`config/prompts/query_planner.md`) is loaded + separately by `QueryPlannerService`. + """ + sections = [ + f"# Question\n\n{question}", + f"# Catalog\n\n{render_catalog(catalog)}", + ] + if previous_error: + sections.append( + "# Previous attempt failed validation\n\n" + f"{previous_error}\n\n" + "Emit a corrected IR. Do not repeat the same mistake." + ) + return "\n\n".join(sections) diff --git a/src/query/planner/service.py b/src/query/planner/service.py new file mode 100644 index 0000000000000000000000000000000000000000..7316b1c5d64d34b59ad68fad6a08991cd19a8829 --- /dev/null +++ b/src/query/planner/service.py @@ -0,0 +1,101 @@ +"""QueryPlannerService — single LLM call: question + catalog → JSON IR. + +Prompt: src/config/prompts/query_planner.md (system) + the human content +built by `prompt.build_planner_prompt(...)`. + +Output: a QueryIR ready for the IRValidator. Validation + retry are the +caller's concern (`QueryService` orchestrates that loop). +""" + +from __future__ import annotations + +from pathlib import Path + +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable +from langchain_openai import AzureChatOpenAI + +from src.middlewares.logging import get_logger + +from ...catalog.models import Catalog +from ..ir.models import QueryIR +from .prompt import build_planner_prompt + +logger = get_logger("query_planner") + +_PROMPT_PATH = ( + Path(__file__).resolve().parent.parent.parent + / "config" + / "prompts" + / "query_planner.md" +) + + +def _load_prompt_text() -> str: + return _PROMPT_PATH.read_text(encoding="utf-8") + + +def _build_default_chain() -> Runnable: + from src.config.settings import settings + + llm = AzureChatOpenAI( + azure_deployment=settings.azureai_deployment_name_4o, + openai_api_version=settings.azureai_api_version_4o, + azure_endpoint=settings.azureai_endpoint_url_4o, + api_key=settings.azureai_api_key_4o, + temperature=0, + ) + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage(content=_load_prompt_text()), + ("human", "{human_content}"), + ] + ) + return prompt | llm.with_structured_output(QueryIR) + + +_default_chain: Runnable | None = None + + +def _get_default_chain() -> Runnable: + global _default_chain + if _default_chain is None: + _default_chain = _build_default_chain() + return _default_chain + + +class QueryPlannerService: + """Wraps the LLM call with structured-output parsing into QueryIR. + + Inject `structured_chain` for tests. The planner prompt is composed + by `build_planner_prompt(question, catalog, previous_error)` so retry + callers can append the prior error context to nudge the LLM. + """ + + def __init__(self, structured_chain: Runnable | None = None) -> None: + self._chain = structured_chain + + def _ensure_chain(self) -> Runnable: + if self._chain is None: + self._chain = _get_default_chain() + return self._chain + + async def plan( + self, + question: str, + catalog: Catalog, + previous_error: str | None = None, + ) -> QueryIR: + human_content = build_planner_prompt(question, catalog, previous_error) + chain = self._ensure_chain() + ir: QueryIR = await chain.ainvoke({"human_content": human_content}) + logger.info( + "query planned", + source_id=ir.source_id, + table_id=ir.table_id, + select_n=len(ir.select), + filters_n=len(ir.filters), + retry=previous_error is not None, + ) + return ir diff --git a/src/query/query_executor.py b/src/query/query_executor.py deleted file mode 100644 index 824f51a8c1dc7173562783ffc436e5e3e9c088d7..0000000000000000000000000000000000000000 --- a/src/query/query_executor.py +++ /dev/null @@ -1,42 +0,0 @@ -"""QueryExecutor — dispatches retrieval results to the appropriate executor by source_type.""" - -import asyncio - -from sqlalchemy.ext.asyncio import AsyncSession - -from src.middlewares.logging import get_logger -from src.query.base import QueryResult -from src.query.executors.db_executor import db_executor -from src.query.executors.tabular import tabular_executor -from src.rag.base import RetrievalResult - -logger = get_logger("query_executor") - - -class QueryExecutor: - async def execute( - self, - results: list[RetrievalResult], - user_id: str, - db: AsyncSession, - question: str, - limit: int = 100, - ) -> list[QueryResult]: - batches = await asyncio.gather( - db_executor.execute(results, user_id, db, question, limit), - tabular_executor.execute(results, user_id, db, question, limit), - return_exceptions=True, - ) - - query_results: list[QueryResult] = [] - for batch in batches: - if isinstance(batch, Exception): - logger.error("executor failed", error=str(batch)) - continue - query_results.extend(batch) - - logger.info("query execution complete", total=len(query_results)) - return query_results - - -query_executor = QueryExecutor() diff --git a/src/query/service.py b/src/query/service.py new file mode 100644 index 0000000000000000000000000000000000000000..3adc804de852423e468191dfc9c2b9255a4223e4 --- /dev/null +++ b/src/query/service.py @@ -0,0 +1,138 @@ +"""QueryService — orchestrates plan → validate → compile → execute. + +Top-level entry point for catalog-driven structured queries. Wired into +the chat endpoint when source_hint == "structured". + +Flow per call: + 1. Plan (LLM): question + catalog → QueryIR + 2. Validate IR against catalog. On failure, re-prompt the planner with the + error context and retry (up to `max_retries` total attempts). + 3. Dispatch IR to the right executor by `source.source_type`. + 4. Execute. Any exception (including NotImplementedError from the + TabularExecutor placeholder) is caught and surfaced via + `QueryResult.error` so the chatbot can branch on success / failure. + +The service never raises — every code path returns a `QueryResult`. +""" + +from __future__ import annotations + +from collections.abc import Callable + +from src.middlewares.logging import get_logger + +from ..catalog.models import Catalog +from .executor.base import QueryResult +from .executor.dispatcher import ExecutorDispatcher +from .ir.validator import IRValidationError, IRValidator +from .planner.service import QueryPlannerService + +logger = get_logger("query_service") + + +class QueryService: + """End-to-end runner for a user question against a catalog. + + All heavy dependencies are injectable so unit tests don't need real + LLMs or DB engines. Default constructors lazy-build the production + deps so importing this module is side-effect-free. + """ + + def __init__( + self, + planner: QueryPlannerService | None = None, + validator: IRValidator | None = None, + dispatcher_factory: Callable[[Catalog], ExecutorDispatcher] | None = None, + max_retries: int = 3, + ) -> None: + self._planner = planner or QueryPlannerService() + self._validator = validator or IRValidator() + self._dispatcher_factory = dispatcher_factory or ExecutorDispatcher + self._max_retries = max(1, max_retries) + + async def run(self, user_id: str, question: str, catalog: Catalog) -> QueryResult: + if not catalog.sources: + return _error_result( + source_id="", + error="No structured data registered yet — connect a database " + "or upload a CSV/XLSX before asking data questions.", + ) + + # ---------- planner + validator with retry ------------------ + previous_error: str | None = None + ir = None + for attempt in range(1, self._max_retries + 1): + try: + ir = await self._planner.plan(question, catalog, previous_error) + except Exception as e: + logger.error("planner crashed", attempt=attempt, error=str(e)) + return _error_result(source_id="", error=f"planner failed: {e}") + + try: + self._validator.validate(ir, catalog) + logger.info( + "ir planned and validated", + attempt=attempt, + source_id=ir.source_id, + table_id=ir.table_id, + select=[s.model_dump() for s in ir.select], + filters=[f.model_dump() for f in ir.filters], + group_by=ir.group_by, + ) + break + except IRValidationError as e: + previous_error = str(e) + logger.warning( + "ir validation failed", + attempt=attempt, + error=previous_error, + ) + ir = None # discard invalid IR + continue + else: + return _error_result( + source_id="", + error=( + f"Planner could not produce a valid IR after " + f"{self._max_retries} attempts. Last error: {previous_error}" + ), + ) + + # `ir` is non-None and valid here (guarded by the for/else above) + assert ir is not None + + # ---------- dispatch + execute ------------------------------ + try: + dispatcher = self._dispatcher_factory(catalog) + executor = dispatcher.pick(ir) + except Exception as e: + logger.error("dispatch failed", source_id=ir.source_id, error=str(e)) + return _error_result(source_id=ir.source_id, error=f"dispatch failed: {e}") + + try: + return await executor.run(ir) + except NotImplementedError as e: + # TabularExecutor placeholder — TAB owner ships PR3-TAB + logger.warning( + "executor not yet implemented", + source_id=ir.source_id, + error=str(e), + ) + return _error_result( + source_id=ir.source_id, + error="Tabular execution is not yet available — pending PR3-TAB.", + ) + except Exception as e: + logger.error("executor crashed", source_id=ir.source_id, error=str(e)) + return _error_result( + source_id=ir.source_id, error=f"executor failed: {e}" + ) + + +def _error_result(source_id: str, error: str) -> QueryResult: + """Build a uniform error QueryResult. + + `backend` is intentionally empty when the failure happens before an + executor is picked — the caller can still distinguish via `error`. + """ + return QueryResult(source_id=source_id, backend="", error=error) diff --git a/src/rag/__init__.py b/src/rag/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/rag/retriever.py b/src/rag/retriever.py deleted file mode 100644 index ee0e704c7d130ef40c229b83d6d0b8f4ca7faa18..0000000000000000000000000000000000000000 --- a/src/rag/retriever.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Public retrieval API — thin wrapper around RetrievalRouter.""" - -from sqlalchemy.ext.asyncio import AsyncSession - -from src.middlewares.logging import get_logger -from src.rag.base import RetrievalResult -from src.rag.retrievers.document import document_retriever -from src.rag.retrievers.schema import schema_retriever -from src.rag.router import RetrievalRouter, SourceHint - -logger = get_logger("retriever") - - -class RetrieverService: - """Public retrieval service used by chat.py and search tools. - - Delegates to RetrievalRouter which dispatches based on source_hint. - Returns RetrievalResult objects directly so downstream consumers - (db_executor, tabular_executor) can be fed without lossy dict - conversion. The `db` parameter is accepted for call-site compatibility - but currently unused — retrieval reads PGVector via _pgvector_engine - inside each retriever. - """ - - def __init__(self): - self._router = RetrievalRouter( - schema_retriever=schema_retriever, - document_retriever=document_retriever, - ) - - async def retrieve( - self, - query: str, - user_id: str, - db: AsyncSession, - k: int = 5, - source_hint: SourceHint = "both", - ) -> list[RetrievalResult]: - try: - return await self._router.retrieve(query, user_id, source_hint, k) - except Exception as e: - logger.error("retrieval failed", error=str(e)) - return [] - - -retriever = RetrieverService() diff --git a/src/rag/retrievers/__init__.py b/src/rag/retrievers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/rag/retrievers/baseline.py b/src/rag/retrievers/baseline.py deleted file mode 100644 index 029e76a31ce9910891dc53856fa3c19f21223007..0000000000000000000000000000000000000000 --- a/src/rag/retrievers/baseline.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Service for retrieving relevant documents from vector store.""" - -import hashlib -import json -from src.db.postgres.vector_store import get_vector_store -from src.db.redis.connection import get_redis -from sqlalchemy.ext.asyncio import AsyncSession -from src.middlewares.logging import get_logger -from typing import List, Dict, Any - -logger = get_logger("retriever") - -_RETRIEVAL_CACHE_TTL = 3600 # 1 hour - - -class BaselineRetrieverService: - """Baseline (pre-Phase-1) retriever — preserved for benchmark comparison. - - Renamed from RetrieverService so it doesn't shadow the production wrapper - at src/rag/retriever.py. Production code imports from src.rag.retriever; - benchmark scripts that want this baseline must import explicitly from - src.rag.retrievers.baseline. - """ - - def __init__(self): - self.vector_store = get_vector_store() - - async def retrieve( - self, - query: str, - user_id: str, - db: AsyncSession, - k: int = 5 - ) -> List[Dict[str, Any]]: - """Retrieve relevant chunks for a query, scoped to the user's documents. - - Returns: - List of dicts with keys: content, metadata - metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF) - """ - try: - redis = await get_redis() - query_hash = hashlib.md5(query.encode()).hexdigest() - cache_key = f"retrieval:{user_id}:{query_hash}:{k}" - - cached = await redis.get(cache_key) - if cached: - logger.info("Returning cached retrieval results") - return json.loads(cached) - - logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...") - - docs = await self.vector_store.asimilarity_search( - query=query, - k=k, - filter={"user_id": user_id} - ) - - results = [ - { - "content": doc.page_content, - "metadata": doc.metadata, - } - for doc in docs - ] - - logger.info(f"Retrieved {len(results)} chunks") - await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results)) - return results - - except Exception as e: - logger.error("Retrieval failed", error=str(e)) - return [] - - -baseline_retriever = BaselineRetrieverService() \ No newline at end of file diff --git a/src/rag/retrievers/schema.py b/src/rag/retrievers/schema.py deleted file mode 100644 index 968d44c5c889a3b11678c5427b4ada29689cfd27..0000000000000000000000000000000000000000 --- a/src/rag/retrievers/schema.py +++ /dev/null @@ -1,411 +0,0 @@ -"""Schema retriever — handles DB schemas (source_type="database") and tabular file -columns stored as source_type="document" with file_type in ("csv","xlsx"). - -Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables -+ tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only). -Embeds the query once, fans out five legs in parallel. - -The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as -a recall signal for multi-table questions: when a relevant table's columns -don't individually win on similarity, the table chunk can still pull the table -into the hit set, where db_executor's downstream full-schema fetch picks up -the per-column detail. - -FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py). -""" - -import asyncio - -from sqlalchemy import text - -from src.db.postgres.connection import _pgvector_engine -from src.db.postgres.vector_store import get_vector_store -from src.middlewares.logging import get_logger -from src.rag.base import BaseRetriever, RetrievalResult - -logger = get_logger("schema_retriever") - -_TABULAR_FILE_TYPES = ("csv", "xlsx") -_TABLE_CHUNK_K_MULTIPLIER = 2 # how many table chunks to pull before RRF - - -class SchemaRetriever(BaseRetriever): - def __init__(self): - self.vector_store = get_vector_store() - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - async def _embed_query(self, query: str) -> list[float]: - return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query) - - async def _search_db( - self, embedding: list[float], user_id: str, k: int - ) -> list[RetrievalResult]: - """Cosine vector search over database chunks.""" - emb_str = "[" + ",".join(str(x) for x in embedding) + "]" - - sql = text(f""" - SELECT lpe.document, lpe.cmetadata, - 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'chunk_level' = 'column' - ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC - LIMIT :k - """) - - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) - rows = result.fetchall() - - return [ - RetrievalResult( - content=row.document, - metadata=row.cmetadata, - score=float(row.score), - source_type="database", - ) - for row in rows - ] - - async def _search_db_tables( - self, embedding: list[float], user_id: str, k: int - ) -> list[RetrievalResult]: - """Cosine vector search over database TABLE-level chunks. - - Recall channel for multi-table questions. The chunk's content is - discarded downstream — db_executor only consumes its `data.table_name` - to seed full-schema fetch. - """ - emb_str = "[" + ",".join(str(x) for x in embedding) + "]" - - sql = text(f""" - SELECT lpe.document, lpe.cmetadata, - 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'chunk_level' = 'table' - ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC - LIMIT :k - """) - - async with _pgvector_engine.connect() as conn: - result = await conn.execute( - sql, {"user_id": user_id, "k": k * _TABLE_CHUNK_K_MULTIPLIER} - ) - rows = result.fetchall() - - return [ - RetrievalResult( - content=row.document, - metadata=row.cmetadata, - score=float(row.score), - source_type="database", - ) - for row in rows - ] - - async def _search_tabular( - self, embedding: list[float], user_id: str, k: int - ) -> list[RetrievalResult]: - """Cosine vector search over tabular document chunks (csv/xlsx).""" - emb_str = "[" + ",".join(str(x) for x in embedding) + "]" - - sql = text(f""" - SELECT lpe.document, lpe.cmetadata, - 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'document' - AND lpe.cmetadata->>'chunk_level' = 'column' - AND (lpe.cmetadata->'data'->>'file_type' = 'csv' - OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') - ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC - LIMIT :k - """) - - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) - rows = result.fetchall() - - return [ - RetrievalResult( - content=row.document, - metadata=row.cmetadata, - score=float(row.score), - source_type="document", - ) - for row in rows - ] - - async def _search_tabular_sheets( - self, embedding: list[float], user_id: str, k: int - ) -> list[RetrievalResult]: - """Leg 5: sheet-level summary chunks from CSV/XLSX files.""" - emb_str = "[" + ",".join(str(x) for x in embedding) + "]" - - sql = text(f""" - SELECT lpe.document, lpe.cmetadata, - 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'document' - AND lpe.cmetadata->>'chunk_level' = 'sheet' - AND (lpe.cmetadata->'data'->>'file_type' = 'csv' - OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') - ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC - LIMIT :k - """) - - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, {"user_id": user_id, "k": k}) - rows = result.fetchall() - - return [ - RetrievalResult( - content=row.document, - metadata=row.cmetadata, - score=float(row.score), - source_type="document", - ) - for row in rows - ] - - async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: - """Full-text search over DB schema chunks using PostgreSQL tsvector.""" - sql = text(""" - SELECT lpe.document, lpe.cmetadata, - ts_rank(to_tsvector('english', lpe.document), - plainto_tsquery('english', :query)) AS rank - FROM langchain_pg_embedding lpe - JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid - WHERE lpc.name = 'document_embeddings' - AND lpe.cmetadata->>'user_id' = :user_id - AND lpe.cmetadata->>'source_type' = 'database' - AND lpe.cmetadata->>'chunk_level' = 'column' - AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) - ORDER BY rank DESC - LIMIT :k - """) - - async with _pgvector_engine.connect() as conn: - result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) - rows = result.fetchall() - - return [ - RetrievalResult( - content=row.document, - metadata=row.cmetadata, - score=float(row.rank), - source_type="database", - ) - for row in rows - ] - - def _rank_tabular_sheets( - self, - sheet_results: list[RetrievalResult], - column_results: list[RetrievalResult], - top_k: int, - k_rrf: int = 60, - ) -> list[RetrievalResult]: - """Rank tabular sheets by RRF across two voting legs: - L1 (primary): sheet-chunk cosine score - L2 (vote): best column-chunk position per (doc_id, sheet_name) - - Returns top-k sheet-level RetrievalResults. The full column list of - each sheet is already in the sheet chunk's data.column_names from - ingestion, so downstream tabular_executor can read full sheet context. - - For sheets surfaced by column votes but missing a sheet chunk (rare — - ingestion always creates one), a minimal stub is returned and - tabular_executor falls back to reading columns from the parquet. - """ - # L1: sheets indexed by (doc_id, sheet_name) from sheet chunks - sheet_index: dict[tuple, RetrievalResult] = {} - sheet_ranked: list[tuple] = [] - for r in sheet_results: - d = r.metadata.get("data", {}) - key = (d.get("document_id"), d.get("sheet_name")) - if key[0] and key not in sheet_index: - sheet_index[key] = r - sheet_ranked.append(key) - - # L2: sheets ranked by first-appearance in column-chunk results - col_sheet_ranked: list[tuple] = [] - seen: set[tuple] = set() - for r in column_results: - d = r.metadata.get("data", {}) - key = (d.get("document_id"), d.get("sheet_name")) - if key[0] and key not in seen: - col_sheet_ranked.append(key) - seen.add(key) - - # RRF over (doc_id, sheet_name) across the two legs - rrf_scores: dict[tuple, float] = {} - for ranked_list in [sheet_ranked, col_sheet_ranked]: - for rank, key in enumerate(ranked_list): - rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) - - top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k] - - results: list[RetrievalResult] = [] - for key in top_sheets: - if key in sheet_index: - r = sheet_index[key] - r.score = rrf_scores[key] - results.append(r) - else: - # Surfaced by column votes only — build stub from a representative - # column result so tabular_executor can group correctly. - doc_id, sheet_name = key - rep = next( - (r for r in column_results - if r.metadata.get("data", {}).get("document_id") == doc_id - and r.metadata.get("data", {}).get("sheet_name") == sheet_name), - None, - ) - if rep is None: - continue - stub_data = dict(rep.metadata.get("data", {})) - stub_data.pop("column_name", None) - stub_data.pop("column_type", None) - results.append(RetrievalResult( - content=f"Sheet: {stub_data.get('filename', '')}" - + (f" / sheet: {sheet_name}" if sheet_name else ""), - metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"}, - score=rrf_scores[key], - source_type="document", - )) - return results - - def _rank_db_tables( - self, - tbl_results: list[RetrievalResult], - col_results: list[RetrievalResult], - fts_results: list[RetrievalResult], - top_k: int, - k_rrf: int = 60, - ) -> list[RetrievalResult]: - """Rank DB tables by RRF across three legs: - L1 (primary): table-summary chunk similarity - L2 (vote): best column-chunk position per table - L3 (vote): best FTS position per table - - Returns top-k table-chunk RetrievalResults. For tables surfaced by - L2/L3 but missing a table chunk, a minimal stub is returned so that - db_executor._fetch_full_schema can seed off data.table_name. - """ - # L1: tables ranked by table-chunk cosine score - tbl_index: dict[str, RetrievalResult] = {} - tbl_ranked: list[str] = [] - for r in tbl_results: - tname = r.metadata.get("data", {}).get("table_name") - if tname and tname not in tbl_index: - tbl_index[tname] = r - tbl_ranked.append(tname) - - # L2: tables ranked by first-appearance in column-chunk list (best col score) - col_table_ranked: list[str] = [] - seen: set[str] = set() - for r in col_results: - tname = r.metadata.get("data", {}).get("table_name") - if tname and tname not in seen: - col_table_ranked.append(tname) - seen.add(tname) - - # L3: tables ranked by first-appearance in FTS list - fts_table_ranked: list[str] = [] - seen = set() - for r in fts_results: - tname = r.metadata.get("data", {}).get("table_name") - if tname and tname not in seen: - fts_table_ranked.append(tname) - seen.add(tname) - - # RRF over table names across the three legs - rrf_scores: dict[str, float] = {} - for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]: - for rank, tname in enumerate(ranked_list): - rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1) - - top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k] - - results: list[RetrievalResult] = [] - for tname in top_tables: - if tname in tbl_index: - r = tbl_index[tname] - r.score = rrf_scores[tname] - results.append(r) - else: - # Surfaced by column/FTS votes with no table chunk — minimal stub - results.append(RetrievalResult( - content=f"Table: {tname}", - metadata={"data": {"table_name": tname}, "source_type": "database"}, - score=rrf_scores[tname], - source_type="database", - )) - return results - - # ------------------------------------------------------------------ - # Public interface — called by the router - # ------------------------------------------------------------------ - - async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: - """Table-first retrieval for DB sources; chunk-level for tabular. - - DB tables are ranked via RRF across three legs: - L1 (primary): table-summary chunk similarity - L2 (vote): top-K column-chunk cosine, grouped by table - L3 (vote): top-K FTS column hits, grouped by table - - db_executor downstream fetches the full per-column schema for the - ranked table set via _fetch_full_schema — the column chunks returned - here are intentionally NOT used as the schema source, only for voting. - - Tabular (CSV/XLSX) sheets are ranked via RRF across two legs: - L1: sheet-chunk cosine - L2: column-chunk votes (best position per sheet) - Returns sheet-level RetrievalResults so tabular_executor receives - full sheet context (all columns) rather than fragmented column hits. - """ - embedding = await self._embed_query(query) - db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather( - self._search_db(embedding, user_id, k), - self._search_db_tables(embedding, user_id, k), - self._search_tabular(embedding, user_id, k), - self._search_fts_db(query, user_id, k * 4), - self._search_tabular_sheets(embedding, user_id, k), - ) - - db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k) - tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k) - - results = sorted(db_ranked + tabular_ranked, key=lambda r: r.score, reverse=True) - logger.info( - "schema retrieval", - count=len(results), - db_tables_ranked=len(db_ranked), - db_cols=len(db_col_results), - db_tables=len(db_tbl_results), - tabular_cols=len(tabular_results), - tabular_sheets=len(sheet_results), - tabular_ranked=len(tabular_ranked), - fts=len(fts_results), - ) - return results - - -schema_retriever = SchemaRetriever() diff --git a/src/rag/router.py b/src/rag/router.py deleted file mode 100644 index be23b0809913a4715228d697a38d4b029f22cf8f..0000000000000000000000000000000000000000 --- a/src/rag/router.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Routes retrieval requests to the appropriate retriever based on source_hint. - -Cross-retriever merging uses Reciprocal Rank Fusion (RRF) on per-retriever -ranked lists — score scales differ across retrievers (RRF, cosine, distance) -and aren't directly comparable, so we rank-merge instead of score-merge. -""" - -import asyncio -import hashlib -import json -from dataclasses import asdict -from typing import Literal - -from src.db.redis.connection import get_redis -from src.middlewares.logging import get_logger -from src.rag.base import BaseRetriever, RetrievalResult - -logger = get_logger("retrieval_router") - -_CACHE_TTL = 3600 # 1 hour -_CACHE_KEY_PREFIX = "retrieval" -_RRF_K = 60 # standard RRF constant -SourceHint = Literal["document", "schema", "both"] - - -def _result_dedup_key(r: RetrievalResult) -> tuple: - """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs - tabular columns vs prose chunks vs sheet-level chunks.""" - data = r.metadata.get("data", {}) - return ( - r.source_type, - data.get("table_name"), - data.get("column_name"), - data.get("filename"), - data.get("sheet_name"), - data.get("chunk_index"), # disambiguates multiple prose chunks per doc - r.metadata.get("chunk_level"), # distinguishes sheet vs column chunks - ) - - -def _rrf_merge( - ranked_lists: list[list[RetrievalResult]], - top_k: int, - k_rrf: int = _RRF_K, -) -> list[RetrievalResult]: - """Reciprocal Rank Fusion across retriever batches. - - Each input list is treated as already best-first ordered. Items are - deduped via _result_dedup_key and re-ranked by aggregated reciprocal - rank across all lists. Score on the returned RetrievalResult is the - aggregated RRF score (uniform scale across legs). - """ - scores: dict[tuple, float] = {} - index: dict[tuple, RetrievalResult] = {} - - for ranked in ranked_lists: - for rank, result in enumerate(ranked): - key = _result_dedup_key(result) - scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) - # Keep the first occurrence; metadata is identical for the same - # key across lists, so any copy is fine. - if key not in index: - index[key] = result - - merged = sorted(index.values(), key=lambda r: scores[_result_dedup_key(r)], reverse=True) - # Overwrite score with RRF score so downstream consumers see a uniform scale. - for r in merged: - r.score = scores[_result_dedup_key(r)] - return merged[:top_k] - - -async def invalidate_retrieval_cache(user_id: str) -> int: - """Delete every cached retrieval entry for `user_id`. - - Called by ingest/upload/delete API handlers after a successful write so - the next retrieval picks up the new data instead of stale cached top-k. - Returns the number of keys removed. - """ - redis = await get_redis() - pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*" - keys = [key async for key in redis.scan_iter(match=pattern)] - if not keys: - return 0 - deleted = await redis.delete(*keys) - logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted) - return int(deleted) - - -class RetrievalRouter: - def __init__( - self, - schema_retriever: BaseRetriever, - document_retriever: BaseRetriever, - ): - self._retrievers: dict[str, BaseRetriever] = { - "schema": schema_retriever, - "document": document_retriever, - } - - def _route(self, source_hint: SourceHint) -> list[tuple[str, BaseRetriever]]: - if source_hint == "schema": - return [("schema", self._retrievers["schema"])] - if source_hint == "document": - return [("document", self._retrievers["document"])] - return list(self._retrievers.items()) - - async def retrieve( - self, - query: str, - user_id: str, - source_hint: SourceHint = "both", - k: int = 10, - ) -> list[RetrievalResult]: - redis = await get_redis() - query_hash = hashlib.md5(query.encode()).hexdigest() - cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}" - - cached = await redis.get(cache_key) - if cached: - try: - raw = json.loads(cached) - logger.info("returning cached retrieval results", source_hint=source_hint) - return [RetrievalResult(**r) for r in raw] - except Exception: - logger.warning("corrupted retrieval cache, fetching fresh", cache_key=cache_key) - - results = await self._retrieve_uncached(query, user_id, source_hint, k) - - # Empty-result fallback: orchestrator may have misclassified intent. - # Retry once with "both" before giving up. No-op when source_hint is - # already "both". - if not results and source_hint != "both": - logger.warning( - "empty retrieval, falling back to source_hint='both'", - original_source_hint=source_hint, - ) - results = await self._retrieve_uncached(query, user_id, "both", k) - - await redis.setex( - cache_key, - _CACHE_TTL, - json.dumps([asdict(r) for r in results]), - ) - return results - - async def _retrieve_uncached( - self, - query: str, - user_id: str, - source_hint: SourceHint, - k: int, - ) -> list[RetrievalResult]: - routed = self._route(source_hint) - batches = await asyncio.gather( - *[r.retrieve(query, user_id, k) for _, r in routed], - return_exceptions=True, - ) - - valid_lists: list[list[RetrievalResult]] = [] - per_retriever: dict[str, int | str] = {} - for (name, _), batch in zip(routed, batches): - if isinstance(batch, Exception): - logger.error("retriever failed", retriever=name, error=str(batch)) - per_retriever[name] = "error" - continue - valid_lists.append(batch) - per_retriever[name] = len(batch) - - results = _rrf_merge(valid_lists, top_k=k) - - logger.info( - "router result", - source_hint=source_hint, - per_retriever=per_retriever, - final_count=len(results), - top_score=results[0].score if results else None, - bottom_score=results[-1].score if results else None, - ) - return results diff --git a/src/retrieval/README.md b/src/retrieval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c68b65fb35b0d0aa1d4183865c9e5674e50c3f4 --- /dev/null +++ b/src/retrieval/README.md @@ -0,0 +1,8 @@ +# retrieval + +Unstructured-source retrieval (PDF, DOCX, TXT) — Cu in the architecture. +Dense similarity over prose chunks via PGVector. + +Structured (DB / tabular) sources do **not** pass through here — they go through `catalog/` + `query/`. + +See `ARCHITECTURE.md` (root) for the full design. diff --git a/src/retrieval/__init__.py b/src/retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab40ab2ec3f257a1dc22e93e67e986e49970b83 --- /dev/null +++ b/src/retrieval/__init__.py @@ -0,0 +1 @@ +"""Retrieval for unstructured sources (Cu) — prose chunks via dense similarity.""" diff --git a/src/rag/base.py b/src/retrieval/base.py similarity index 87% rename from src/rag/base.py rename to src/retrieval/base.py index a485af2ed8749af18ff66084ec95df6c61fda435..06ac2a2feb5fd5121fa127a9b467ee9ca452f6fe 100644 --- a/src/rag/base.py +++ b/src/retrieval/base.py @@ -1,4 +1,4 @@ -"""Shared contract for all retriever implementations.""" +"""Shared types for the retrieval layer.""" from abc import ABC, abstractmethod from dataclasses import dataclass diff --git a/src/rag/retrievers/document.py b/src/retrieval/document.py similarity index 74% rename from src/rag/retrievers/document.py rename to src/retrieval/document.py index af33ba22386aa6f319b4c9048cfeab5bb2a4d123..6eaff1bf3831ead264ee58e002cf96b731b71fa4 100644 --- a/src/rag/retrievers/document.py +++ b/src/retrieval/document.py @@ -1,5 +1,10 @@ -"""Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).""" +"""DocumentRetriever — dense similarity over prose chunks (Cu). +For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector with +collection `document_embeddings`. Methods: MMR, cosine, euclidean, etc. +""" + +import functools import math from langchain_postgres import PGVector @@ -11,7 +16,7 @@ from src.config.settings import settings from src.db.postgres.connection import _pgvector_engine from src.db.postgres.vector_store import get_vector_store from src.middlewares.logging import get_logger -from src.rag.base import BaseRetriever, RetrievalResult +from src.retrieval.base import BaseRetriever, RetrievalResult logger = get_logger("document_retriever") @@ -24,32 +29,40 @@ _FETCH_K = 20 _LAMBDA_MULT = 0.5 _COLLECTION_NAME = "document_embeddings" -_embeddings = AzureOpenAIEmbeddings( - azure_deployment=settings.azureai_deployment_name_embedding, - openai_api_version=settings.azureai_api_version_embedding, - azure_endpoint=settings.azureai_endpoint_url_embedding, - api_key=settings.azureai_api_key_embedding, -) - -_euclidean_store = PGVector( - embeddings=_embeddings, - connection=_pgvector_engine, - collection_name=_COLLECTION_NAME, - distance_strategy=DistanceStrategy.EUCLIDEAN, - use_jsonb=True, - async_mode=True, - create_extension=False, -) - -_ip_store = PGVector( - embeddings=_embeddings, - connection=_pgvector_engine, - collection_name=_COLLECTION_NAME, - distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT, - use_jsonb=True, - async_mode=True, - create_extension=False, -) +@functools.cache +def _get_embeddings() -> AzureOpenAIEmbeddings: + return AzureOpenAIEmbeddings( + azure_deployment=settings.azureai_deployment_name_embedding, + openai_api_version=settings.azureai_api_version_embedding, + azure_endpoint=settings.azureai_endpoint_url_embedding, + api_key=settings.azureai_api_key_embedding, + ) + + +@functools.cache +def _get_euclidean_store() -> PGVector: + return PGVector( + embeddings=_get_embeddings(), + connection=_pgvector_engine, + collection_name=_COLLECTION_NAME, + distance_strategy=DistanceStrategy.EUCLIDEAN, + use_jsonb=True, + async_mode=True, + create_extension=False, + ) + + +@functools.cache +def _get_ip_store() -> PGVector: + return PGVector( + embeddings=_get_embeddings(), + connection=_pgvector_engine, + collection_name=_COLLECTION_NAME, + distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT, + use_jsonb=True, + async_mode=True, + create_extension=False, + ) _MANHATTAN_SQL = text(""" SELECT @@ -93,11 +106,11 @@ class DocumentRetriever(BaseRetriever): score_map = {doc.page_content: score for doc, score in cosine} docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs] elif _RETRIEVAL_METHOD == "euclidean": - docs_with_scores = await _euclidean_store.asimilarity_search_with_score( + docs_with_scores = await _get_euclidean_store().asimilarity_search_with_score( query=query, k=fetch_k, filter=filter_, ) elif _RETRIEVAL_METHOD == "inner_product": - docs_with_scores = await _ip_store.asimilarity_search_with_score( + docs_with_scores = await _get_ip_store().asimilarity_search_with_score( query=query, k=fetch_k, filter=filter_, ) else: # cosine @@ -124,7 +137,7 @@ class DocumentRetriever(BaseRetriever): async def _retrieve_manhattan( self, query: str, user_id: str, k: int, fetch_k: int ) -> list[RetrievalResult]: - query_vector = await _embeddings.aembed_query(query) + query_vector = await _get_embeddings().aembed_query(query) if not all(math.isfinite(v) for v in query_vector): raise ValueError("Embedding vector contains NaN or Infinity values.") vector_str = "[" + ",".join(str(v) for v in query_vector) + "]" diff --git a/src/retrieval/router.py b/src/retrieval/router.py new file mode 100644 index 0000000000000000000000000000000000000000..e38d16a72e56c41bcb7636574879bd9a65b441f9 --- /dev/null +++ b/src/retrieval/router.py @@ -0,0 +1,80 @@ +"""Retrieval router — dispatches to DocumentRetriever for unstructured sources. + +Routing rules: + - unstructured / document / both → DocumentRetriever (PGVector, PDF/DOCX/TXT) + - structured / schema → empty list; handled by query/service.py + - chat → empty list; bypasses retrieval entirely + +Exposes the same interface as the old src/rag/retriever.py so call sites in +chat.py require no changes beyond the import path. +""" + +import hashlib +import json +from dataclasses import asdict + +from src.db.redis.connection import get_redis +from src.middlewares.logging import get_logger +from src.retrieval.base import RetrievalResult +from src.retrieval.document import DocumentRetriever + +logger = get_logger("retrieval_router") + +_CACHE_TTL = 3600 +_CACHE_KEY_PREFIX = "retrieval" + + +class RetrievalRouter: + def __init__(self) -> None: + self._retriever: DocumentRetriever | None = None + + def _get_retriever(self) -> DocumentRetriever: + if self._retriever is None: + self._retriever = DocumentRetriever() + return self._retriever + + async def retrieve( + self, + query: str, + user_id: str, + k: int = 5, + ) -> list[RetrievalResult]: + redis = await get_redis() + query_hash = hashlib.md5(query.encode()).hexdigest() + cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{query_hash}:{k}" + + cached = await redis.get(cache_key) + if cached: + try: + raw = json.loads(cached) + logger.info("returning cached retrieval results") + return [RetrievalResult(**r) for r in raw] + except Exception: + logger.warning("corrupted retrieval cache, fetching fresh") + + try: + results = await self._get_retriever().retrieve(query, user_id, k) + except Exception as e: + logger.error("retrieval failed", error=str(e)) + return [] + + await redis.setex( + cache_key, + _CACHE_TTL, + json.dumps([asdict(r) for r in results]), + ) + return results + + async def invalidate_cache(self, user_id: str) -> int: + """Delete all cached retrieval entries for a user. Call after upload/delete.""" + redis = await get_redis() + pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*" + keys = [key async for key in redis.scan_iter(match=pattern)] + if not keys: + return 0 + deleted = await redis.delete(*keys) + logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted) + return int(deleted) + + +retrieval_router = RetrievalRouter() diff --git a/src/security/README.md b/src/security/README.md new file mode 100644 index 0000000000000000000000000000000000000000..775920361a4d449efe9000df72a8577dca0ef680 --- /dev/null +++ b/src/security/README.md @@ -0,0 +1,8 @@ +# security + +Cross-cutting security primitives: +- credential encryption (Fernet) for stored DB credentials +- authentication / password / JWT helpers +- PII detection patterns used by the catalog introspectors + +Consolidates utilities previously split between `utils/` and `users/`. diff --git a/src/security/__init__.py b/src/security/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cfa62d0a931aa3ae3c667dee40ef600ecca9c6b --- /dev/null +++ b/src/security/__init__.py @@ -0,0 +1 @@ +"""Security primitives — credentials, auth, PII patterns.""" diff --git a/src/security/auth.py b/src/security/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..5e130e6b9220c61680d3e8c69fa1aa17bbf9420d --- /dev/null +++ b/src/security/auth.py @@ -0,0 +1,21 @@ +"""Authentication helpers: password hashing, JWT encode/decode, get_user. + +Receives the working implementation from the previous src/users/users.py +during the cleanup phase. +""" + + +def hash_password(plaintext: str) -> str: + raise NotImplementedError + + +def verify_password(plaintext: str, hashed: str) -> bool: + raise NotImplementedError + + +def encode_jwt(payload: dict) -> str: + raise NotImplementedError + + +def decode_jwt(token: str) -> dict: + raise NotImplementedError diff --git a/src/security/credentials.py b/src/security/credentials.py new file mode 100644 index 0000000000000000000000000000000000000000..5422f3c8679fd3a8cce8d82435f84b82eccb034b --- /dev/null +++ b/src/security/credentials.py @@ -0,0 +1,15 @@ +"""Fernet-encrypted credential storage for user-registered DB connections. + +Receives the working implementation from the previous +src/utils/db_credential_encryption.py during the cleanup phase. + +Key: settings.dataeyond__db__credential__key (Fernet, kept out of source). +""" + + +def encrypt_credential(plaintext: str) -> str: + raise NotImplementedError + + +def decrypt_credential(ciphertext: str) -> str: + raise NotImplementedError diff --git a/src/security/pii_patterns.py b/src/security/pii_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..548e01b24231ba7376a3c6f520825593c2feac14 --- /dev/null +++ b/src/security/pii_patterns.py @@ -0,0 +1,20 @@ +"""Regex patterns and column-name heuristics for PII detection. + +Used by catalog/pii_detector.py at ingestion time. Default policy: +when in doubt, set pii_flag=True. False positives cost nothing; false +negatives leak data. +""" + +import re + +PII_NAME_PATTERNS = frozenset({ + "email", + "phone", "mobile", "telp", "telephone", + "ssn", "tin", "passport", "ktp", "nik", + "name", "fullname", "first_name", "last_name", "surname", + "address", "street", "zipcode", "postal", + "birthdate", "dob", "birthday", +}) + +EMAIL_REGEX = re.compile(r"^[\w.+-]+@[\w-]+\.[\w.-]+$") +PHONE_REGEX = re.compile(r"^\+?[\d\s\-()]{7,}$") diff --git a/src/knowledge/parquet_service.py b/src/storage/parquet.py similarity index 93% rename from src/knowledge/parquet_service.py rename to src/storage/parquet.py index 1a47cc903f94c70b0d95f0813620570dc36de08e..ac41a399cd2203fbcab05864a7644e8ab593fe46 100644 --- a/src/knowledge/parquet_service.py +++ b/src/storage/parquet.py @@ -1,4 +1,4 @@ -"""Parquet service — converts, uploads, downloads, and deletes Parquet files for CSV/XLSX. +"""Parquet storage helpers — converts, uploads, downloads, and deletes Parquet files for CSV/XLSX. Parquet files are stored in Azure Blob alongside the original document using a deterministic naming convention based on document_id: @@ -18,7 +18,7 @@ import pandas as pd from src.middlewares.logging import get_logger from src.storage.az_blob.az_blob import blob_storage -logger = get_logger("parquet_service") +logger = get_logger("storage.parquet") def _safe_sheet_name(sheet_name: str) -> str: @@ -27,7 +27,7 @@ def _safe_sheet_name(sheet_name: str) -> str: def parquet_blob_name(user_id: str, document_id: str, sheet_name: str | None = None) -> str: """Construct deterministic Parquet blob name.""" - if sheet_name: + if sheet_name is not None: return f"{user_id}/{document_id}__{_safe_sheet_name(sheet_name)}.parquet" return f"{user_id}/{document_id}.parquet" diff --git a/src/tools/__init__.py b/src/tools/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/tools/search.py b/src/tools/search.py deleted file mode 100644 index 9ab80f31bdf71e3866ee2f4c653d1b3c55a04173..0000000000000000000000000000000000000000 --- a/src/tools/search.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Search tool for agent.""" - -from langchain_core.tools import tool -from src.rag.retriever import retriever -from sqlalchemy.ext.asyncio import AsyncSession -from src.middlewares.logging import get_logger - -logger = get_logger("search_tool") - - -@tool -async def search_documents( - query: str, - user_id: str, - db: AsyncSession, - num_results: int = 5 -) -> str: - """Search user's uploaded documents for relevant information. - - Args: - query: The search query or question - user_id: The user's ID - db: Database session - num_results: Number of results to return (default: 5) - - Returns: - Relevant document excerpts with source and page information - """ - try: - results = await retriever.retrieve(query, user_id, db, num_results) - - if not results: - return "No relevant information found in the documents." - - formatted_results = [] - for result in results: - filename = result.metadata.get("filename", "Unknown") - page = result.metadata.get("page_label") - source_label = f"{filename}, p.{page}" if page else filename - formatted_results.append(f"[Source: {source_label}]\n{result.content}\n") - - return "\n".join(formatted_results) - - except Exception as e: - logger.error("Search failed", error=str(e)) - return "Sorry, I encountered an error while searching the documents."