Mayank Chugh commited on
Commit ·
a32f9e3
1
Parent(s): fceb91f
Enhance environment configuration and API documentation for Milestone 11
Browse files- Update `.env.example` to reflect new application name, version, and additional configuration options for LLM providers.
- Revise `README.md` to improve architecture overview, quick start instructions, and use cases for the DocuAudit AI application.
- Add detailed specifications for API endpoints in `milestones.md`, including new features for multi-file uploads, query parameters, and audit logging.
- Refactor API routes to support enhanced query and ingestion functionalities, including user ID tracking and improved response structures.
- Update request and response models to align with new API specifications and ensure comprehensive coverage of expected behaviors.
- .env.example +44 -8
- README.md +121 -25
- api/config.py +27 -5
- api/main.py +22 -2
- api/routes/audit.py +30 -18
- api/routes/ingest.py +100 -54
- api/routes/jobs.py +8 -17
- api/routes/query.py +126 -86
- models/requests.py +33 -16
- models/responses.py +92 -64
- pytest.ini +3 -0
- rag/retriever.py +48 -22
- rag/vector_store.py +35 -1
- storage/audit_store.py +208 -26
- storage/job_store.py +197 -37
- streamlit_app.py +60 -26
- tests/test_audit.py +68 -17
- tests/test_ingest.py +10 -6
- tests/test_query.py +47 -7
- workers/ingest_worker.py +67 -27
.env.example
CHANGED
|
@@ -1,12 +1,48 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
| 3 |
LLM_PROVIDER=ollama
|
| 4 |
-
|
| 5 |
-
|
| 6 |
OPENAI_API_KEY=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
CHROMA_PERSIST_DIRECTORY=./data/chroma
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
CHUNK_SIZE=1000
|
| 10 |
-
CHUNK_OVERLAP=
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DocuAudit AI — environment template (see docs/DOCUAUDIT_AI_REQUIREMENTS.md)
|
| 2 |
+
|
| 3 |
+
# LLM Provider: ollama | anthropic | openai | huggingface
|
| 4 |
LLM_PROVIDER=ollama
|
| 5 |
+
|
| 6 |
+
# OpenAI (optional)
|
| 7 |
OPENAI_API_KEY=
|
| 8 |
+
OPENAI_MODEL=gpt-4o
|
| 9 |
+
OPENAI_EMBEDDING_MODEL=text-embedding-3-small
|
| 10 |
+
|
| 11 |
+
# Anthropic (optional)
|
| 12 |
+
ANTHROPIC_API_KEY=
|
| 13 |
+
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
| 14 |
+
|
| 15 |
+
# Ollama (recommended local default)
|
| 16 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 17 |
+
OLLAMA_CHAT_MODEL=llama3.1:8b
|
| 18 |
+
OLLAMA_EMBEDDING_MODEL=nomic-embed-text
|
| 19 |
+
|
| 20 |
+
# App
|
| 21 |
+
APP_NAME=DocuAudit AI
|
| 22 |
+
APP_VERSION=1.0.0
|
| 23 |
+
DEBUG=false
|
| 24 |
+
MAX_FILE_SIZE_MB=50
|
| 25 |
+
# Spec name alias (optional; mapped to MAX_FILE_SIZE_MB in settings)
|
| 26 |
+
MAX_UPLOAD_SIZE_MB=
|
| 27 |
+
|
| 28 |
+
# ChromaDB
|
| 29 |
CHROMA_PERSIST_DIRECTORY=./data/chroma
|
| 30 |
+
CHROMA_PERSIST_DIR=
|
| 31 |
+
CHROMA_COLLECTION_NAME=docuaudit_docs
|
| 32 |
+
|
| 33 |
+
# Chunking
|
| 34 |
CHUNK_SIZE=1000
|
| 35 |
+
CHUNK_OVERLAP=200
|
| 36 |
+
|
| 37 |
+
# Retrieval default (overridable per request on /query/ask via top_k)
|
| 38 |
+
TOP_K_RESULTS=5
|
| 39 |
+
|
| 40 |
+
# Audit + jobs SQLite
|
| 41 |
+
AUDIT_DB_PATH=./audit.db
|
| 42 |
+
JOBS_DB_PATH=./data/jobs.db
|
| 43 |
+
|
| 44 |
+
# Limits
|
| 45 |
+
MAX_DOCUMENTS_PER_BATCH=100
|
| 46 |
+
|
| 47 |
+
# Streamlit → API
|
| 48 |
+
STREAMLIT_BACKEND_URL=http://localhost:8000
|
README.md
CHANGED
|
@@ -1,42 +1,138 @@
|
|
| 1 |
-
#
|
| 2 |
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
uv venv --python 3.11.14
|
| 9 |
-
uv init --python 3.11.14
|
| 10 |
-
uv add requirements.txt
|
| 11 |
-
uv pip install -r requirements.txt
|
| 12 |
-
copy .env.example .env
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
curl -fsSL https://ollama.com/install.sh | sh
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
|
| 25 |
-
ollama serve &
|
| 26 |
|
| 27 |
-
|
| 28 |
-
curl http://localhost:11434/api/tags
|
| 29 |
|
| 30 |
-
### Start backend server- fastapi
|
| 31 |
```bash
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
```
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
```bash
|
| 36 |
-
uv run streamlit run streamlit_app.py --server.address
|
| 37 |
```
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
|
|
|
|
|
|
| 1 |
+
# DocuAudit AI
|
| 2 |
|
| 3 |
+
**DocuAudit AI** is a production-oriented FastAPI backend plus optional Streamlit UI for **multi-document RAG**: upload documents, build a Chroma vector index, ask grounded questions with citations, and retain a **SQLite audit trail** of every query.
|
| 4 |
|
| 5 |
+
## Architecture
|
| 6 |
|
| 7 |
+
```mermaid
|
| 8 |
+
flowchart LR
|
| 9 |
+
subgraph ingest [Ingestion]
|
| 10 |
+
A[PDF / TXT / MD] --> B[Loader]
|
| 11 |
+
B --> C[Chunker]
|
| 12 |
+
C --> D[Embedder]
|
| 13 |
+
D --> E[(ChromaDB)]
|
| 14 |
+
end
|
| 15 |
+
subgraph query [Query path]
|
| 16 |
+
Q[User question] --> R[Semantic search]
|
| 17 |
+
R --> E
|
| 18 |
+
R --> T[Top-K chunks]
|
| 19 |
+
T --> L[LLM]
|
| 20 |
+
L --> U[Answer + citations]
|
| 21 |
+
end
|
| 22 |
+
U --> V[(SQLite audit)]
|
| 23 |
+
```
|
| 24 |
|
| 25 |
+
ASCII equivalent:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
```
|
| 28 |
+
PDF Upload → Parser → Chunker → Embedder → ChromaDB
|
| 29 |
+
↓
|
| 30 |
+
User Query → Semantic Search → Top-K Chunks → LLM → Answer + Citations
|
| 31 |
+
↓
|
| 32 |
+
Audit Log (SQLite)
|
| 33 |
+
```
|
| 34 |
|
| 35 |
+
## Use cases
|
|
|
|
| 36 |
|
| 37 |
+
- **Litigation document analysis** — trace claims to exact pages and filenames.
|
| 38 |
+
- **Corporate finance review** — compare disclosures and filings under a consistent audit log.
|
| 39 |
+
- **Investigation support** — bulk ingest, async jobs, and reproducible query history.
|
| 40 |
|
| 41 |
+
## Quick start (local, without Docker)
|
|
|
|
| 42 |
|
| 43 |
+
Docker and Compose are planned under **Milestone 12**. Until then, run the API with **uv** (or your preferred tool):
|
|
|
|
| 44 |
|
|
|
|
| 45 |
```bash
|
| 46 |
+
git clone <repository-url> doc-Audi-ai
|
| 47 |
+
cd doc-Audi-ai
|
| 48 |
+
copy .env.example .env
|
| 49 |
+
uv sync
|
| 50 |
+
ollama pull llama3.1:8b
|
| 51 |
+
ollama pull nomic-embed-text
|
| 52 |
+
uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
|
| 53 |
```
|
| 54 |
+
|
| 55 |
+
Optional UI:
|
| 56 |
+
|
| 57 |
```bash
|
| 58 |
+
uv run streamlit run streamlit_app.py --server.port 8501 --server.address 0.0.0.0
|
| 59 |
```
|
| 60 |
|
| 61 |
+
After **Milestone 12**, the intended one-command experience will be `docker compose up` for API (`localhost:8000`) and UI (`localhost:8501`).
|
| 62 |
+
|
| 63 |
+
## API overview
|
| 64 |
+
|
| 65 |
+
| Method | Path | Description |
|
| 66 |
+
|--------|------|-------------|
|
| 67 |
+
| GET | `/health` | Liveness; returns configured app name and version |
|
| 68 |
+
| POST | `/ingest/upload` | Multipart **`files`** (one or more); queues background ingest job |
|
| 69 |
+
| POST | `/ingest/url` | JSON **`urls`** array (1–100); download and queue ingest |
|
| 70 |
+
| GET | `/ingest/collections` | Lists collections with **`document_count`** and optional **`created_at`** |
|
| 71 |
+
| DELETE | `/ingest/collection/{collection_name}` | Drops a collection; returns **`documents_removed`** |
|
| 72 |
+
| GET | `/jobs` | Lists jobs with **`total`** count |
|
| 73 |
+
| GET | `/jobs/{job_id}` | Job status with **`progress_percent`**, file counters, timestamps, **`errors`** |
|
| 74 |
+
| POST | `/query/ask` | Grounded answer; request includes **`top_k`**, **`user_id`** |
|
| 75 |
+
| POST | `/query/summarise` | Collection summary; distinct response shape (`summary`, `document_count`, …) |
|
| 76 |
+
| POST | `/query` | Legacy alias of **`/query/ask`** |
|
| 77 |
+
| GET | `/audit/logs` | Filterable audit index (`user_id`, `from_date`, `to_date`, pagination) |
|
| 78 |
+
| GET | `/audit/logs/{query_id}` | Full stored answer and citations for one query |
|
| 79 |
+
|
| 80 |
+
Interactive docs: `http://localhost:8000/docs`.
|
| 81 |
+
|
| 82 |
+
## Sample request and response (`POST /query/ask`)
|
| 83 |
+
|
| 84 |
+
Request:
|
| 85 |
+
|
| 86 |
+
```json
|
| 87 |
+
{
|
| 88 |
+
"question": "What were the key risk factors identified in the Q3 2023 financial report?",
|
| 89 |
+
"collection_name": "default",
|
| 90 |
+
"top_k": 5,
|
| 91 |
+
"user_id": "analyst_001"
|
| 92 |
+
}
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Response (shape; values depend on your documents and model):
|
| 96 |
+
|
| 97 |
+
```json
|
| 98 |
+
{
|
| 99 |
+
"query_id": "uuid-string",
|
| 100 |
+
"question": "What were the key risk factors identified in the Q3 2023 financial report?",
|
| 101 |
+
"answer": "… grounded text with citations …",
|
| 102 |
+
"sources": [
|
| 103 |
+
{
|
| 104 |
+
"document_name": "q3_financial_report.pdf",
|
| 105 |
+
"page_number": 12,
|
| 106 |
+
"chunk_text": "Key risk factors include …",
|
| 107 |
+
"relevance_score": 0.91
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
"model_used": "llama3.1:8b",
|
| 111 |
+
"tokens_used": 0,
|
| 112 |
+
"response_time_ms": 1820,
|
| 113 |
+
"timestamp": "2026-05-03T12:00:00Z"
|
| 114 |
+
}
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Design decisions
|
| 118 |
+
|
| 119 |
+
- **Source citations** — High-stakes review requires every substantive claim to be tied to **document name** and **page** (where available), not a free-floating model monologue.
|
| 120 |
+
- **Auditability** — Each ask/summarise persists **query id**, **user id**, timing, model id, token usage (when the provider exposes it), and serialized sources so regulators or counsel can reconstruct what the system returned.
|
| 121 |
+
|
| 122 |
+
## Scale note
|
| 123 |
+
|
| 124 |
+
Architecture is designed for **high-volume document ingestion** via **async background jobs** (FastAPI `BackgroundTasks`), persistent Chroma collections, and a stateless API tier that can be replicated once you add a shared vector store and job queue.
|
| 125 |
+
|
| 126 |
+
## Tests
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
uv run pytest tests/ -q
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## Configuration
|
| 133 |
+
|
| 134 |
+
See **`.env.example`**. Common variables include `LLM_PROVIDER`, Ollama/OpenAI/Anthropic keys and models, `CHROMA_PERSIST_DIRECTORY`, `AUDIT_DB_PATH`, `JOBS_DB_PATH`, and upload limits (`MAX_FILE_SIZE_MB`; **`MAX_UPLOAD_SIZE_MB`** is accepted as an alias via settings normalization).
|
| 135 |
|
| 136 |
+
## Specification
|
| 137 |
|
| 138 |
+
Authoritative product and API shapes: **`docs/DOCUAUDIT_AI_REQUIREMENTS.md`**. Gap tracking: **`docs/REQUIREMENTS_IMPLEMENTATION_GAPS.md`**.
|
api/config.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
from functools import lru_cache
|
| 2 |
-
from
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 5 |
|
|
@@ -10,10 +12,30 @@ class Settings(BaseSettings):
|
|
| 10 |
env_file_encoding="utf-8",
|
| 11 |
extra="ignore",
|
| 12 |
case_sensitive=False,
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
llm_provider: str = Field(default="ollama", description="Embedding provider")
|
| 18 |
|
| 19 |
openai_api_key: str | None = Field(default=None, description="OpenAI API key")
|
|
@@ -38,12 +60,12 @@ class Settings(BaseSettings):
|
|
| 38 |
|
| 39 |
chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
|
| 40 |
chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
|
| 41 |
-
top_k_results: int = Field(default=
|
| 42 |
|
| 43 |
audit_db_path: str = "./audit.db"
|
| 44 |
jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
|
| 45 |
|
| 46 |
-
max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size")
|
| 47 |
max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
|
| 48 |
|
| 49 |
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from pydantic import Field, model_validator
|
| 5 |
|
| 6 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 7 |
|
|
|
|
| 12 |
env_file_encoding="utf-8",
|
| 13 |
extra="ignore",
|
| 14 |
case_sensitive=False,
|
| 15 |
+
populate_by_name=True,
|
| 16 |
)
|
| 17 |
|
| 18 |
+
@model_validator(mode="before")
|
| 19 |
+
@classmethod
|
| 20 |
+
def _map_max_upload_env_alias(cls, data: Any) -> Any:
|
| 21 |
+
if not isinstance(data, dict):
|
| 22 |
+
return data
|
| 23 |
+
out = dict(data)
|
| 24 |
+
if out.get("max_file_size_mb") in (None, "") and out.get("max_upload_size_mb") not in (None, ""):
|
| 25 |
+
out["max_file_size_mb"] = out.pop("max_upload_size_mb")
|
| 26 |
+
elif "max_upload_size_mb" in out and "max_file_size_mb" not in out:
|
| 27 |
+
out["max_file_size_mb"] = out.pop("max_upload_size_mb")
|
| 28 |
+
return out
|
| 29 |
+
|
| 30 |
+
app_name: str = Field(default="DocuAudit AI", description="FastAPI title and product name")
|
| 31 |
+
app_version: str = Field(default="1.0.0", description="Application version")
|
| 32 |
+
app_description: str = Field(
|
| 33 |
+
default=(
|
| 34 |
+
"Multi-document RAG API for high-stakes consulting environments. "
|
| 35 |
+
"Every answer is grounded in source documents with full audit trails."
|
| 36 |
+
),
|
| 37 |
+
description="OpenAPI /docs description",
|
| 38 |
+
)
|
| 39 |
llm_provider: str = Field(default="ollama", description="Embedding provider")
|
| 40 |
|
| 41 |
openai_api_key: str | None = Field(default=None, description="OpenAI API key")
|
|
|
|
| 60 |
|
| 61 |
chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
|
| 62 |
chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
|
| 63 |
+
top_k_results: int = Field(default=5, ge=1, le=20, description="Default number of chunks to retrieve")
|
| 64 |
|
| 65 |
audit_db_path: str = "./audit.db"
|
| 66 |
jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
|
| 67 |
|
| 68 |
+
max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size (MB)")
|
| 69 |
max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
|
| 70 |
|
| 71 |
|
api/main.py
CHANGED
|
@@ -4,13 +4,27 @@ import os
|
|
| 4 |
os.environ.setdefault("ANONYMIZED_TELEMETRY", "FALSE")
|
| 5 |
|
| 6 |
from fastapi import FastAPI
|
|
|
|
| 7 |
|
| 8 |
from api.config import get_settings
|
| 9 |
from storage.audit_store import init_audit_db
|
| 10 |
from storage.job_store import init_jobs_db
|
| 11 |
from .routes import audit, ingest, jobs, query
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
app.include_router(audit.router)
|
| 16 |
app.include_router(ingest.router)
|
|
@@ -25,6 +39,12 @@ async def startup() -> None:
|
|
| 25 |
await init_audit_db(settings.audit_db_path)
|
| 26 |
await init_jobs_db(settings.jobs_db_path)
|
| 27 |
|
|
|
|
| 28 |
@app.get("/health", tags=["Health"])
|
| 29 |
def health() -> dict[str, str]:
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
os.environ.setdefault("ANONYMIZED_TELEMETRY", "FALSE")
|
| 5 |
|
| 6 |
from fastapi import FastAPI
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
from api.config import get_settings
|
| 10 |
from storage.audit_store import init_audit_db
|
| 11 |
from storage.job_store import init_jobs_db
|
| 12 |
from .routes import audit, ingest, jobs, query
|
| 13 |
|
| 14 |
+
_settings = get_settings()
|
| 15 |
+
app = FastAPI(
|
| 16 |
+
title=_settings.app_name,
|
| 17 |
+
version=_settings.app_version,
|
| 18 |
+
description=_settings.app_description,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
app.add_middleware(
|
| 22 |
+
CORSMiddleware,
|
| 23 |
+
allow_origins=["*"],
|
| 24 |
+
allow_credentials=True,
|
| 25 |
+
allow_methods=["*"],
|
| 26 |
+
allow_headers=["*"],
|
| 27 |
+
)
|
| 28 |
|
| 29 |
app.include_router(audit.router)
|
| 30 |
app.include_router(ingest.router)
|
|
|
|
| 39 |
await init_audit_db(settings.audit_db_path)
|
| 40 |
await init_jobs_db(settings.jobs_db_path)
|
| 41 |
|
| 42 |
+
|
| 43 |
@app.get("/health", tags=["Health"])
|
| 44 |
def health() -> dict[str, str]:
|
| 45 |
+
settings = get_settings()
|
| 46 |
+
return {
|
| 47 |
+
"status": "ok",
|
| 48 |
+
"app": settings.app_name,
|
| 49 |
+
"version": settings.app_version,
|
| 50 |
+
}
|
api/routes/audit.py
CHANGED
|
@@ -4,42 +4,54 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
| 4 |
|
| 5 |
from api.config import get_settings
|
| 6 |
from models.requests import AuditListParams
|
| 7 |
-
from models.responses import
|
| 8 |
from storage.audit_store import get_audit_event, list_audit_events
|
| 9 |
|
| 10 |
|
| 11 |
def _audit_list_params(
|
| 12 |
-
limit: Annotated[int, Query(ge=1, le=100)] =
|
| 13 |
offset: Annotated[int, Query(ge=0)] = 0,
|
|
|
|
|
|
|
|
|
|
| 14 |
) -> AuditListParams:
|
| 15 |
-
return AuditListParams(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
router = APIRouter(prefix="/audit", tags=["audit"])
|
| 19 |
|
| 20 |
|
| 21 |
-
@router.get("/logs", response_model=
|
| 22 |
async def audit_logs(
|
| 23 |
params: Annotated[AuditListParams, Depends(_audit_list_params)],
|
| 24 |
-
) ->
|
| 25 |
settings = get_settings()
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
|
| 35 |
-
@router.get("/logs/{query_id}", response_model=
|
| 36 |
-
async def audit_log_detail(query_id: str) ->
|
| 37 |
settings = get_settings()
|
| 38 |
event = await get_audit_event(settings.audit_db_path, query_id)
|
| 39 |
if event is None:
|
| 40 |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
|
| 41 |
-
return
|
| 42 |
-
status="success",
|
| 43 |
-
message="Audit event retrieved.",
|
| 44 |
-
event=event,
|
| 45 |
-
)
|
|
|
|
| 4 |
|
| 5 |
from api.config import get_settings
|
| 6 |
from models.requests import AuditListParams
|
| 7 |
+
from models.responses import AuditLogDetailResponse, AuditLogsResponse
|
| 8 |
from storage.audit_store import get_audit_event, list_audit_events
|
| 9 |
|
| 10 |
|
| 11 |
def _audit_list_params(
|
| 12 |
+
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
| 13 |
offset: Annotated[int, Query(ge=0)] = 0,
|
| 14 |
+
user_id: Annotated[str | None, Query(max_length=256)] = None,
|
| 15 |
+
from_date: Annotated[str | None, Query(description="ISO 8601 lower bound")] = None,
|
| 16 |
+
to_date: Annotated[str | None, Query(description="ISO 8601 upper bound")] = None,
|
| 17 |
) -> AuditListParams:
|
| 18 |
+
return AuditListParams(
|
| 19 |
+
limit=limit,
|
| 20 |
+
offset=offset,
|
| 21 |
+
user_id=user_id,
|
| 22 |
+
from_date=from_date,
|
| 23 |
+
to_date=to_date,
|
| 24 |
+
)
|
| 25 |
|
| 26 |
|
| 27 |
router = APIRouter(prefix="/audit", tags=["audit"])
|
| 28 |
|
| 29 |
|
| 30 |
+
@router.get("/logs", response_model=AuditLogsResponse)
|
| 31 |
async def audit_logs(
|
| 32 |
params: Annotated[AuditListParams, Depends(_audit_list_params)],
|
| 33 |
+
) -> AuditLogsResponse:
|
| 34 |
settings = get_settings()
|
| 35 |
+
logs, total = await list_audit_events(
|
| 36 |
+
settings.audit_db_path,
|
| 37 |
+
limit=params.limit,
|
| 38 |
+
offset=params.offset,
|
| 39 |
+
user_id=params.user_id,
|
| 40 |
+
from_date=params.from_date,
|
| 41 |
+
to_date=params.to_date,
|
| 42 |
+
)
|
| 43 |
+
return AuditLogsResponse(
|
| 44 |
+
logs=logs,
|
| 45 |
+
total=total,
|
| 46 |
+
limit=params.limit,
|
| 47 |
+
offset=params.offset,
|
| 48 |
)
|
| 49 |
|
| 50 |
|
| 51 |
+
@router.get("/logs/{query_id}", response_model=AuditLogDetailResponse)
|
| 52 |
+
async def audit_log_detail(query_id: str) -> AuditLogDetailResponse:
|
| 53 |
settings = get_settings()
|
| 54 |
event = await get_audit_event(settings.audit_db_path, query_id)
|
| 55 |
if event is None:
|
| 56 |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
|
| 57 |
+
return event
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/ingest.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
from tempfile import NamedTemporaryFile
|
| 3 |
from typing import Annotated
|
|
@@ -7,14 +8,20 @@ import httpx
|
|
| 7 |
from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
|
| 8 |
|
| 9 |
from api.config import get_settings
|
| 10 |
-
from models.requests import
|
| 11 |
from models.responses import (
|
|
|
|
| 12 |
IngestCollectionsResponse,
|
| 13 |
IngestDeleteCollectionResponse,
|
| 14 |
IngestUploadResponse,
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
-
from rag.vector_store import delete_collection, list_collection_names
|
| 18 |
from storage.job_store import create_ingest_job
|
| 19 |
from workers.ingest_worker import run_ingest_job
|
| 20 |
|
|
@@ -86,7 +93,7 @@ async def _download_url_to_temp(url: str, max_bytes: int) -> tuple[str, str]:
|
|
| 86 |
|
| 87 |
timeout = httpx.Timeout(60.0, connect=10.0)
|
| 88 |
limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
|
| 89 |
-
headers = {"User-Agent": "
|
| 90 |
|
| 91 |
try:
|
| 92 |
async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
|
|
@@ -132,97 +139,131 @@ async def _download_url_to_temp(url: str, max_bytes: int) -> tuple[str, str]:
|
|
| 132 |
return temp_path, display_name
|
| 133 |
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
@router.post("/upload", response_model=IngestUploadResponse)
|
| 136 |
async def upload_endpoint(
|
| 137 |
background_tasks: BackgroundTasks,
|
| 138 |
-
|
| 139 |
collection_name: Annotated[str, Form(min_length=1, max_length=256)] = "default",
|
| 140 |
) -> IngestUploadResponse:
|
| 141 |
settings = get_settings()
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
try:
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
job_id = await create_ingest_job(
|
| 154 |
settings.jobs_db_path,
|
| 155 |
-
collection_name=collection_name,
|
| 156 |
-
|
| 157 |
)
|
| 158 |
|
| 159 |
background_tasks.add_task(
|
| 160 |
run_ingest_job,
|
| 161 |
job_id,
|
| 162 |
-
|
| 163 |
-
collection_name,
|
| 164 |
settings.jobs_db_path,
|
| 165 |
settings.chroma_persist_directory,
|
| 166 |
)
|
| 167 |
|
| 168 |
return IngestUploadResponse(
|
| 169 |
-
status="queued",
|
| 170 |
-
message=f"Ingestion job accepted. Poll GET /jobs/{job_id} for status.",
|
| 171 |
job_id=job_id,
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
| 173 |
)
|
| 174 |
except HTTPException:
|
| 175 |
-
|
| 176 |
-
Path(
|
| 177 |
raise
|
| 178 |
except Exception as exc:
|
| 179 |
-
|
| 180 |
-
Path(
|
| 181 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 182 |
-
finally:
|
| 183 |
-
await file.close()
|
| 184 |
|
| 185 |
|
| 186 |
-
@router.post("/url", response_model=
|
| 187 |
async def ingest_url_endpoint(
|
| 188 |
background_tasks: BackgroundTasks,
|
| 189 |
-
payload:
|
| 190 |
-
) ->
|
| 191 |
settings = get_settings()
|
| 192 |
max_bytes = settings.max_file_size_mb * 1024 * 1024
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
try:
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
|
|
|
| 198 |
job_id = await create_ingest_job(
|
| 199 |
settings.jobs_db_path,
|
| 200 |
-
collection_name=
|
| 201 |
-
|
| 202 |
)
|
| 203 |
|
| 204 |
background_tasks.add_task(
|
| 205 |
run_ingest_job,
|
| 206 |
job_id,
|
| 207 |
-
|
| 208 |
-
|
| 209 |
settings.jobs_db_path,
|
| 210 |
settings.chroma_persist_directory,
|
| 211 |
)
|
| 212 |
|
| 213 |
-
return
|
| 214 |
-
status="queued",
|
| 215 |
-
message=f"Ingestion job accepted. Poll GET /jobs/{job_id} for status.",
|
| 216 |
job_id=job_id,
|
| 217 |
-
|
|
|
|
|
|
|
| 218 |
)
|
| 219 |
except HTTPException:
|
| 220 |
-
|
| 221 |
-
Path(
|
| 222 |
raise
|
| 223 |
except Exception as exc:
|
| 224 |
-
|
| 225 |
-
Path(
|
| 226 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 227 |
|
| 228 |
|
|
@@ -233,12 +274,18 @@ async def list_collections_endpoint() -> IngestCollectionsResponse:
|
|
| 233 |
names = list_collection_names(settings.chroma_persist_directory)
|
| 234 |
except Exception as exc:
|
| 235 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 236 |
-
items
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
@router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
|
|
@@ -251,13 +298,12 @@ async def delete_collection_endpoint(collection_name: str) -> IngestDeleteCollec
|
|
| 251 |
existing = list_collection_names(settings.chroma_persist_directory)
|
| 252 |
if name not in existing:
|
| 253 |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
|
| 254 |
-
delete_collection(settings.chroma_persist_directory, name)
|
| 255 |
except HTTPException:
|
| 256 |
raise
|
| 257 |
except Exception as exc:
|
| 258 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 259 |
return IngestDeleteCollectionResponse(
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
collection_name=name,
|
| 263 |
)
|
|
|
|
| 1 |
+
from datetime import datetime, timezone
|
| 2 |
from pathlib import Path
|
| 3 |
from tempfile import NamedTemporaryFile
|
| 4 |
from typing import Annotated
|
|
|
|
| 8 |
from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
|
| 9 |
|
| 10 |
from api.config import get_settings
|
| 11 |
+
from models.requests import URLIngestRequest
|
| 12 |
from models.responses import (
|
| 13 |
+
CollectionItem,
|
| 14 |
IngestCollectionsResponse,
|
| 15 |
IngestDeleteCollectionResponse,
|
| 16 |
IngestUploadResponse,
|
| 17 |
+
UrlIngestResponse,
|
| 18 |
+
)
|
| 19 |
+
from rag.vector_store import (
|
| 20 |
+
collection_created_at,
|
| 21 |
+
collection_document_count,
|
| 22 |
+
delete_collection,
|
| 23 |
+
list_collection_names,
|
| 24 |
)
|
|
|
|
| 25 |
from storage.job_store import create_ingest_job
|
| 26 |
from workers.ingest_worker import run_ingest_job
|
| 27 |
|
|
|
|
| 93 |
|
| 94 |
timeout = httpx.Timeout(60.0, connect=10.0)
|
| 95 |
limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
|
| 96 |
+
headers = {"User-Agent": "docuaudit-ai/ingest"}
|
| 97 |
|
| 98 |
try:
|
| 99 |
async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
|
|
|
|
| 139 |
return temp_path, display_name
|
| 140 |
|
| 141 |
|
| 142 |
+
def _parse_created_at(raw: str | None) -> datetime | None:
|
| 143 |
+
if not raw:
|
| 144 |
+
return None
|
| 145 |
+
s = raw.strip()
|
| 146 |
+
if s.endswith("Z"):
|
| 147 |
+
s = s[:-1] + "+00:00"
|
| 148 |
+
try:
|
| 149 |
+
dt = datetime.fromisoformat(s)
|
| 150 |
+
if dt.tzinfo is None:
|
| 151 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 152 |
+
return dt
|
| 153 |
+
except ValueError:
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
@router.post("/upload", response_model=IngestUploadResponse)
|
| 158 |
async def upload_endpoint(
|
| 159 |
background_tasks: BackgroundTasks,
|
| 160 |
+
files: list[UploadFile] = File(..., description="One or more PDF, TXT, or MD files"),
|
| 161 |
collection_name: Annotated[str, Form(min_length=1, max_length=256)] = "default",
|
| 162 |
) -> IngestUploadResponse:
|
| 163 |
settings = get_settings()
|
| 164 |
+
if not files:
|
| 165 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required.")
|
| 166 |
+
if len(files) > settings.max_documents_per_batch:
|
| 167 |
+
raise HTTPException(
|
| 168 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 169 |
+
detail=f"Too many files in one request (max {settings.max_documents_per_batch}).",
|
| 170 |
+
)
|
| 171 |
|
| 172 |
+
max_bytes = settings.max_file_size_mb * 1024 * 1024
|
| 173 |
+
temp_paths: list[tuple[str, str]] = []
|
| 174 |
+
filenames: list[str] = []
|
| 175 |
try:
|
| 176 |
+
for file in files:
|
| 177 |
+
suffix = _validate_file(file, max_bytes)
|
| 178 |
+
display_name = (file.filename or "upload").strip()
|
| 179 |
+
file_bytes = await file.read()
|
| 180 |
+
with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 181 |
+
tmp.write(file_bytes)
|
| 182 |
+
temp_paths.append((tmp.name, display_name))
|
| 183 |
+
filenames.append(display_name)
|
| 184 |
+
await file.close()
|
| 185 |
|
| 186 |
job_id = await create_ingest_job(
|
| 187 |
settings.jobs_db_path,
|
| 188 |
+
collection_name=collection_name.strip(),
|
| 189 |
+
filenames=filenames,
|
| 190 |
)
|
| 191 |
|
| 192 |
background_tasks.add_task(
|
| 193 |
run_ingest_job,
|
| 194 |
job_id,
|
| 195 |
+
temp_paths,
|
| 196 |
+
collection_name.strip(),
|
| 197 |
settings.jobs_db_path,
|
| 198 |
settings.chroma_persist_directory,
|
| 199 |
)
|
| 200 |
|
| 201 |
return IngestUploadResponse(
|
|
|
|
|
|
|
| 202 |
job_id=job_id,
|
| 203 |
+
status="queued",
|
| 204 |
+
total_files=len(filenames),
|
| 205 |
+
filenames=filenames,
|
| 206 |
+
message=f"Documents queued for processing. Poll /jobs/{job_id} for status.",
|
| 207 |
)
|
| 208 |
except HTTPException:
|
| 209 |
+
for path, _ in temp_paths:
|
| 210 |
+
Path(path).unlink(missing_ok=True)
|
| 211 |
raise
|
| 212 |
except Exception as exc:
|
| 213 |
+
for path, _ in temp_paths:
|
| 214 |
+
Path(path).unlink(missing_ok=True)
|
| 215 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
|
|
|
|
|
|
| 216 |
|
| 217 |
|
| 218 |
+
@router.post("/url", response_model=UrlIngestResponse)
|
| 219 |
async def ingest_url_endpoint(
|
| 220 |
background_tasks: BackgroundTasks,
|
| 221 |
+
payload: URLIngestRequest,
|
| 222 |
+
) -> UrlIngestResponse:
|
| 223 |
settings = get_settings()
|
| 224 |
max_bytes = settings.max_file_size_mb * 1024 * 1024
|
| 225 |
+
url_strings = [str(u).strip() for u in payload.urls]
|
| 226 |
+
if len(url_strings) > settings.max_documents_per_batch:
|
| 227 |
+
raise HTTPException(
|
| 228 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 229 |
+
detail=f"Too many URLs in one request (max {settings.max_documents_per_batch}).",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
downloaded: list[tuple[str, str]] = []
|
| 233 |
try:
|
| 234 |
+
for url_str in url_strings:
|
| 235 |
+
temp_path, display_name = await _download_url_to_temp(url_str, max_bytes)
|
| 236 |
+
downloaded.append((temp_path, display_name))
|
| 237 |
|
| 238 |
+
coll = (payload.collection_name or "default").strip()
|
| 239 |
job_id = await create_ingest_job(
|
| 240 |
settings.jobs_db_path,
|
| 241 |
+
collection_name=coll,
|
| 242 |
+
filenames=[name for _, name in downloaded],
|
| 243 |
)
|
| 244 |
|
| 245 |
background_tasks.add_task(
|
| 246 |
run_ingest_job,
|
| 247 |
job_id,
|
| 248 |
+
downloaded,
|
| 249 |
+
coll,
|
| 250 |
settings.jobs_db_path,
|
| 251 |
settings.chroma_persist_directory,
|
| 252 |
)
|
| 253 |
|
| 254 |
+
return UrlIngestResponse(
|
|
|
|
|
|
|
| 255 |
job_id=job_id,
|
| 256 |
+
status="queued",
|
| 257 |
+
total_urls=len(downloaded),
|
| 258 |
+
message="URLs queued for download and processing.",
|
| 259 |
)
|
| 260 |
except HTTPException:
|
| 261 |
+
for path, _ in downloaded:
|
| 262 |
+
Path(path).unlink(missing_ok=True)
|
| 263 |
raise
|
| 264 |
except Exception as exc:
|
| 265 |
+
for path, _ in downloaded:
|
| 266 |
+
Path(path).unlink(missing_ok=True)
|
| 267 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 268 |
|
| 269 |
|
|
|
|
| 274 |
names = list_collection_names(settings.chroma_persist_directory)
|
| 275 |
except Exception as exc:
|
| 276 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 277 |
+
items: list[CollectionItem] = []
|
| 278 |
+
for n in names:
|
| 279 |
+
cnt = collection_document_count(settings.chroma_persist_directory, n)
|
| 280 |
+
raw_created = collection_created_at(settings.chroma_persist_directory, n)
|
| 281 |
+
items.append(
|
| 282 |
+
CollectionItem(
|
| 283 |
+
name=n,
|
| 284 |
+
document_count=cnt,
|
| 285 |
+
created_at=_parse_created_at(raw_created),
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
return IngestCollectionsResponse(collections=items, total=len(items))
|
| 289 |
|
| 290 |
|
| 291 |
@router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
|
|
|
|
| 298 |
existing = list_collection_names(settings.chroma_persist_directory)
|
| 299 |
if name not in existing:
|
| 300 |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
|
| 301 |
+
removed = delete_collection(settings.chroma_persist_directory, name)
|
| 302 |
except HTTPException:
|
| 303 |
raise
|
| 304 |
except Exception as exc:
|
| 305 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 306 |
return IngestDeleteCollectionResponse(
|
| 307 |
+
message=f"Collection '{name}' deleted successfully.",
|
| 308 |
+
documents_removed=removed,
|
|
|
|
| 309 |
)
|
api/routes/jobs.py
CHANGED
|
@@ -4,8 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
| 4 |
|
| 5 |
from api.config import get_settings
|
| 6 |
from models.requests import JobsListParams
|
| 7 |
-
from models.responses import
|
| 8 |
-
from storage.job_store import
|
| 9 |
|
| 10 |
|
| 11 |
def _jobs_list_params(
|
|
@@ -23,27 +23,18 @@ async def list_jobs(
|
|
| 23 |
params: Annotated[JobsListParams, Depends(_jobs_list_params)],
|
| 24 |
) -> JobListResponse:
|
| 25 |
settings = get_settings()
|
| 26 |
-
|
| 27 |
settings.jobs_db_path,
|
| 28 |
limit=params.limit,
|
| 29 |
offset=params.offset,
|
| 30 |
)
|
| 31 |
-
|
| 32 |
-
return JobListResponse(
|
| 33 |
-
status="success",
|
| 34 |
-
message=f"Returned {len(jobs)} job(s).",
|
| 35 |
-
jobs=jobs,
|
| 36 |
-
)
|
| 37 |
|
| 38 |
|
| 39 |
-
@router.get("/jobs/{job_id}", response_model=
|
| 40 |
-
async def get_job(job_id: str) ->
|
| 41 |
settings = get_settings()
|
| 42 |
-
job = await
|
| 43 |
if job is None:
|
| 44 |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found.")
|
| 45 |
-
return
|
| 46 |
-
status="success",
|
| 47 |
-
message="Job found.",
|
| 48 |
-
job=job,
|
| 49 |
-
)
|
|
|
|
| 4 |
|
| 5 |
from api.config import get_settings
|
| 6 |
from models.requests import JobsListParams
|
| 7 |
+
from models.responses import JobListResponse, JobStatusResponse
|
| 8 |
+
from storage.job_store import get_job_status, list_ingest_jobs
|
| 9 |
|
| 10 |
|
| 11 |
def _jobs_list_params(
|
|
|
|
| 23 |
params: Annotated[JobsListParams, Depends(_jobs_list_params)],
|
| 24 |
) -> JobListResponse:
|
| 25 |
settings = get_settings()
|
| 26 |
+
jobs, total = await list_ingest_jobs(
|
| 27 |
settings.jobs_db_path,
|
| 28 |
limit=params.limit,
|
| 29 |
offset=params.offset,
|
| 30 |
)
|
| 31 |
+
return JobListResponse(jobs=jobs, total=total)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
+
@router.get("/jobs/{job_id}", response_model=JobStatusResponse)
|
| 35 |
+
async def get_job(job_id: str) -> JobStatusResponse:
|
| 36 |
settings = get_settings()
|
| 37 |
+
job = await get_job_status(settings.jobs_db_path, job_id)
|
| 38 |
if job is None:
|
| 39 |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found.")
|
| 40 |
+
return job
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/query.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import APIRouter, HTTPException, status
|
| 2 |
|
| 3 |
-
from api.config import get_settings
|
| 4 |
from models.requests import QueryRequest, SummariseRequest
|
| 5 |
-
from models.responses import
|
| 6 |
from rag.embedder import create_embedding_function
|
| 7 |
from rag.retriever import (
|
| 8 |
SUMMARY_RETRIEVAL_QUERY,
|
|
@@ -11,117 +15,153 @@ from rag.retriever import (
|
|
| 11 |
retrieve_chunks,
|
| 12 |
summarise_with_grounding,
|
| 13 |
)
|
| 14 |
-
from rag.vector_store import get_vector_store
|
| 15 |
from storage.audit_store import persist_query_audit
|
| 16 |
|
| 17 |
router = APIRouter(prefix="/query", tags=["query"])
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
page=chunk.page,
|
| 32 |
-
chunk_index=chunk.chunk_index,
|
| 33 |
-
score=chunk.score,
|
| 34 |
-
excerpt=chunk.text[:280],
|
| 35 |
-
)
|
| 36 |
-
for chunk in chunks
|
| 37 |
-
]
|
| 38 |
-
return QueryResponse(
|
| 39 |
-
status="success",
|
| 40 |
-
message=message,
|
| 41 |
-
answer=answer,
|
| 42 |
-
sources=sources,
|
| 43 |
-
results=results,
|
| 44 |
-
)
|
| 45 |
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
-
|
| 58 |
-
answer = answer_with_grounding(settings, payload.question, chunks)
|
| 59 |
-
except Exception as exc:
|
| 60 |
-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
answer=answer,
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
-
try:
|
| 71 |
-
await persist_query_audit(
|
| 72 |
-
settings.audit_db_path,
|
| 73 |
-
action="query",
|
| 74 |
-
question=payload.question,
|
| 75 |
-
collection_name=payload.collection_name,
|
| 76 |
-
response=response,
|
| 77 |
-
)
|
| 78 |
-
except Exception as exc:
|
| 79 |
-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 80 |
return response
|
| 81 |
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
|
| 87 |
audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
try:
|
| 89 |
-
|
| 90 |
-
vector_store = get_vector_store(
|
| 91 |
-
persist_directory=settings.chroma_persist_directory,
|
| 92 |
-
collection_name=payload.collection_name,
|
| 93 |
-
embedding_function=embedding_function,
|
| 94 |
-
)
|
| 95 |
-
chunks = retrieve_chunks(vector_store, retrieval_query, settings.top_k_results)
|
| 96 |
-
answer = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
|
| 97 |
except Exception as exc:
|
| 98 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
message=(
|
| 105 |
-
f"Retrieved {len(chunks)} chunks from '{payload.collection_name}' and generated a grounded summary."
|
| 106 |
-
),
|
| 107 |
-
)
|
| 108 |
try:
|
| 109 |
-
await
|
| 110 |
-
settings.audit_db_path,
|
| 111 |
-
action="summarise",
|
| 112 |
-
question=audit_question,
|
| 113 |
-
collection_name=payload.collection_name,
|
| 114 |
-
response=response,
|
| 115 |
-
)
|
| 116 |
except Exception as exc:
|
| 117 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 118 |
-
return response
|
| 119 |
|
| 120 |
|
| 121 |
legacy_query_router = APIRouter(tags=["query"])
|
| 122 |
|
| 123 |
|
| 124 |
-
@legacy_query_router.post("/query", response_model=
|
| 125 |
-
async def query_post_compat(payload: QueryRequest) ->
|
| 126 |
"""Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
|
| 127 |
return await ask_endpoint(payload)
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from datetime import datetime, timezone
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
|
| 5 |
from fastapi import APIRouter, HTTPException, status
|
| 6 |
|
| 7 |
+
from api.config import Settings, get_settings
|
| 8 |
from models.requests import QueryRequest, SummariseRequest
|
| 9 |
+
from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
|
| 10 |
from rag.embedder import create_embedding_function
|
| 11 |
from rag.retriever import (
|
| 12 |
SUMMARY_RETRIEVAL_QUERY,
|
|
|
|
| 15 |
retrieve_chunks,
|
| 16 |
summarise_with_grounding,
|
| 17 |
)
|
| 18 |
+
from rag.vector_store import collection_document_count, get_vector_store
|
| 19 |
from storage.audit_store import persist_query_audit
|
| 20 |
|
| 21 |
router = APIRouter(prefix="/query", tags=["query"])
|
| 22 |
|
| 23 |
|
| 24 |
+
def _model_used_label(settings: Settings) -> str:
|
| 25 |
+
provider = settings.llm_provider.lower()
|
| 26 |
+
if provider == "openai":
|
| 27 |
+
return settings.openai_model
|
| 28 |
+
if provider == "ollama":
|
| 29 |
+
return settings.ollama_chat_model
|
| 30 |
+
if provider == "anthropic":
|
| 31 |
+
return settings.anthropic_model
|
| 32 |
+
if provider == "huggingface":
|
| 33 |
+
return settings.huggingface_model
|
| 34 |
+
return f"{provider}:unknown"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
+
def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
|
| 38 |
+
citations: list[SourceCitation] = []
|
| 39 |
+
for chunk in chunks:
|
| 40 |
+
page = chunk.page if chunk.page is not None else 0
|
| 41 |
+
score = float(chunk.score) if chunk.score is not None else 0.0
|
| 42 |
+
citations.append(
|
| 43 |
+
SourceCitation(
|
| 44 |
+
document_name=chunk.source or "unknown",
|
| 45 |
+
page_number=page,
|
| 46 |
+
chunk_text=chunk.text,
|
| 47 |
+
relevance_score=score,
|
| 48 |
+
)
|
| 49 |
)
|
| 50 |
+
return citations
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
|
| 53 |
+
async def _run_ask(
|
| 54 |
+
settings: Settings,
|
| 55 |
+
payload: QueryRequest,
|
| 56 |
+
) -> AskQueryResponse:
|
| 57 |
+
top_k = payload.top_k
|
| 58 |
+
t0 = time.perf_counter()
|
| 59 |
+
embedding_function = create_embedding_function()
|
| 60 |
+
vector_store = get_vector_store(
|
| 61 |
+
persist_directory=settings.chroma_persist_directory,
|
| 62 |
+
collection_name=payload.collection_name or "default",
|
| 63 |
+
embedding_function=embedding_function,
|
| 64 |
+
)
|
| 65 |
+
chunks = retrieve_chunks(vector_store, payload.question, top_k)
|
| 66 |
+
answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
|
| 67 |
+
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
| 68 |
+
citations = _chunks_to_citations(chunks)
|
| 69 |
+
query_id = str(uuid4())
|
| 70 |
+
ts = datetime.now(timezone.utc)
|
| 71 |
+
response = AskQueryResponse(
|
| 72 |
+
query_id=query_id,
|
| 73 |
+
question=payload.question,
|
| 74 |
answer=answer,
|
| 75 |
+
sources=citations,
|
| 76 |
+
model_used=_model_used_label(settings),
|
| 77 |
+
tokens_used=tokens_used,
|
| 78 |
+
response_time_ms=elapsed_ms,
|
| 79 |
+
timestamp=ts,
|
| 80 |
+
)
|
| 81 |
+
await persist_query_audit(
|
| 82 |
+
settings.audit_db_path,
|
| 83 |
+
query_id=query_id,
|
| 84 |
+
action="query",
|
| 85 |
+
user_id=payload.user_id,
|
| 86 |
+
question=payload.question,
|
| 87 |
+
collection_name=payload.collection_name or "default",
|
| 88 |
+
answer=answer,
|
| 89 |
+
sources=citations,
|
| 90 |
+
model_used=response.model_used,
|
| 91 |
+
tokens_used=tokens_used,
|
| 92 |
+
response_time_ms=elapsed_ms,
|
| 93 |
+
kind="ask",
|
| 94 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return response
|
| 96 |
|
| 97 |
|
| 98 |
+
async def _run_summarise(
|
| 99 |
+
settings: Settings,
|
| 100 |
+
payload: SummariseRequest,
|
| 101 |
+
) -> SummariseQueryResponse:
|
| 102 |
+
top_k = settings.top_k_results
|
| 103 |
retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
|
| 104 |
audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
|
| 105 |
+
t0 = time.perf_counter()
|
| 106 |
+
embedding_function = create_embedding_function()
|
| 107 |
+
vector_store = get_vector_store(
|
| 108 |
+
persist_directory=settings.chroma_persist_directory,
|
| 109 |
+
collection_name=payload.collection_name,
|
| 110 |
+
embedding_function=embedding_function,
|
| 111 |
+
)
|
| 112 |
+
chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
|
| 113 |
+
summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
|
| 114 |
+
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
| 115 |
+
citations = _chunks_to_citations(chunks)
|
| 116 |
+
doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
|
| 117 |
+
query_id = str(uuid4())
|
| 118 |
+
ts = datetime.now(timezone.utc)
|
| 119 |
+
response = SummariseQueryResponse(
|
| 120 |
+
query_id=query_id,
|
| 121 |
+
summary=summary,
|
| 122 |
+
document_count=doc_count,
|
| 123 |
+
sources=citations,
|
| 124 |
+
timestamp=ts,
|
| 125 |
+
)
|
| 126 |
+
await persist_query_audit(
|
| 127 |
+
settings.audit_db_path,
|
| 128 |
+
query_id=query_id,
|
| 129 |
+
action="summarise",
|
| 130 |
+
user_id=payload.user_id,
|
| 131 |
+
question=audit_question,
|
| 132 |
+
collection_name=payload.collection_name,
|
| 133 |
+
answer=summary,
|
| 134 |
+
sources=citations,
|
| 135 |
+
model_used=_model_used_label(settings),
|
| 136 |
+
tokens_used=tokens_used,
|
| 137 |
+
response_time_ms=elapsed_ms,
|
| 138 |
+
kind="summarise",
|
| 139 |
+
)
|
| 140 |
+
return response
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@router.post("/ask", response_model=AskQueryResponse)
|
| 144 |
+
async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
|
| 145 |
+
settings = get_settings()
|
| 146 |
try:
|
| 147 |
+
return await _run_ask(settings, payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
except Exception as exc:
|
| 149 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 150 |
|
| 151 |
+
|
| 152 |
+
@router.post("/summarise", response_model=SummariseQueryResponse)
|
| 153 |
+
async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
|
| 154 |
+
settings = get_settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
try:
|
| 156 |
+
return await _run_summarise(settings, payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
except Exception as exc:
|
| 158 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
legacy_query_router = APIRouter(tags=["query"])
|
| 162 |
|
| 163 |
|
| 164 |
+
@legacy_query_router.post("/query", response_model=AskQueryResponse)
|
| 165 |
+
async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
|
| 166 |
"""Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
|
| 167 |
return await ask_endpoint(payload)
|
models/requests.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel, ConfigDict, Field, HttpUrl
|
| 2 |
|
| 3 |
|
| 4 |
class QueryRequest(BaseModel):
|
| 5 |
model_config = ConfigDict(extra="forbid")
|
| 6 |
|
| 7 |
-
question: str = Field(min_length=
|
| 8 |
-
collection_name: str = Field(
|
| 9 |
default="default",
|
| 10 |
min_length=1,
|
| 11 |
max_length=256,
|
| 12 |
-
description="
|
| 13 |
)
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class SummariseRequest(BaseModel):
|
|
@@ -20,37 +24,50 @@ class SummariseRequest(BaseModel):
|
|
| 20 |
default="default",
|
| 21 |
min_length=1,
|
| 22 |
max_length=256,
|
| 23 |
-
description="Chroma collection to summarise
|
| 24 |
)
|
| 25 |
focus: str | None = Field(
|
| 26 |
default=None,
|
| 27 |
max_length=8000,
|
| 28 |
-
description="Optional angle or scope for retrieval and the summary
|
| 29 |
)
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
-
class
|
| 33 |
model_config = ConfigDict(extra="forbid")
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
default="default",
|
| 38 |
min_length=1,
|
| 39 |
max_length=256,
|
| 40 |
description="Target Chroma collection name",
|
| 41 |
)
|
| 42 |
|
| 43 |
-
class IngestUploadRequest(BaseModel):
|
| 44 |
-
model_config = ConfigDict(extra="forbid")
|
| 45 |
-
collection_name: str = Field(default="default", min_length=1, max_length=256, description="The name of the collection to upload the document to")
|
| 46 |
-
filename: str = Field(min_length=1, max_length=1024, description="The name of the file to upload")
|
| 47 |
|
| 48 |
class JobsListParams(BaseModel):
|
| 49 |
model_config = ConfigDict(extra="forbid")
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
class AuditListParams(BaseModel):
|
| 54 |
model_config = ConfigDict(extra="forbid")
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
from pydantic import BaseModel, ConfigDict, Field, HttpUrl
|
| 4 |
|
| 5 |
|
| 6 |
class QueryRequest(BaseModel):
|
| 7 |
model_config = ConfigDict(extra="forbid")
|
| 8 |
|
| 9 |
+
question: str = Field(min_length=5, max_length=2000, description="Natural language question")
|
| 10 |
+
collection_name: Optional[str] = Field(
|
| 11 |
default="default",
|
| 12 |
min_length=1,
|
| 13 |
max_length=256,
|
| 14 |
+
description="Chroma collection to search",
|
| 15 |
)
|
| 16 |
+
top_k: int = Field(default=5, ge=1, le=20, description="Number of chunks to retrieve")
|
| 17 |
+
user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
|
| 18 |
|
| 19 |
|
| 20 |
class SummariseRequest(BaseModel):
|
|
|
|
| 24 |
default="default",
|
| 25 |
min_length=1,
|
| 26 |
max_length=256,
|
| 27 |
+
description="Chroma collection to summarise",
|
| 28 |
)
|
| 29 |
focus: str | None = Field(
|
| 30 |
default=None,
|
| 31 |
max_length=8000,
|
| 32 |
+
description="Optional angle or scope for retrieval and the summary",
|
| 33 |
)
|
| 34 |
+
user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
|
| 35 |
|
| 36 |
|
| 37 |
+
class URLIngestRequest(BaseModel):
|
| 38 |
model_config = ConfigDict(extra="forbid")
|
| 39 |
|
| 40 |
+
urls: list[HttpUrl] = Field(
|
| 41 |
+
min_length=1,
|
| 42 |
+
max_length=100,
|
| 43 |
+
description="One or more HTTP(S) URLs to PDF, TXT, or Markdown documents",
|
| 44 |
+
)
|
| 45 |
+
collection_name: Optional[str] = Field(
|
| 46 |
default="default",
|
| 47 |
min_length=1,
|
| 48 |
max_length=256,
|
| 49 |
description="Target Chroma collection name",
|
| 50 |
)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
class JobsListParams(BaseModel):
|
| 54 |
model_config = ConfigDict(extra="forbid")
|
| 55 |
+
|
| 56 |
+
limit: int = Field(default=10, ge=1, le=100, description="Max jobs to return")
|
| 57 |
+
offset: int = Field(default=0, ge=0, description="Offset for pagination")
|
| 58 |
+
|
| 59 |
|
| 60 |
class AuditListParams(BaseModel):
|
| 61 |
model_config = ConfigDict(extra="forbid")
|
| 62 |
+
|
| 63 |
+
limit: int = Field(default=50, ge=1, le=100, description="Max log entries to return")
|
| 64 |
+
offset: int = Field(default=0, ge=0, description="Offset for pagination")
|
| 65 |
+
user_id: str | None = Field(default=None, max_length=256, description="Filter by user id")
|
| 66 |
+
from_date: str | None = Field(
|
| 67 |
+
default=None,
|
| 68 |
+
description="ISO 8601 datetime lower bound (inclusive) on timestamp",
|
| 69 |
+
)
|
| 70 |
+
to_date: str | None = Field(
|
| 71 |
+
default=None,
|
| 72 |
+
description="ISO 8601 datetime upper bound (inclusive) on timestamp",
|
| 73 |
+
)
|
models/responses.py
CHANGED
|
@@ -1,102 +1,130 @@
|
|
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
|
| 3 |
|
| 4 |
-
|
| 5 |
-
text: str | None = None
|
| 6 |
-
score: float | None = None
|
| 7 |
|
| 8 |
-
class QuerySourceItem(BaseModel):
|
| 9 |
-
source: str
|
| 10 |
-
page: int | None = None
|
| 11 |
-
chunk_index: int | None = None
|
| 12 |
-
score: float | None = None
|
| 13 |
-
excerpt: str | None = None
|
| 14 |
|
| 15 |
-
class
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class IngestUploadResponse(BaseModel):
|
|
|
|
| 23 |
status: str
|
|
|
|
|
|
|
| 24 |
message: str
|
|
|
|
|
|
|
|
|
|
| 25 |
job_id: str
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class CollectionItem(BaseModel):
|
| 30 |
name: str
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class IngestCollectionsResponse(BaseModel):
|
| 34 |
-
status: str
|
| 35 |
-
message: str
|
| 36 |
collections: list[CollectionItem] = Field(default_factory=list)
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
class IngestDeleteCollectionResponse(BaseModel):
|
| 40 |
-
status: str
|
| 41 |
message: str
|
| 42 |
-
|
| 43 |
|
| 44 |
-
class JobSummary(BaseModel):
|
| 45 |
-
job_id: str
|
| 46 |
-
status: str
|
| 47 |
-
collection_name: str | None = None
|
| 48 |
-
filename: str | None = None
|
| 49 |
-
created_at: str | None = None
|
| 50 |
|
| 51 |
-
|
| 52 |
-
status: str
|
| 53 |
-
message: str
|
| 54 |
-
jobs: list[JobSummary] = Field(default_factory=list)
|
| 55 |
|
| 56 |
|
| 57 |
-
class
|
| 58 |
job_id: str
|
| 59 |
status: str
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
-
class
|
|
|
|
| 69 |
status: str
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
class AuditEvent(BaseModel):
|
| 74 |
-
event_id: str
|
| 75 |
-
action: str
|
| 76 |
-
question: str | None = None
|
| 77 |
-
collection_name: str | None = None
|
| 78 |
-
created_at: str | None = None
|
| 79 |
|
| 80 |
-
class
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
-
class
|
| 87 |
-
|
| 88 |
-
|
| 89 |
question: str
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
sources: list[QuerySourceItem] = Field(default_factory=list)
|
| 95 |
-
results: list[QueryResultItem] = Field(default_factory=list)
|
| 96 |
-
created_at: str
|
| 97 |
|
| 98 |
|
| 99 |
-
class
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
from pydantic import BaseModel, Field
|
| 4 |
|
| 5 |
|
| 6 |
+
# --- Shared citations (spec-shaped) ---
|
|
|
|
|
|
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
class SourceCitation(BaseModel):
|
| 10 |
+
document_name: str
|
| 11 |
+
page_number: int
|
| 12 |
+
chunk_text: str
|
| 13 |
+
relevance_score: float
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- Query: ask ---
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AskQueryResponse(BaseModel):
|
| 20 |
+
query_id: str
|
| 21 |
+
question: str
|
| 22 |
+
answer: str
|
| 23 |
+
sources: list[SourceCitation] = Field(default_factory=list)
|
| 24 |
+
model_used: str
|
| 25 |
+
tokens_used: int
|
| 26 |
+
response_time_ms: int
|
| 27 |
+
timestamp: datetime
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# --- Query: summarise ---
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SummariseQueryResponse(BaseModel):
|
| 34 |
+
query_id: str
|
| 35 |
+
summary: str
|
| 36 |
+
document_count: int
|
| 37 |
+
sources: list[SourceCitation] = Field(default_factory=list)
|
| 38 |
+
timestamp: datetime
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# --- Ingest ---
|
| 42 |
+
|
| 43 |
|
| 44 |
class IngestUploadResponse(BaseModel):
|
| 45 |
+
job_id: str
|
| 46 |
status: str
|
| 47 |
+
total_files: int
|
| 48 |
+
filenames: list[str]
|
| 49 |
message: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class UrlIngestResponse(BaseModel):
|
| 53 |
job_id: str
|
| 54 |
+
status: str
|
| 55 |
+
total_urls: int
|
| 56 |
+
message: str
|
| 57 |
|
| 58 |
|
| 59 |
class CollectionItem(BaseModel):
|
| 60 |
name: str
|
| 61 |
+
document_count: int
|
| 62 |
+
created_at: datetime | None = None
|
| 63 |
|
| 64 |
|
| 65 |
class IngestCollectionsResponse(BaseModel):
|
|
|
|
|
|
|
| 66 |
collections: list[CollectionItem] = Field(default_factory=list)
|
| 67 |
+
total: int
|
| 68 |
|
| 69 |
|
| 70 |
class IngestDeleteCollectionResponse(BaseModel):
|
|
|
|
| 71 |
message: str
|
| 72 |
+
documents_removed: int
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
# --- Jobs ---
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
+
class JobStatusResponse(BaseModel):
|
| 79 |
job_id: str
|
| 80 |
status: str
|
| 81 |
+
total_files: int
|
| 82 |
+
processed_files: int
|
| 83 |
+
failed_files: int
|
| 84 |
+
progress_percent: int
|
| 85 |
+
started_at: datetime | None
|
| 86 |
+
completed_at: datetime | None
|
| 87 |
+
errors: list[str] = Field(default_factory=list)
|
| 88 |
|
| 89 |
|
| 90 |
+
class JobListItem(BaseModel):
|
| 91 |
+
job_id: str
|
| 92 |
status: str
|
| 93 |
+
total_files: int
|
| 94 |
+
completed_at: datetime | None = None
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
class JobListResponse(BaseModel):
|
| 98 |
+
jobs: list[JobListItem] = Field(default_factory=list)
|
| 99 |
+
total: int
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# --- Audit ---
|
| 103 |
|
| 104 |
|
| 105 |
+
class AuditLogEntry(BaseModel):
|
| 106 |
+
query_id: str
|
| 107 |
+
user_id: str
|
| 108 |
question: str
|
| 109 |
+
answer_summary: str
|
| 110 |
+
sources_count: int
|
| 111 |
+
model_used: str | None
|
| 112 |
+
timestamp: datetime
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
+
class AuditLogsResponse(BaseModel):
|
| 116 |
+
logs: list[AuditLogEntry] = Field(default_factory=list)
|
| 117 |
+
total: int
|
| 118 |
+
limit: int
|
| 119 |
+
offset: int
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class AuditLogDetailResponse(BaseModel):
|
| 123 |
+
query_id: str
|
| 124 |
+
user_id: str
|
| 125 |
+
question: str
|
| 126 |
+
full_answer: str
|
| 127 |
+
sources: list[SourceCitation] = Field(default_factory=list)
|
| 128 |
+
model_used: str | None
|
| 129 |
+
tokens_used: int | None
|
| 130 |
+
timestamp: datetime
|
pytest.ini
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
testpaths = tests
|
| 3 |
+
python_files = test_*.py
|
rag/retriever.py
CHANGED
|
@@ -27,9 +27,27 @@ except ImportError:
|
|
| 27 |
|
| 28 |
from api.config import Settings
|
| 29 |
|
| 30 |
-
NO_MATCH_ANSWER = "I
|
| 31 |
MIN_RELEVANCE_SCORE = 0.15
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
@dataclass
|
| 35 |
class RetrievedChunk:
|
|
@@ -62,31 +80,19 @@ SUMMARY_RETRIEVAL_QUERY = (
|
|
| 62 |
)
|
| 63 |
|
| 64 |
|
| 65 |
-
def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> str:
|
| 66 |
ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
|
| 67 |
if not ranked_chunks:
|
| 68 |
-
return NO_MATCH_ANSWER
|
| 69 |
|
| 70 |
llm = _create_chat_model(settings)
|
| 71 |
prompt_context = _format_context(ranked_chunks)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
content=(
|
| 75 |
-
"You answer questions using only the provided context from uploaded documents. "
|
| 76 |
-
"If the answer is not in context, say you do not know."
|
| 77 |
-
)
|
| 78 |
-
),
|
| 79 |
-
HumanMessage(
|
| 80 |
-
content=(
|
| 81 |
-
f"Question: {question}\n\n"
|
| 82 |
-
f"Context:\n{prompt_context}\n\n"
|
| 83 |
-
"Return a concise grounded answer."
|
| 84 |
-
)
|
| 85 |
-
),
|
| 86 |
-
]
|
| 87 |
response = llm.invoke(messages)
|
| 88 |
answer = _extract_message_text(response).strip()
|
| 89 |
-
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def summarise_with_grounding(
|
|
@@ -94,10 +100,10 @@ def summarise_with_grounding(
|
|
| 94 |
*,
|
| 95 |
focus: str | None,
|
| 96 |
chunks: list[RetrievedChunk],
|
| 97 |
-
) -> str:
|
| 98 |
ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
|
| 99 |
if not ranked_chunks:
|
| 100 |
-
return NO_MATCH_ANSWER
|
| 101 |
|
| 102 |
llm = _create_chat_model(settings)
|
| 103 |
prompt_context = _format_context(ranked_chunks)
|
|
@@ -123,7 +129,8 @@ def summarise_with_grounding(
|
|
| 123 |
]
|
| 124 |
response = llm.invoke(messages)
|
| 125 |
answer = _extract_message_text(response).strip()
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def _create_chat_model(settings: Settings) -> BaseChatModel:
|
|
@@ -186,6 +193,25 @@ def _to_int_or_none(value: object) -> int | None:
|
|
| 186 |
return None
|
| 187 |
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
def _extract_message_text(response: object) -> str:
|
| 190 |
content = getattr(response, "content", "")
|
| 191 |
if isinstance(content, str):
|
|
|
|
| 27 |
|
| 28 |
from api.config import Settings
|
| 29 |
|
| 30 |
+
NO_MATCH_ANSWER = "I cannot find this information in the uploaded documents."
|
| 31 |
MIN_RELEVANCE_SCORE = 0.15
|
| 32 |
|
| 33 |
+
# Verbatim from DOCUAUDIT_AI_REQUIREMENTS.md (placeholders filled at runtime).
|
| 34 |
+
DOCUAUDIT_ASK_TEMPLATE = """You are DocuAudit AI, an expert document analyst for consulting environments.
|
| 35 |
+
|
| 36 |
+
RULES:
|
| 37 |
+
1. Answer ONLY based on the provided document excerpts below.
|
| 38 |
+
2. If the answer is not in the documents, say: "I cannot find this information in the uploaded documents."
|
| 39 |
+
3. ALWAYS cite your sources: mention the document name and page number for every claim.
|
| 40 |
+
4. Be precise and professional. This is a high-stakes consulting environment.
|
| 41 |
+
5. Do not speculate or add information not present in the documents.
|
| 42 |
+
|
| 43 |
+
DOCUMENT EXCERPTS:
|
| 44 |
+
{context}
|
| 45 |
+
|
| 46 |
+
QUESTION: {question}
|
| 47 |
+
|
| 48 |
+
ANSWER (with source citations):
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
|
| 52 |
@dataclass
|
| 53 |
class RetrievedChunk:
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
|
| 83 |
+
def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> tuple[str, int]:
|
| 84 |
ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
|
| 85 |
if not ranked_chunks:
|
| 86 |
+
return NO_MATCH_ANSWER, 0
|
| 87 |
|
| 88 |
llm = _create_chat_model(settings)
|
| 89 |
prompt_context = _format_context(ranked_chunks)
|
| 90 |
+
user_content = DOCUAUDIT_ASK_TEMPLATE.format(context=prompt_context, question=question)
|
| 91 |
+
messages = [HumanMessage(content=user_content)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
response = llm.invoke(messages)
|
| 93 |
answer = _extract_message_text(response).strip()
|
| 94 |
+
tokens = _extract_usage_tokens(response)
|
| 95 |
+
return (answer or NO_MATCH_ANSWER), tokens
|
| 96 |
|
| 97 |
|
| 98 |
def summarise_with_grounding(
|
|
|
|
| 100 |
*,
|
| 101 |
focus: str | None,
|
| 102 |
chunks: list[RetrievedChunk],
|
| 103 |
+
) -> tuple[str, int]:
|
| 104 |
ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
|
| 105 |
if not ranked_chunks:
|
| 106 |
+
return NO_MATCH_ANSWER, 0
|
| 107 |
|
| 108 |
llm = _create_chat_model(settings)
|
| 109 |
prompt_context = _format_context(ranked_chunks)
|
|
|
|
| 129 |
]
|
| 130 |
response = llm.invoke(messages)
|
| 131 |
answer = _extract_message_text(response).strip()
|
| 132 |
+
tokens = _extract_usage_tokens(response)
|
| 133 |
+
return (answer or NO_MATCH_ANSWER), tokens
|
| 134 |
|
| 135 |
|
| 136 |
def _create_chat_model(settings: Settings) -> BaseChatModel:
|
|
|
|
| 193 |
return None
|
| 194 |
|
| 195 |
|
| 196 |
+
def _extract_usage_tokens(response: object) -> int:
|
| 197 |
+
um = getattr(response, "usage_metadata", None)
|
| 198 |
+
if isinstance(um, dict):
|
| 199 |
+
total = um.get("total_tokens")
|
| 200 |
+
if total is not None:
|
| 201 |
+
return int(total)
|
| 202 |
+
inp = int(um.get("input_tokens", 0) or 0)
|
| 203 |
+
out = int(um.get("output_tokens", 0) or 0)
|
| 204 |
+
return inp + out
|
| 205 |
+
rm = getattr(response, "response_metadata", None) or {}
|
| 206 |
+
if isinstance(rm, dict):
|
| 207 |
+
tu = rm.get("token_usage")
|
| 208 |
+
if isinstance(tu, dict):
|
| 209 |
+
if tu.get("total_tokens") is not None:
|
| 210 |
+
return int(tu["total_tokens"])
|
| 211 |
+
return int(tu.get("prompt_tokens", 0) or 0) + int(tu.get("completion_tokens", 0) or 0)
|
| 212 |
+
return 0
|
| 213 |
+
|
| 214 |
+
|
| 215 |
def _extract_message_text(response: object) -> str:
|
| 216 |
content = getattr(response, "content", "")
|
| 217 |
if isinstance(content, str):
|
rag/vector_store.py
CHANGED
|
@@ -36,8 +36,42 @@ def list_collection_names(persist_directory: str) -> list[str]:
|
|
| 36 |
return sorted(c.name for c in client.list_collections())
|
| 37 |
|
| 38 |
|
| 39 |
-
def delete_collection(persist_directory: str, collection_name: str) ->
|
|
|
|
| 40 |
Path(persist_directory).mkdir(parents=True, exist_ok=True)
|
| 41 |
client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
client.delete_collection(name=collection_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
|
|
|
| 36 |
return sorted(c.name for c in client.list_collections())
|
| 37 |
|
| 38 |
|
| 39 |
+
def delete_collection(persist_directory: str, collection_name: str) -> int:
|
| 40 |
+
"""Delete a collection and return the number of documents that were removed (best effort)."""
|
| 41 |
Path(persist_directory).mkdir(parents=True, exist_ok=True)
|
| 42 |
client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
|
| 43 |
+
removed = 0
|
| 44 |
+
try:
|
| 45 |
+
col = client.get_collection(name=collection_name)
|
| 46 |
+
removed = int(col.count())
|
| 47 |
+
except Exception:
|
| 48 |
+
removed = 0
|
| 49 |
client.delete_collection(name=collection_name)
|
| 50 |
+
return removed
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def collection_document_count(persist_directory: str, collection_name: str) -> int:
|
| 54 |
+
Path(persist_directory).mkdir(parents=True, exist_ok=True)
|
| 55 |
+
client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
|
| 56 |
+
try:
|
| 57 |
+
col = client.get_collection(name=collection_name)
|
| 58 |
+
return int(col.count())
|
| 59 |
+
except Exception:
|
| 60 |
+
return 0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def collection_created_at(persist_directory: str, collection_name: str) -> str | None:
|
| 64 |
+
"""Return collection metadata ``created_at`` if present (Chroma-specific)."""
|
| 65 |
+
Path(persist_directory).mkdir(parents=True, exist_ok=True)
|
| 66 |
+
client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
|
| 67 |
+
try:
|
| 68 |
+
col = client.get_collection(name=collection_name)
|
| 69 |
+
meta = getattr(col, "metadata", None) or {}
|
| 70 |
+
if isinstance(meta, dict):
|
| 71 |
+
raw = meta.get("created_at") or meta.get("created")
|
| 72 |
+
if raw is not None:
|
| 73 |
+
return str(raw)
|
| 74 |
+
except Exception:
|
| 75 |
+
pass
|
| 76 |
+
return None
|
| 77 |
|
storage/audit_store.py
CHANGED
|
@@ -1,11 +1,54 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any
|
| 4 |
from uuid import uuid4
|
| 5 |
|
| 6 |
import aiosqlite
|
| 7 |
|
| 8 |
-
from models.responses import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
async def init_audit_db(db_path: str) -> None:
|
|
@@ -24,80 +67,219 @@ async def init_audit_db(db_path: str) -> None:
|
|
| 24 |
message TEXT NOT NULL,
|
| 25 |
sources_json TEXT NOT NULL,
|
| 26 |
results_json TEXT NOT NULL,
|
| 27 |
-
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
"""
|
| 30 |
)
|
| 31 |
await conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
async def persist_query_audit(
|
| 35 |
db_path: str,
|
| 36 |
*,
|
|
|
|
| 37 |
action: str,
|
|
|
|
| 38 |
question: str,
|
| 39 |
collection_name: str,
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
) -> str:
|
| 42 |
-
event_id = str(uuid4())
|
| 43 |
await init_audit_db(db_path)
|
|
|
|
|
|
|
|
|
|
| 44 |
async with aiosqlite.connect(db_path) as conn:
|
| 45 |
await conn.execute(
|
| 46 |
"""
|
| 47 |
INSERT INTO audit_events (
|
| 48 |
-
event_id, action, question, collection_name, answer, status, message,
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
""",
|
| 51 |
(
|
| 52 |
-
|
| 53 |
action,
|
| 54 |
question,
|
| 55 |
collection_name,
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
json.dumps(
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
),
|
| 62 |
)
|
| 63 |
await conn.commit()
|
| 64 |
-
return
|
| 65 |
|
| 66 |
|
| 67 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
await init_audit_db(db_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
async with aiosqlite.connect(db_path) as conn:
|
| 70 |
conn.row_factory = aiosqlite.Row
|
| 71 |
cursor = await conn.execute(
|
| 72 |
-
"""
|
| 73 |
-
SELECT event_id,
|
| 74 |
FROM audit_events
|
|
|
|
| 75 |
ORDER BY datetime(created_at) DESC, rowid DESC
|
| 76 |
LIMIT ? OFFSET ?
|
| 77 |
""",
|
| 78 |
-
|
| 79 |
)
|
| 80 |
rows = await cursor.fetchall()
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
-
async def get_audit_event(db_path: str,
|
| 85 |
await init_audit_db(db_path)
|
| 86 |
async with aiosqlite.connect(db_path) as conn:
|
| 87 |
conn.row_factory = aiosqlite.Row
|
| 88 |
cursor = await conn.execute(
|
| 89 |
"""
|
| 90 |
-
SELECT event_id,
|
| 91 |
FROM audit_events
|
| 92 |
WHERE event_id = ?
|
| 93 |
""",
|
| 94 |
-
(
|
| 95 |
)
|
| 96 |
row = await cursor.fetchone()
|
| 97 |
if row is None:
|
| 98 |
return None
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
from datetime import datetime, timezone
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Any
|
| 5 |
from uuid import uuid4
|
| 6 |
|
| 7 |
import aiosqlite
|
| 8 |
|
| 9 |
+
from models.responses import AuditLogDetailResponse, AuditLogEntry, SourceCitation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _utc_now_iso() -> str:
|
| 13 |
+
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _parse_ts(value: object) -> datetime:
|
| 17 |
+
if value is None or value == "":
|
| 18 |
+
return datetime.now(timezone.utc)
|
| 19 |
+
s = str(value).strip()
|
| 20 |
+
if s.endswith("Z"):
|
| 21 |
+
s = s[:-1] + "+00:00"
|
| 22 |
+
try:
|
| 23 |
+
dt = datetime.fromisoformat(s)
|
| 24 |
+
if dt.tzinfo is None:
|
| 25 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 26 |
+
return dt
|
| 27 |
+
except ValueError:
|
| 28 |
+
return datetime.now(timezone.utc)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
async def _migrate_audit_columns(conn: aiosqlite.Connection) -> None:
|
| 32 |
+
cursor = await conn.execute("PRAGMA table_info(audit_events)")
|
| 33 |
+
rows = await cursor.fetchall()
|
| 34 |
+
col_names = {str(r[1]) for r in rows}
|
| 35 |
+
alters: list[str] = []
|
| 36 |
+
if "user_id" not in col_names:
|
| 37 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN user_id TEXT NOT NULL DEFAULT 'anonymous'")
|
| 38 |
+
if "model_used" not in col_names:
|
| 39 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN model_used TEXT")
|
| 40 |
+
if "tokens_used" not in col_names:
|
| 41 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN tokens_used INTEGER")
|
| 42 |
+
if "response_time_ms" not in col_names:
|
| 43 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN response_time_ms INTEGER")
|
| 44 |
+
if "answer_summary" not in col_names:
|
| 45 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN answer_summary TEXT")
|
| 46 |
+
if "kind" not in col_names:
|
| 47 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN kind TEXT NOT NULL DEFAULT 'ask'")
|
| 48 |
+
for stmt in alters:
|
| 49 |
+
await conn.execute(stmt)
|
| 50 |
+
if alters:
|
| 51 |
+
await conn.commit()
|
| 52 |
|
| 53 |
|
| 54 |
async def init_audit_db(db_path: str) -> None:
|
|
|
|
| 67 |
message TEXT NOT NULL,
|
| 68 |
sources_json TEXT NOT NULL,
|
| 69 |
results_json TEXT NOT NULL,
|
| 70 |
+
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 71 |
+
user_id TEXT NOT NULL DEFAULT 'anonymous',
|
| 72 |
+
model_used TEXT,
|
| 73 |
+
tokens_used INTEGER,
|
| 74 |
+
response_time_ms INTEGER,
|
| 75 |
+
answer_summary TEXT,
|
| 76 |
+
kind TEXT NOT NULL DEFAULT 'ask'
|
| 77 |
)
|
| 78 |
"""
|
| 79 |
)
|
| 80 |
await conn.commit()
|
| 81 |
+
await _migrate_audit_columns(conn)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _summary_from_answer(answer: str, max_len: int = 280) -> str:
|
| 85 |
+
text = (answer or "").strip()
|
| 86 |
+
if len(text) <= max_len:
|
| 87 |
+
return text
|
| 88 |
+
return text[: max_len - 1].rstrip() + "…"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _sources_to_citations(raw: list[dict[str, Any]]) -> list[SourceCitation]:
|
| 92 |
+
out: list[SourceCitation] = []
|
| 93 |
+
for item in raw:
|
| 94 |
+
if not isinstance(item, dict):
|
| 95 |
+
continue
|
| 96 |
+
if "document_name" in item:
|
| 97 |
+
doc = str(item.get("document_name", ""))
|
| 98 |
+
page = int(item.get("page_number", 0) or 0)
|
| 99 |
+
chunk = str(item.get("chunk_text", ""))
|
| 100 |
+
score = float(item.get("relevance_score", 0.0) or 0.0)
|
| 101 |
+
else:
|
| 102 |
+
doc = str(item.get("source", item.get("document_name", "")))
|
| 103 |
+
p = item.get("page_number", item.get("page"))
|
| 104 |
+
try:
|
| 105 |
+
page = int(p) if p is not None else 0
|
| 106 |
+
except (TypeError, ValueError):
|
| 107 |
+
page = 0
|
| 108 |
+
chunk = str(item.get("chunk_text", item.get("excerpt", item.get("text", ""))))
|
| 109 |
+
s = item.get("relevance_score", item.get("score"))
|
| 110 |
+
try:
|
| 111 |
+
score = float(s) if s is not None else 0.0
|
| 112 |
+
except (TypeError, ValueError):
|
| 113 |
+
score = 0.0
|
| 114 |
+
out.append(
|
| 115 |
+
SourceCitation(
|
| 116 |
+
document_name=doc or "unknown",
|
| 117 |
+
page_number=page,
|
| 118 |
+
chunk_text=chunk,
|
| 119 |
+
relevance_score=score,
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
return out
|
| 123 |
|
| 124 |
|
| 125 |
async def persist_query_audit(
|
| 126 |
db_path: str,
|
| 127 |
*,
|
| 128 |
+
query_id: str,
|
| 129 |
action: str,
|
| 130 |
+
user_id: str,
|
| 131 |
question: str,
|
| 132 |
collection_name: str,
|
| 133 |
+
answer: str,
|
| 134 |
+
sources: list[SourceCitation],
|
| 135 |
+
model_used: str,
|
| 136 |
+
tokens_used: int,
|
| 137 |
+
response_time_ms: int,
|
| 138 |
+
status: str = "success",
|
| 139 |
+
message: str = "ok",
|
| 140 |
+
kind: str = "ask",
|
| 141 |
) -> str:
|
|
|
|
| 142 |
await init_audit_db(db_path)
|
| 143 |
+
sources_payload = [s.model_dump(mode="json") for s in sources]
|
| 144 |
+
summary = _summary_from_answer(answer)
|
| 145 |
+
created = _utc_now_iso()
|
| 146 |
async with aiosqlite.connect(db_path) as conn:
|
| 147 |
await conn.execute(
|
| 148 |
"""
|
| 149 |
INSERT INTO audit_events (
|
| 150 |
+
event_id, action, question, collection_name, answer, status, message,
|
| 151 |
+
sources_json, results_json, created_at, user_id, model_used, tokens_used,
|
| 152 |
+
response_time_ms, answer_summary, kind
|
| 153 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, '[]', ?, ?, ?, ?, ?, ?, ?)
|
| 154 |
""",
|
| 155 |
(
|
| 156 |
+
query_id,
|
| 157 |
action,
|
| 158 |
question,
|
| 159 |
collection_name,
|
| 160 |
+
answer,
|
| 161 |
+
status,
|
| 162 |
+
message,
|
| 163 |
+
json.dumps(sources_payload),
|
| 164 |
+
created,
|
| 165 |
+
user_id,
|
| 166 |
+
model_used,
|
| 167 |
+
tokens_used,
|
| 168 |
+
response_time_ms,
|
| 169 |
+
summary,
|
| 170 |
+
kind,
|
| 171 |
),
|
| 172 |
)
|
| 173 |
await conn.commit()
|
| 174 |
+
return query_id
|
| 175 |
|
| 176 |
|
| 177 |
+
async def count_audit_events(
|
| 178 |
+
db_path: str,
|
| 179 |
+
*,
|
| 180 |
+
user_id: str | None = None,
|
| 181 |
+
from_date: str | None = None,
|
| 182 |
+
to_date: str | None = None,
|
| 183 |
+
) -> int:
|
| 184 |
await init_audit_db(db_path)
|
| 185 |
+
where, params = _audit_filters(user_id, from_date, to_date)
|
| 186 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 187 |
+
cur = await conn.execute(f"SELECT COUNT(*) AS c FROM audit_events {where}", params)
|
| 188 |
+
row = await cur.fetchone()
|
| 189 |
+
return int(row[0]) if row else 0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _audit_filters(user_id: str | None, from_date: str | None, to_date: str | None) -> tuple[str, list[Any]]:
|
| 193 |
+
clauses: list[str] = []
|
| 194 |
+
params: list[Any] = []
|
| 195 |
+
if user_id:
|
| 196 |
+
clauses.append("user_id = ?")
|
| 197 |
+
params.append(user_id)
|
| 198 |
+
if from_date:
|
| 199 |
+
clauses.append("datetime(created_at) >= datetime(?)")
|
| 200 |
+
params.append(from_date)
|
| 201 |
+
if to_date:
|
| 202 |
+
clauses.append("datetime(created_at) <= datetime(?)")
|
| 203 |
+
params.append(to_date)
|
| 204 |
+
if not clauses:
|
| 205 |
+
return "", []
|
| 206 |
+
return "WHERE " + " AND ".join(clauses), params
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
async def list_audit_events(
|
| 210 |
+
db_path: str,
|
| 211 |
+
*,
|
| 212 |
+
limit: int,
|
| 213 |
+
offset: int,
|
| 214 |
+
user_id: str | None = None,
|
| 215 |
+
from_date: str | None = None,
|
| 216 |
+
to_date: str | None = None,
|
| 217 |
+
) -> tuple[list[AuditLogEntry], int]:
|
| 218 |
+
await init_audit_db(db_path)
|
| 219 |
+
where, fparams = _audit_filters(user_id, from_date, to_date)
|
| 220 |
+
total = await count_audit_events(db_path, user_id=user_id, from_date=from_date, to_date=to_date)
|
| 221 |
async with aiosqlite.connect(db_path) as conn:
|
| 222 |
conn.row_factory = aiosqlite.Row
|
| 223 |
cursor = await conn.execute(
|
| 224 |
+
f"""
|
| 225 |
+
SELECT event_id, user_id, question, answer, answer_summary, sources_json, model_used, created_at
|
| 226 |
FROM audit_events
|
| 227 |
+
{where}
|
| 228 |
ORDER BY datetime(created_at) DESC, rowid DESC
|
| 229 |
LIMIT ? OFFSET ?
|
| 230 |
""",
|
| 231 |
+
[*fparams, limit, offset],
|
| 232 |
)
|
| 233 |
rows = await cursor.fetchall()
|
| 234 |
+
logs: list[AuditLogEntry] = []
|
| 235 |
+
for row in rows:
|
| 236 |
+
src_raw = json.loads(row["sources_json"] or "[]")
|
| 237 |
+
if not isinstance(src_raw, list):
|
| 238 |
+
src_raw = []
|
| 239 |
+
summary_cell = row["answer_summary"]
|
| 240 |
+
summary_text = str(summary_cell).strip() if summary_cell else ""
|
| 241 |
+
if not summary_text:
|
| 242 |
+
summary_text = _summary_from_answer(str(row["answer"] or ""))
|
| 243 |
+
logs.append(
|
| 244 |
+
AuditLogEntry(
|
| 245 |
+
query_id=str(row["event_id"]),
|
| 246 |
+
user_id=str(row["user_id"] or "anonymous"),
|
| 247 |
+
question=str(row["question"]),
|
| 248 |
+
answer_summary=summary_text,
|
| 249 |
+
sources_count=len(src_raw),
|
| 250 |
+
model_used=row["model_used"],
|
| 251 |
+
timestamp=_parse_ts(row["created_at"]),
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
return logs, total
|
| 255 |
|
| 256 |
|
| 257 |
+
async def get_audit_event(db_path: str, query_id: str) -> AuditLogDetailResponse | None:
|
| 258 |
await init_audit_db(db_path)
|
| 259 |
async with aiosqlite.connect(db_path) as conn:
|
| 260 |
conn.row_factory = aiosqlite.Row
|
| 261 |
cursor = await conn.execute(
|
| 262 |
"""
|
| 263 |
+
SELECT event_id, user_id, question, answer, sources_json, model_used, tokens_used, created_at
|
| 264 |
FROM audit_events
|
| 265 |
WHERE event_id = ?
|
| 266 |
""",
|
| 267 |
+
(query_id,),
|
| 268 |
)
|
| 269 |
row = await cursor.fetchone()
|
| 270 |
if row is None:
|
| 271 |
return None
|
| 272 |
+
src_raw = json.loads(row["sources_json"] or "[]")
|
| 273 |
+
if not isinstance(src_raw, list):
|
| 274 |
+
src_raw = []
|
| 275 |
+
citations = _sources_to_citations(src_raw)
|
| 276 |
+
return AuditLogDetailResponse(
|
| 277 |
+
query_id=str(row["event_id"]),
|
| 278 |
+
user_id=str(row["user_id"] or "anonymous"),
|
| 279 |
+
question=str(row["question"]),
|
| 280 |
+
full_answer=str(row["answer"] or ""),
|
| 281 |
+
sources=citations,
|
| 282 |
+
model_used=row["model_used"],
|
| 283 |
+
tokens_used=row["tokens_used"],
|
| 284 |
+
timestamp=_parse_ts(row["created_at"]),
|
| 285 |
+
)
|
storage/job_store.py
CHANGED
|
@@ -1,11 +1,64 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any
|
| 4 |
from uuid import uuid4
|
| 5 |
|
| 6 |
import aiosqlite
|
| 7 |
|
| 8 |
-
from models.responses import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
async def init_jobs_db(db_path: str) -> None:
|
|
@@ -22,74 +75,134 @@ async def init_jobs_db(db_path: str) -> None:
|
|
| 22 |
message TEXT NOT NULL DEFAULT '',
|
| 23 |
document_ids_json TEXT NOT NULL DEFAULT '[]',
|
| 24 |
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 25 |
-
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
"""
|
| 28 |
)
|
| 29 |
await conn.commit()
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
async def create_ingest_job(
|
| 33 |
db_path: str,
|
| 34 |
*,
|
| 35 |
collection_name: str,
|
| 36 |
-
|
| 37 |
) -> str:
|
|
|
|
|
|
|
| 38 |
job_id = str(uuid4())
|
|
|
|
|
|
|
|
|
|
| 39 |
await init_jobs_db(db_path)
|
| 40 |
async with aiosqlite.connect(db_path) as conn:
|
| 41 |
await conn.execute(
|
| 42 |
"""
|
| 43 |
INSERT INTO ingest_jobs (
|
| 44 |
-
job_id, status, collection_name, filename, message, document_ids_json
|
| 45 |
-
|
|
|
|
| 46 |
""",
|
| 47 |
-
(job_id, collection_name,
|
| 48 |
)
|
| 49 |
await conn.commit()
|
| 50 |
return job_id
|
| 51 |
|
| 52 |
|
| 53 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
db_path: str,
|
| 55 |
job_id: str,
|
| 56 |
*,
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
message: str | None = None,
|
| 59 |
-
document_ids: list[str] | None = None,
|
| 60 |
) -> None:
|
| 61 |
await init_jobs_db(db_path)
|
| 62 |
async with aiosqlite.connect(db_path) as conn:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
)
|
| 73 |
-
else:
|
| 74 |
-
await conn.execute(
|
| 75 |
-
"""
|
| 76 |
-
UPDATE ingest_jobs
|
| 77 |
-
SET status = ?, message = COALESCE(?, message),
|
| 78 |
-
updated_at = CURRENT_TIMESTAMP
|
| 79 |
-
WHERE job_id = ?
|
| 80 |
-
""",
|
| 81 |
-
(status, message, job_id),
|
| 82 |
-
)
|
| 83 |
await conn.commit()
|
| 84 |
|
| 85 |
|
| 86 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
await init_jobs_db(db_path)
|
| 88 |
async with aiosqlite.connect(db_path) as conn:
|
| 89 |
conn.row_factory = aiosqlite.Row
|
| 90 |
cursor = await conn.execute(
|
| 91 |
"""
|
| 92 |
-
SELECT job_id, status,
|
|
|
|
| 93 |
FROM ingest_jobs
|
| 94 |
WHERE job_id = ?
|
| 95 |
""",
|
|
@@ -98,18 +211,39 @@ async def get_ingest_job(db_path: str, job_id: str) -> IngestJobDetail | None:
|
|
| 98 |
row = await cursor.fetchone()
|
| 99 |
if row is None:
|
| 100 |
return None
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
-
async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> list[
|
| 107 |
await init_jobs_db(db_path)
|
| 108 |
async with aiosqlite.connect(db_path) as conn:
|
| 109 |
conn.row_factory = aiosqlite.Row
|
|
|
|
|
|
|
|
|
|
| 110 |
cursor = await conn.execute(
|
| 111 |
"""
|
| 112 |
-
SELECT job_id, status,
|
| 113 |
FROM ingest_jobs
|
| 114 |
ORDER BY datetime(updated_at) DESC, rowid DESC
|
| 115 |
LIMIT ? OFFSET ?
|
|
@@ -117,4 +251,30 @@ async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> list[dic
|
|
| 117 |
(limit, offset),
|
| 118 |
)
|
| 119 |
rows = await cursor.fetchall()
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
from datetime import datetime, timezone
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Any
|
| 5 |
from uuid import uuid4
|
| 6 |
|
| 7 |
import aiosqlite
|
| 8 |
|
| 9 |
+
from models.responses import JobListItem, JobStatusResponse
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _utc_now_iso() -> str:
|
| 13 |
+
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def _migrate_jobs_columns(conn: aiosqlite.Connection) -> None:
|
| 17 |
+
cursor = await conn.execute("PRAGMA table_info(ingest_jobs)")
|
| 18 |
+
rows = await cursor.fetchall()
|
| 19 |
+
col_names = {str(r[1]) for r in rows}
|
| 20 |
+
alters: list[str] = []
|
| 21 |
+
if "total_files" not in col_names:
|
| 22 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN total_files INTEGER NOT NULL DEFAULT 1")
|
| 23 |
+
if "processed_files" not in col_names:
|
| 24 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN processed_files INTEGER NOT NULL DEFAULT 0")
|
| 25 |
+
if "failed_files" not in col_names:
|
| 26 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN failed_files INTEGER NOT NULL DEFAULT 0")
|
| 27 |
+
if "filenames_json" not in col_names:
|
| 28 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN filenames_json TEXT NOT NULL DEFAULT '[]'")
|
| 29 |
+
if "errors_json" not in col_names:
|
| 30 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN errors_json TEXT NOT NULL DEFAULT '[]'")
|
| 31 |
+
if "started_at" not in col_names:
|
| 32 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN started_at TEXT")
|
| 33 |
+
if "completed_at" not in col_names:
|
| 34 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN completed_at TEXT")
|
| 35 |
+
for stmt in alters:
|
| 36 |
+
await conn.execute(stmt)
|
| 37 |
+
if alters:
|
| 38 |
+
await conn.commit()
|
| 39 |
+
await _backfill_job_filenames(conn)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
async def _backfill_job_filenames(conn: aiosqlite.Connection) -> None:
|
| 43 |
+
conn.row_factory = aiosqlite.Row
|
| 44 |
+
cursor = await conn.execute("SELECT job_id, filename, filenames_json, total_files FROM ingest_jobs")
|
| 45 |
+
rows = await cursor.fetchall()
|
| 46 |
+
for row in rows:
|
| 47 |
+
raw = row["filenames_json"] or "[]"
|
| 48 |
+
try:
|
| 49 |
+
parsed: Any = json.loads(raw)
|
| 50 |
+
except json.JSONDecodeError:
|
| 51 |
+
parsed = []
|
| 52 |
+
if not parsed and row["filename"]:
|
| 53 |
+
await conn.execute(
|
| 54 |
+
"""
|
| 55 |
+
UPDATE ingest_jobs
|
| 56 |
+
SET filenames_json = ?, total_files = CASE WHEN total_files IS NULL OR total_files < 1 THEN 1 ELSE total_files END
|
| 57 |
+
WHERE job_id = ?
|
| 58 |
+
""",
|
| 59 |
+
(json.dumps([row["filename"]]), row["job_id"]),
|
| 60 |
+
)
|
| 61 |
+
await conn.commit()
|
| 62 |
|
| 63 |
|
| 64 |
async def init_jobs_db(db_path: str) -> None:
|
|
|
|
| 75 |
message TEXT NOT NULL DEFAULT '',
|
| 76 |
document_ids_json TEXT NOT NULL DEFAULT '[]',
|
| 77 |
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 78 |
+
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 79 |
+
total_files INTEGER NOT NULL DEFAULT 1,
|
| 80 |
+
processed_files INTEGER NOT NULL DEFAULT 0,
|
| 81 |
+
failed_files INTEGER NOT NULL DEFAULT 0,
|
| 82 |
+
filenames_json TEXT NOT NULL DEFAULT '[]',
|
| 83 |
+
errors_json TEXT NOT NULL DEFAULT '[]',
|
| 84 |
+
started_at TEXT,
|
| 85 |
+
completed_at TEXT
|
| 86 |
)
|
| 87 |
"""
|
| 88 |
)
|
| 89 |
await conn.commit()
|
| 90 |
+
await _migrate_jobs_columns(conn)
|
| 91 |
|
| 92 |
|
| 93 |
async def create_ingest_job(
|
| 94 |
db_path: str,
|
| 95 |
*,
|
| 96 |
collection_name: str,
|
| 97 |
+
filenames: list[str],
|
| 98 |
) -> str:
|
| 99 |
+
if not filenames:
|
| 100 |
+
raise ValueError("filenames must not be empty")
|
| 101 |
job_id = str(uuid4())
|
| 102 |
+
primary = filenames[0]
|
| 103 |
+
names_json = json.dumps(filenames)
|
| 104 |
+
total = len(filenames)
|
| 105 |
await init_jobs_db(db_path)
|
| 106 |
async with aiosqlite.connect(db_path) as conn:
|
| 107 |
await conn.execute(
|
| 108 |
"""
|
| 109 |
INSERT INTO ingest_jobs (
|
| 110 |
+
job_id, status, collection_name, filename, message, document_ids_json,
|
| 111 |
+
total_files, processed_files, failed_files, filenames_json, errors_json
|
| 112 |
+
) VALUES (?, 'queued', ?, ?, '', '[]', ?, 0, 0, ?, '[]')
|
| 113 |
""",
|
| 114 |
+
(job_id, collection_name, primary, total, names_json),
|
| 115 |
)
|
| 116 |
await conn.commit()
|
| 117 |
return job_id
|
| 118 |
|
| 119 |
|
| 120 |
+
async def mark_job_processing(db_path: str, job_id: str) -> None:
|
| 121 |
+
await init_jobs_db(db_path)
|
| 122 |
+
started = _utc_now_iso()
|
| 123 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 124 |
+
await conn.execute(
|
| 125 |
+
"""
|
| 126 |
+
UPDATE ingest_jobs
|
| 127 |
+
SET status = 'processing', message = 'Ingestion in progress.', started_at = COALESCE(started_at, ?),
|
| 128 |
+
updated_at = CURRENT_TIMESTAMP
|
| 129 |
+
WHERE job_id = ?
|
| 130 |
+
""",
|
| 131 |
+
(started, job_id),
|
| 132 |
+
)
|
| 133 |
+
await conn.commit()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
async def update_job_progress(
|
| 137 |
db_path: str,
|
| 138 |
job_id: str,
|
| 139 |
*,
|
| 140 |
+
processed_files: int,
|
| 141 |
+
failed_files: int,
|
| 142 |
+
errors: list[str],
|
| 143 |
message: str | None = None,
|
|
|
|
| 144 |
) -> None:
|
| 145 |
await init_jobs_db(db_path)
|
| 146 |
async with aiosqlite.connect(db_path) as conn:
|
| 147 |
+
await conn.execute(
|
| 148 |
+
"""
|
| 149 |
+
UPDATE ingest_jobs
|
| 150 |
+
SET processed_files = ?, failed_files = ?, errors_json = ?,
|
| 151 |
+
message = COALESCE(?, message), updated_at = CURRENT_TIMESTAMP
|
| 152 |
+
WHERE job_id = ?
|
| 153 |
+
""",
|
| 154 |
+
(processed_files, failed_files, json.dumps(errors), message, job_id),
|
| 155 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
await conn.commit()
|
| 157 |
|
| 158 |
|
| 159 |
+
async def complete_ingest_job(
|
| 160 |
+
db_path: str,
|
| 161 |
+
job_id: str,
|
| 162 |
+
*,
|
| 163 |
+
document_ids: list[str],
|
| 164 |
+
message: str,
|
| 165 |
+
) -> None:
|
| 166 |
+
await init_jobs_db(db_path)
|
| 167 |
+
completed = _utc_now_iso()
|
| 168 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 169 |
+
await conn.execute(
|
| 170 |
+
"""
|
| 171 |
+
UPDATE ingest_jobs
|
| 172 |
+
SET status = 'completed', message = ?, document_ids_json = ?,
|
| 173 |
+
completed_at = ?, updated_at = CURRENT_TIMESTAMP
|
| 174 |
+
WHERE job_id = ?
|
| 175 |
+
""",
|
| 176 |
+
(message, json.dumps(document_ids), completed, job_id),
|
| 177 |
+
)
|
| 178 |
+
await conn.commit()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
async def fail_ingest_job(db_path: str, job_id: str, *, message: str, errors: list[str] | None = None) -> None:
|
| 182 |
+
await init_jobs_db(db_path)
|
| 183 |
+
completed = _utc_now_iso()
|
| 184 |
+
err_json = json.dumps(errors or [message])
|
| 185 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 186 |
+
await conn.execute(
|
| 187 |
+
"""
|
| 188 |
+
UPDATE ingest_jobs
|
| 189 |
+
SET status = 'failed', message = ?, errors_json = ?, completed_at = ?,
|
| 190 |
+
updated_at = CURRENT_TIMESTAMP
|
| 191 |
+
WHERE job_id = ?
|
| 192 |
+
""",
|
| 193 |
+
(message, err_json, completed, job_id),
|
| 194 |
+
)
|
| 195 |
+
await conn.commit()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
async def get_job_status(db_path: str, job_id: str) -> JobStatusResponse | None:
|
| 199 |
await init_jobs_db(db_path)
|
| 200 |
async with aiosqlite.connect(db_path) as conn:
|
| 201 |
conn.row_factory = aiosqlite.Row
|
| 202 |
cursor = await conn.execute(
|
| 203 |
"""
|
| 204 |
+
SELECT job_id, status, total_files, processed_files, failed_files, errors_json,
|
| 205 |
+
started_at, completed_at, message
|
| 206 |
FROM ingest_jobs
|
| 207 |
WHERE job_id = ?
|
| 208 |
""",
|
|
|
|
| 211 |
row = await cursor.fetchone()
|
| 212 |
if row is None:
|
| 213 |
return None
|
| 214 |
+
data = dict(row)
|
| 215 |
+
total = int(data["total_files"] or 0)
|
| 216 |
+
processed = int(data["processed_files"] or 0)
|
| 217 |
+
failed = int(data["failed_files"] or 0)
|
| 218 |
+
denom = total if total > 0 else 1
|
| 219 |
+
progress = int(min(100, max(0, round((processed + failed) / denom * 100))))
|
| 220 |
+
errors = json.loads(data.get("errors_json") or "[]")
|
| 221 |
+
if not isinstance(errors, list):
|
| 222 |
+
errors = [str(errors)]
|
| 223 |
+
errors_str = [str(e) for e in errors]
|
| 224 |
+
return JobStatusResponse(
|
| 225 |
+
job_id=str(data["job_id"]),
|
| 226 |
+
status=str(data["status"]),
|
| 227 |
+
total_files=total,
|
| 228 |
+
processed_files=processed,
|
| 229 |
+
failed_files=failed,
|
| 230 |
+
progress_percent=progress,
|
| 231 |
+
started_at=_parse_dt(data.get("started_at")),
|
| 232 |
+
completed_at=_parse_dt(data.get("completed_at")),
|
| 233 |
+
errors=errors_str,
|
| 234 |
+
)
|
| 235 |
|
| 236 |
|
| 237 |
+
async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> tuple[list[JobListItem], int]:
|
| 238 |
await init_jobs_db(db_path)
|
| 239 |
async with aiosqlite.connect(db_path) as conn:
|
| 240 |
conn.row_factory = aiosqlite.Row
|
| 241 |
+
cur_total = await conn.execute("SELECT COUNT(*) AS c FROM ingest_jobs")
|
| 242 |
+
total_row = await cur_total.fetchone()
|
| 243 |
+
total = int(total_row["c"]) if total_row else 0
|
| 244 |
cursor = await conn.execute(
|
| 245 |
"""
|
| 246 |
+
SELECT job_id, status, total_files, completed_at
|
| 247 |
FROM ingest_jobs
|
| 248 |
ORDER BY datetime(updated_at) DESC, rowid DESC
|
| 249 |
LIMIT ? OFFSET ?
|
|
|
|
| 251 |
(limit, offset),
|
| 252 |
)
|
| 253 |
rows = await cursor.fetchall()
|
| 254 |
+
items = [
|
| 255 |
+
JobListItem(
|
| 256 |
+
job_id=str(r["job_id"]),
|
| 257 |
+
status=str(r["status"]),
|
| 258 |
+
total_files=int(r["total_files"] or 0),
|
| 259 |
+
completed_at=_parse_dt(r["completed_at"]),
|
| 260 |
+
)
|
| 261 |
+
for r in rows
|
| 262 |
+
]
|
| 263 |
+
return items, total
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _parse_dt(value: object) -> datetime | None:
|
| 267 |
+
if value is None or value == "":
|
| 268 |
+
return None
|
| 269 |
+
s = str(value).strip()
|
| 270 |
+
if not s:
|
| 271 |
+
return None
|
| 272 |
+
if s.endswith("Z"):
|
| 273 |
+
s = s[:-1] + "+00:00"
|
| 274 |
+
try:
|
| 275 |
+
dt = datetime.fromisoformat(s)
|
| 276 |
+
if dt.tzinfo is None:
|
| 277 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 278 |
+
return dt
|
| 279 |
+
except ValueError:
|
| 280 |
+
return None
|
streamlit_app.py
CHANGED
|
@@ -71,17 +71,43 @@ def _fmt_api_error(exc: httpx.HTTPStatusError) -> str:
|
|
| 71 |
return f"HTTP {exc.response.status_code}"
|
| 72 |
|
| 73 |
|
| 74 |
-
def _post_query_ask(
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
r = client.post("/query/ask", json=body)
|
| 78 |
if r.status_code == 404:
|
| 79 |
r = client.post("/query", json=body)
|
| 80 |
return r
|
| 81 |
|
| 82 |
|
| 83 |
-
def _get_audit_logs(
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
r = client.get("/audit/logs", params=params)
|
| 86 |
if r.status_code == 404:
|
| 87 |
r = client.get("/audit", params=params)
|
|
@@ -156,7 +182,7 @@ def main() -> None:
|
|
| 156 |
st.warning("Choose a file first.")
|
| 157 |
else:
|
| 158 |
try:
|
| 159 |
-
files = {"
|
| 160 |
data = {"collection_name": up_collection}
|
| 161 |
with _client() as c:
|
| 162 |
r = c.post("/ingest/upload", files=files, data=data)
|
|
@@ -185,7 +211,10 @@ def main() -> None:
|
|
| 185 |
else:
|
| 186 |
try:
|
| 187 |
with _client() as c:
|
| 188 |
-
r = c.post(
|
|
|
|
|
|
|
|
|
|
| 189 |
r.raise_for_status()
|
| 190 |
out = r.json()
|
| 191 |
st.success(out.get("message", "Queued"))
|
|
@@ -206,10 +235,10 @@ def main() -> None:
|
|
| 206 |
r = c.get("/ingest/collections")
|
| 207 |
r.raise_for_status()
|
| 208 |
cols = r.json()
|
| 209 |
-
|
| 210 |
-
st.write(cols.get(
|
| 211 |
-
if
|
| 212 |
-
st.dataframe(
|
| 213 |
else:
|
| 214 |
st.info("No collections yet.")
|
| 215 |
except httpx.HTTPStatusError as e:
|
|
@@ -228,7 +257,10 @@ def main() -> None:
|
|
| 228 |
with _client() as c:
|
| 229 |
r = c.delete(f"/ingest/collection/{del_name.strip()}")
|
| 230 |
r.raise_for_status()
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
| 232 |
except httpx.HTTPStatusError as e:
|
| 233 |
st.error(_fmt_api_error(e))
|
| 234 |
except httpx.ConnectError as e:
|
|
@@ -250,7 +282,7 @@ def main() -> None:
|
|
| 250 |
r.raise_for_status()
|
| 251 |
payload = r.json()
|
| 252 |
jobs: list[dict[str, Any]] = payload.get("jobs", [])
|
| 253 |
-
st.caption(payload.get(
|
| 254 |
if jobs:
|
| 255 |
st.dataframe(jobs, hide_index=True, use_container_width=True)
|
| 256 |
else:
|
|
@@ -293,9 +325,8 @@ def main() -> None:
|
|
| 293 |
r = c.get(f"/jobs/{job_id.strip()}")
|
| 294 |
r.raise_for_status()
|
| 295 |
body = r.json()
|
| 296 |
-
|
| 297 |
-
st_
|
| 298 |
-
status_ph.write(f"Poll {i + 1}: **{st_}** — {job.get('message', '')}")
|
| 299 |
if st_ in ("completed", "failed"):
|
| 300 |
st.json(body)
|
| 301 |
break
|
|
@@ -331,8 +362,7 @@ def main() -> None:
|
|
| 331 |
)
|
| 332 |
r.raise_for_status()
|
| 333 |
ans = r.json()
|
| 334 |
-
|
| 335 |
-
st.success(msg if msg else "Request completed.")
|
| 336 |
if ans.get("answer"):
|
| 337 |
st.markdown("### Answer")
|
| 338 |
st.markdown(ans["answer"])
|
|
@@ -370,11 +400,11 @@ def main() -> None:
|
|
| 370 |
r = c.post("/query/summarise", json=body)
|
| 371 |
r.raise_for_status()
|
| 372 |
ans = r.json()
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
if
|
| 376 |
st.markdown("### Summary")
|
| 377 |
-
st.markdown(
|
| 378 |
else:
|
| 379 |
st.warning("No summary text in the response; see **Raw response** below.")
|
| 380 |
src = ans.get("sources") or []
|
|
@@ -411,11 +441,15 @@ def main() -> None:
|
|
| 411 |
)
|
| 412 |
r.raise_for_status()
|
| 413 |
payload = r.json()
|
| 414 |
-
events = payload.get("events", [])
|
| 415 |
-
st.caption(payload.get(
|
| 416 |
if events:
|
| 417 |
st.dataframe(events, hide_index=True, use_container_width=True)
|
| 418 |
-
ids = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
if ids:
|
| 420 |
st.session_state["_audit_ids"] = ids
|
| 421 |
else:
|
|
@@ -432,7 +466,7 @@ def main() -> None:
|
|
| 432 |
pick = ""
|
| 433 |
if ids_for_select:
|
| 434 |
pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
|
| 435 |
-
manual_id = st.text_input("Or enter event ID", key="audit_manual")
|
| 436 |
ev_id = (manual_id.strip() or (pick or "").strip()).strip()
|
| 437 |
if st.button("Load detail", key="btn_audit_detail") and ev_id:
|
| 438 |
try:
|
|
|
|
| 71 |
return f"HTTP {exc.response.status_code}"
|
| 72 |
|
| 73 |
|
| 74 |
+
def _post_query_ask(
|
| 75 |
+
client: httpx.Client,
|
| 76 |
+
*,
|
| 77 |
+
question: str,
|
| 78 |
+
collection_name: str,
|
| 79 |
+
top_k: int = 5,
|
| 80 |
+
user_id: str = "anonymous",
|
| 81 |
+
) -> httpx.Response:
|
| 82 |
+
"""POST /query/ask (falls back to POST /query on older servers)."""
|
| 83 |
+
body: dict[str, object] = {
|
| 84 |
+
"question": question.strip(),
|
| 85 |
+
"collection_name": collection_name,
|
| 86 |
+
"top_k": top_k,
|
| 87 |
+
"user_id": user_id,
|
| 88 |
+
}
|
| 89 |
r = client.post("/query/ask", json=body)
|
| 90 |
if r.status_code == 404:
|
| 91 |
r = client.post("/query", json=body)
|
| 92 |
return r
|
| 93 |
|
| 94 |
|
| 95 |
+
def _get_audit_logs(
|
| 96 |
+
client: httpx.Client,
|
| 97 |
+
*,
|
| 98 |
+
limit: int,
|
| 99 |
+
offset: int,
|
| 100 |
+
user_id: str | None = None,
|
| 101 |
+
from_date: str | None = None,
|
| 102 |
+
to_date: str | None = None,
|
| 103 |
+
) -> httpx.Response:
|
| 104 |
+
params: dict[str, object] = {"limit": limit, "offset": offset}
|
| 105 |
+
if user_id:
|
| 106 |
+
params["user_id"] = user_id
|
| 107 |
+
if from_date:
|
| 108 |
+
params["from_date"] = from_date
|
| 109 |
+
if to_date:
|
| 110 |
+
params["to_date"] = to_date
|
| 111 |
r = client.get("/audit/logs", params=params)
|
| 112 |
if r.status_code == 404:
|
| 113 |
r = client.get("/audit", params=params)
|
|
|
|
| 182 |
st.warning("Choose a file first.")
|
| 183 |
else:
|
| 184 |
try:
|
| 185 |
+
files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")}
|
| 186 |
data = {"collection_name": up_collection}
|
| 187 |
with _client() as c:
|
| 188 |
r = c.post("/ingest/upload", files=files, data=data)
|
|
|
|
| 211 |
else:
|
| 212 |
try:
|
| 213 |
with _client() as c:
|
| 214 |
+
r = c.post(
|
| 215 |
+
"/ingest/url",
|
| 216 |
+
json={"urls": [ingest_url.strip()], "collection_name": url_collection},
|
| 217 |
+
)
|
| 218 |
r.raise_for_status()
|
| 219 |
out = r.json()
|
| 220 |
st.success(out.get("message", "Queued"))
|
|
|
|
| 235 |
r = c.get("/ingest/collections")
|
| 236 |
r.raise_for_status()
|
| 237 |
cols = r.json()
|
| 238 |
+
rows = cols.get("collections", [])
|
| 239 |
+
st.write(f"{cols.get('total', len(rows))} collection(s).")
|
| 240 |
+
if rows:
|
| 241 |
+
st.dataframe(rows, hide_index=True, use_container_width=True)
|
| 242 |
else:
|
| 243 |
st.info("No collections yet.")
|
| 244 |
except httpx.HTTPStatusError as e:
|
|
|
|
| 257 |
with _client() as c:
|
| 258 |
r = c.delete(f"/ingest/collection/{del_name.strip()}")
|
| 259 |
r.raise_for_status()
|
| 260 |
+
del_body = r.json()
|
| 261 |
+
st.success(del_body.get("message", "Deleted"))
|
| 262 |
+
if "documents_removed" in del_body:
|
| 263 |
+
st.caption(f"Documents removed: **{del_body['documents_removed']}**")
|
| 264 |
except httpx.HTTPStatusError as e:
|
| 265 |
st.error(_fmt_api_error(e))
|
| 266 |
except httpx.ConnectError as e:
|
|
|
|
| 282 |
r.raise_for_status()
|
| 283 |
payload = r.json()
|
| 284 |
jobs: list[dict[str, Any]] = payload.get("jobs", [])
|
| 285 |
+
st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**")
|
| 286 |
if jobs:
|
| 287 |
st.dataframe(jobs, hide_index=True, use_container_width=True)
|
| 288 |
else:
|
|
|
|
| 325 |
r = c.get(f"/jobs/{job_id.strip()}")
|
| 326 |
r.raise_for_status()
|
| 327 |
body = r.json()
|
| 328 |
+
st_ = body.get("status", "")
|
| 329 |
+
status_ph.write(f"Poll {i + 1}: **{st_}** — {body.get('progress_percent', 0)}%")
|
|
|
|
| 330 |
if st_ in ("completed", "failed"):
|
| 331 |
st.json(body)
|
| 332 |
break
|
|
|
|
| 362 |
)
|
| 363 |
r.raise_for_status()
|
| 364 |
ans = r.json()
|
| 365 |
+
st.success(f"Query id: `{ans.get('query_id', '')}`")
|
|
|
|
| 366 |
if ans.get("answer"):
|
| 367 |
st.markdown("### Answer")
|
| 368 |
st.markdown(ans["answer"])
|
|
|
|
| 400 |
r = c.post("/query/summarise", json=body)
|
| 401 |
r.raise_for_status()
|
| 402 |
ans = r.json()
|
| 403 |
+
st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**")
|
| 404 |
+
summary_text = ans.get("summary") or ans.get("answer")
|
| 405 |
+
if summary_text:
|
| 406 |
st.markdown("### Summary")
|
| 407 |
+
st.markdown(summary_text)
|
| 408 |
else:
|
| 409 |
st.warning("No summary text in the response; see **Raw response** below.")
|
| 410 |
src = ans.get("sources") or []
|
|
|
|
| 441 |
)
|
| 442 |
r.raise_for_status()
|
| 443 |
payload = r.json()
|
| 444 |
+
events = payload.get("logs", payload.get("events", []))
|
| 445 |
+
st.caption(f"Total matching: **{payload.get('total', len(events))}**")
|
| 446 |
if events:
|
| 447 |
st.dataframe(events, hide_index=True, use_container_width=True)
|
| 448 |
+
ids = [
|
| 449 |
+
e.get("query_id") or e.get("event_id")
|
| 450 |
+
for e in events
|
| 451 |
+
if isinstance(e, dict) and (e.get("query_id") or e.get("event_id"))
|
| 452 |
+
]
|
| 453 |
if ids:
|
| 454 |
st.session_state["_audit_ids"] = ids
|
| 455 |
else:
|
|
|
|
| 466 |
pick = ""
|
| 467 |
if ids_for_select:
|
| 468 |
pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
|
| 469 |
+
manual_id = st.text_input("Or enter query / event ID", key="audit_manual")
|
| 470 |
ev_id = (manual_id.strip() or (pick or "").strip()).strip()
|
| 471 |
if st.button("Load detail", key="btn_audit_detail") and ev_id:
|
| 472 |
try:
|
tests/test_audit.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
import asyncio
|
| 2 |
from unittest.mock import AsyncMock
|
|
|
|
| 3 |
|
| 4 |
import pytest
|
| 5 |
from fastapi.testclient import TestClient
|
| 6 |
|
| 7 |
from api.config import Settings
|
| 8 |
from api.main import app
|
| 9 |
-
from models.responses import
|
| 10 |
from storage.audit_store import persist_query_audit
|
| 11 |
|
| 12 |
|
|
@@ -32,39 +33,89 @@ def client(settings, monkeypatch):
|
|
| 32 |
yield test_client
|
| 33 |
|
| 34 |
|
| 35 |
-
def _seed_audit(settings: Settings, question: str = "What are key risks?") -> str:
|
| 36 |
-
|
|
|
|
| 37 |
persist_query_audit(
|
| 38 |
settings.audit_db_path,
|
|
|
|
| 39 |
action="query",
|
|
|
|
| 40 |
question=question,
|
| 41 |
collection_name="default",
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
)
|
| 50 |
)
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def test_audit_logs_and_detail_success(client, settings):
|
| 54 |
-
|
| 55 |
|
| 56 |
list_response = client.get("/audit/logs?limit=10&offset=0")
|
| 57 |
assert list_response.status_code == 200
|
| 58 |
body = list_response.json()
|
| 59 |
-
assert
|
| 60 |
-
assert
|
| 61 |
-
assert any(
|
| 62 |
|
| 63 |
-
detail_response = client.get(f"/audit/logs/{
|
| 64 |
assert detail_response.status_code == 200
|
| 65 |
detail = detail_response.json()
|
| 66 |
-
assert detail["
|
| 67 |
-
assert detail["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def test_audit_logs_validation_error_for_bad_limit(client):
|
|
|
|
| 1 |
import asyncio
|
| 2 |
from unittest.mock import AsyncMock
|
| 3 |
+
from uuid import uuid4
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
from fastapi.testclient import TestClient
|
| 7 |
|
| 8 |
from api.config import Settings
|
| 9 |
from api.main import app
|
| 10 |
+
from models.responses import SourceCitation
|
| 11 |
from storage.audit_store import persist_query_audit
|
| 12 |
|
| 13 |
|
|
|
|
| 33 |
yield test_client
|
| 34 |
|
| 35 |
|
| 36 |
+
def _seed_audit(settings: Settings, question: str = "What are key risks?", user_id: str = "analyst_001") -> str:
|
| 37 |
+
query_id = str(uuid4())
|
| 38 |
+
asyncio.run(
|
| 39 |
persist_query_audit(
|
| 40 |
settings.audit_db_path,
|
| 41 |
+
query_id=query_id,
|
| 42 |
action="query",
|
| 43 |
+
user_id=user_id,
|
| 44 |
question=question,
|
| 45 |
collection_name="default",
|
| 46 |
+
answer="Grounded answer text for audit trail.",
|
| 47 |
+
sources=[
|
| 48 |
+
SourceCitation(
|
| 49 |
+
document_name="report.pdf",
|
| 50 |
+
page_number=3,
|
| 51 |
+
chunk_text="Risk disclosure excerpt.",
|
| 52 |
+
relevance_score=0.9,
|
| 53 |
+
)
|
| 54 |
+
],
|
| 55 |
+
model_used="ollama:llama3.1:8b",
|
| 56 |
+
tokens_used=120,
|
| 57 |
+
response_time_ms=50,
|
| 58 |
+
kind="ask",
|
| 59 |
)
|
| 60 |
)
|
| 61 |
+
return query_id
|
| 62 |
|
| 63 |
|
| 64 |
def test_audit_logs_and_detail_success(client, settings):
|
| 65 |
+
query_id = _seed_audit(settings)
|
| 66 |
|
| 67 |
list_response = client.get("/audit/logs?limit=10&offset=0")
|
| 68 |
assert list_response.status_code == 200
|
| 69 |
body = list_response.json()
|
| 70 |
+
assert "logs" in body
|
| 71 |
+
assert body["total"] >= 1
|
| 72 |
+
assert any(entry["query_id"] == query_id for entry in body["logs"])
|
| 73 |
|
| 74 |
+
detail_response = client.get(f"/audit/logs/{query_id}")
|
| 75 |
assert detail_response.status_code == 200
|
| 76 |
detail = detail_response.json()
|
| 77 |
+
assert detail["query_id"] == query_id
|
| 78 |
+
assert detail["question"] == "What are key risks?"
|
| 79 |
+
assert detail["full_answer"] == "Grounded answer text for audit trail."
|
| 80 |
+
assert len(detail["sources"]) == 1
|
| 81 |
+
assert detail["sources"][0]["document_name"] == "report.pdf"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_audit_logs_filter_by_user_id(client, settings):
|
| 85 |
+
q1 = _seed_audit(settings, question="Q one", user_id="user_a")
|
| 86 |
+
_seed_audit(settings, question="Q two", user_id="user_b")
|
| 87 |
+
|
| 88 |
+
r = client.get("/audit/logs", params={"user_id": "user_a", "limit": 50, "offset": 0})
|
| 89 |
+
assert r.status_code == 200
|
| 90 |
+
body = r.json()
|
| 91 |
+
ids = {e["query_id"] for e in body["logs"]}
|
| 92 |
+
assert q1 in ids
|
| 93 |
+
assert all(e["user_id"] == "user_a" for e in body["logs"])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_audit_logs_filter_by_from_date(client, settings):
|
| 97 |
+
query_id = str(uuid4())
|
| 98 |
+
future = "2099-01-01T00:00:00Z"
|
| 99 |
+
asyncio.run(
|
| 100 |
+
persist_query_audit(
|
| 101 |
+
settings.audit_db_path,
|
| 102 |
+
query_id=query_id,
|
| 103 |
+
action="query",
|
| 104 |
+
user_id="u",
|
| 105 |
+
question="Future dated row",
|
| 106 |
+
collection_name="default",
|
| 107 |
+
answer="A",
|
| 108 |
+
sources=[],
|
| 109 |
+
model_used="m",
|
| 110 |
+
tokens_used=0,
|
| 111 |
+
response_time_ms=1,
|
| 112 |
+
kind="ask",
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
r = client.get("/audit/logs", params={"from_date": future, "limit": 50, "offset": 0})
|
| 116 |
+
assert r.status_code == 200
|
| 117 |
+
body = r.json()
|
| 118 |
+
assert query_id not in {e["query_id"] for e in body["logs"]}
|
| 119 |
|
| 120 |
|
| 121 |
def test_audit_logs_validation_error_for_bad_limit(client):
|
tests/test_ingest.py
CHANGED
|
@@ -34,21 +34,23 @@ def test_upload_queues_job_success(client, monkeypatch):
|
|
| 34 |
response = client.post(
|
| 35 |
"/ingest/upload",
|
| 36 |
data={"collection_name": "default"},
|
| 37 |
-
files=
|
| 38 |
)
|
| 39 |
|
| 40 |
assert response.status_code == 200
|
| 41 |
body = response.json()
|
| 42 |
assert body["status"] == "queued"
|
| 43 |
assert body["job_id"] == "job-123"
|
| 44 |
-
assert
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def test_upload_rejects_unsupported_extension(client):
|
| 48 |
response = client.post(
|
| 49 |
"/ingest/upload",
|
| 50 |
data={"collection_name": "default"},
|
| 51 |
-
files=
|
| 52 |
)
|
| 53 |
|
| 54 |
assert response.status_code == 400
|
|
@@ -65,7 +67,7 @@ def test_upload_returns_500_on_job_creation_error(client, monkeypatch):
|
|
| 65 |
response = client.post(
|
| 66 |
"/ingest/upload",
|
| 67 |
data={"collection_name": "default"},
|
| 68 |
-
files=
|
| 69 |
)
|
| 70 |
|
| 71 |
assert response.status_code == 500
|
|
@@ -75,12 +77,14 @@ def test_upload_returns_500_on_job_creation_error(client, monkeypatch):
|
|
| 75 |
def test_ingest_url_rejects_non_http_scheme(client, monkeypatch):
|
| 76 |
monkeypatch.setattr(
|
| 77 |
"api.routes.ingest._download_url_to_temp",
|
| 78 |
-
AsyncMock(
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
response = client.post(
|
| 82 |
"/ingest/url",
|
| 83 |
-
json={"
|
| 84 |
)
|
| 85 |
|
| 86 |
assert response.status_code == 400
|
|
|
|
| 34 |
response = client.post(
|
| 35 |
"/ingest/upload",
|
| 36 |
data={"collection_name": "default"},
|
| 37 |
+
files=[("files", ("sample.txt", b"hello world", "text/plain"))],
|
| 38 |
)
|
| 39 |
|
| 40 |
assert response.status_code == 200
|
| 41 |
body = response.json()
|
| 42 |
assert body["status"] == "queued"
|
| 43 |
assert body["job_id"] == "job-123"
|
| 44 |
+
assert body["total_files"] == 1
|
| 45 |
+
assert body["filenames"] == ["sample.txt"]
|
| 46 |
+
assert "Poll /jobs/job-123" in body["message"]
|
| 47 |
|
| 48 |
|
| 49 |
def test_upload_rejects_unsupported_extension(client):
|
| 50 |
response = client.post(
|
| 51 |
"/ingest/upload",
|
| 52 |
data={"collection_name": "default"},
|
| 53 |
+
files=[("files", ("sample.csv", b"a,b\n1,2", "text/csv"))],
|
| 54 |
)
|
| 55 |
|
| 56 |
assert response.status_code == 400
|
|
|
|
| 67 |
response = client.post(
|
| 68 |
"/ingest/upload",
|
| 69 |
data={"collection_name": "default"},
|
| 70 |
+
files=[("files", ("sample.txt", b"hello", "text/plain"))],
|
| 71 |
)
|
| 72 |
|
| 73 |
assert response.status_code == 500
|
|
|
|
| 77 |
def test_ingest_url_rejects_non_http_scheme(client, monkeypatch):
|
| 78 |
monkeypatch.setattr(
|
| 79 |
"api.routes.ingest._download_url_to_temp",
|
| 80 |
+
AsyncMock(
|
| 81 |
+
side_effect=ingest_route.HTTPException(status_code=400, detail="Only http and https URLs are supported.")
|
| 82 |
+
),
|
| 83 |
)
|
| 84 |
|
| 85 |
response = client.post(
|
| 86 |
"/ingest/url",
|
| 87 |
+
json={"urls": ["https://example.com/file.txt"], "collection_name": "default"},
|
| 88 |
)
|
| 89 |
|
| 90 |
assert response.status_code == 400
|
tests/test_query.py
CHANGED
|
@@ -40,20 +40,51 @@ def test_ask_returns_grounded_answer_with_sources(client, monkeypatch):
|
|
| 40 |
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 41 |
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 42 |
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 43 |
-
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: "Audi is expanding EV investment.")
|
| 44 |
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
|
| 45 |
|
| 46 |
response = client.post(
|
| 47 |
"/query/ask",
|
| 48 |
-
json={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
assert response.status_code == 200
|
| 52 |
body = response.json()
|
| 53 |
-
assert body["status"] == "success"
|
| 54 |
assert body["answer"] == "Audi is expanding EV investment."
|
|
|
|
|
|
|
| 55 |
assert len(body["sources"]) == 1
|
| 56 |
-
assert body["sources"][0]["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def test_ask_returns_422_for_invalid_payload(client):
|
|
@@ -61,6 +92,14 @@ def test_ask_returns_422_for_invalid_payload(client):
|
|
| 61 |
assert response.status_code == 422
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
|
| 65 |
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 66 |
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
|
@@ -68,7 +107,7 @@ def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
|
|
| 68 |
|
| 69 |
response = client.post(
|
| 70 |
"/query/ask",
|
| 71 |
-
json={"question": "What happened?", "collection_name": "default"},
|
| 72 |
)
|
| 73 |
|
| 74 |
assert response.status_code == 500
|
|
@@ -88,7 +127,8 @@ def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
|
|
| 88 |
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 89 |
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 90 |
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 91 |
-
monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: "Summary output")
|
|
|
|
| 92 |
monkeypatch.setattr(
|
| 93 |
"api.routes.query.persist_query_audit",
|
| 94 |
AsyncMock(side_effect=RuntimeError("audit write failed")),
|
|
@@ -96,7 +136,7 @@ def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
|
|
| 96 |
|
| 97 |
response = client.post(
|
| 98 |
"/query/summarise",
|
| 99 |
-
json={"collection_name": "default", "focus": "summarise risks"},
|
| 100 |
)
|
| 101 |
|
| 102 |
assert response.status_code == 500
|
|
|
|
| 40 |
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 41 |
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 42 |
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 43 |
+
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42))
|
| 44 |
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
|
| 45 |
|
| 46 |
response = client.post(
|
| 47 |
"/query/ask",
|
| 48 |
+
json={
|
| 49 |
+
"question": "What is Audi doing in EV markets worldwide?",
|
| 50 |
+
"collection_name": "default",
|
| 51 |
+
"top_k": 3,
|
| 52 |
+
"user_id": "tester",
|
| 53 |
+
},
|
| 54 |
)
|
| 55 |
|
| 56 |
assert response.status_code == 200
|
| 57 |
body = response.json()
|
|
|
|
| 58 |
assert body["answer"] == "Audi is expanding EV investment."
|
| 59 |
+
assert "query_id" in body
|
| 60 |
+
assert body["question"].startswith("What is Audi")
|
| 61 |
assert len(body["sources"]) == 1
|
| 62 |
+
assert body["sources"][0]["document_name"] == "strategy.md"
|
| 63 |
+
assert body["sources"][0]["page_number"] == 1
|
| 64 |
+
assert body["tokens_used"] == 42
|
| 65 |
+
assert "response_time_ms" in body
|
| 66 |
+
assert "model_used" in body
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch):
|
| 70 |
+
captured: dict[str, object] = {}
|
| 71 |
+
|
| 72 |
+
def capture_retrieve(vs, question, k):
|
| 73 |
+
captured["k"] = k
|
| 74 |
+
return []
|
| 75 |
+
|
| 76 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 77 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 78 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve)
|
| 79 |
+
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0))
|
| 80 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
|
| 81 |
+
|
| 82 |
+
response = client.post(
|
| 83 |
+
"/query/ask",
|
| 84 |
+
json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7},
|
| 85 |
+
)
|
| 86 |
+
assert response.status_code == 200
|
| 87 |
+
assert captured.get("k") == 7
|
| 88 |
|
| 89 |
|
| 90 |
def test_ask_returns_422_for_invalid_payload(client):
|
|
|
|
| 92 |
assert response.status_code == 422
|
| 93 |
|
| 94 |
|
| 95 |
+
def test_ask_returns_422_for_short_question(client):
|
| 96 |
+
response = client.post(
|
| 97 |
+
"/query/ask",
|
| 98 |
+
json={"question": "hi", "collection_name": "default"},
|
| 99 |
+
)
|
| 100 |
+
assert response.status_code == 422
|
| 101 |
+
|
| 102 |
+
|
| 103 |
def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
|
| 104 |
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 105 |
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
|
|
|
| 107 |
|
| 108 |
response = client.post(
|
| 109 |
"/query/ask",
|
| 110 |
+
json={"question": "What happened in the documents?", "collection_name": "default"},
|
| 111 |
)
|
| 112 |
|
| 113 |
assert response.status_code == 500
|
|
|
|
| 127 |
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 128 |
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 129 |
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 130 |
+
monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10))
|
| 131 |
+
monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5)
|
| 132 |
monkeypatch.setattr(
|
| 133 |
"api.routes.query.persist_query_audit",
|
| 134 |
AsyncMock(side_effect=RuntimeError("audit write failed")),
|
|
|
|
| 136 |
|
| 137 |
response = client.post(
|
| 138 |
"/query/summarise",
|
| 139 |
+
json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"},
|
| 140 |
)
|
| 141 |
|
| 142 |
assert response.status_code == 500
|
workers/ingest_worker.py
CHANGED
|
@@ -5,10 +5,15 @@ from rag.chunker import chunk_documents
|
|
| 5 |
from rag.embedder import create_embedding_function
|
| 6 |
from rag.loader import load_documents
|
| 7 |
from rag.vector_store import add_documents, get_vector_store
|
| 8 |
-
from storage.job_store import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
| 12 |
documents = load_documents(temp_path)
|
| 13 |
chunks = chunk_documents(documents)
|
| 14 |
if not chunks:
|
|
@@ -25,37 +30,72 @@ def _ingest_sync(temp_path: str, collection_name: str, chroma_persist_directory:
|
|
| 25 |
|
| 26 |
async def run_ingest_job(
|
| 27 |
job_id: str,
|
| 28 |
-
|
| 29 |
collection_name: str,
|
| 30 |
jobs_db_path: str,
|
| 31 |
chroma_persist_directory: str,
|
| 32 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
try:
|
| 34 |
-
await
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
jobs_db_path,
|
| 48 |
job_id,
|
| 49 |
-
|
| 50 |
-
message=f"
|
| 51 |
-
document_ids=document_ids,
|
| 52 |
)
|
| 53 |
except Exception as exc:
|
| 54 |
-
await
|
| 55 |
-
jobs_db_path,
|
| 56 |
-
job_id,
|
| 57 |
-
status="failed",
|
| 58 |
-
message=str(exc),
|
| 59 |
-
)
|
| 60 |
-
finally:
|
| 61 |
-
Path(temp_path).unlink(missing_ok=True)
|
|
|
|
| 5 |
from rag.embedder import create_embedding_function
|
| 6 |
from rag.loader import load_documents
|
| 7 |
from rag.vector_store import add_documents, get_vector_store
|
| 8 |
+
from storage.job_store import (
|
| 9 |
+
complete_ingest_job,
|
| 10 |
+
fail_ingest_job,
|
| 11 |
+
mark_job_processing,
|
| 12 |
+
update_job_progress,
|
| 13 |
+
)
|
| 14 |
|
| 15 |
|
| 16 |
+
def _ingest_one_file_sync(temp_path: str, collection_name: str, chroma_persist_directory: str) -> tuple[list[str], int]:
|
| 17 |
documents = load_documents(temp_path)
|
| 18 |
chunks = chunk_documents(documents)
|
| 19 |
if not chunks:
|
|
|
|
| 30 |
|
| 31 |
async def run_ingest_job(
|
| 32 |
job_id: str,
|
| 33 |
+
files: list[tuple[str, str]],
|
| 34 |
collection_name: str,
|
| 35 |
jobs_db_path: str,
|
| 36 |
chroma_persist_directory: str,
|
| 37 |
) -> None:
|
| 38 |
+
"""
|
| 39 |
+
Process one or more temp files for a single job. ``files`` is (temp_path, display_name).
|
| 40 |
+
"""
|
| 41 |
+
all_doc_ids: list[str] = []
|
| 42 |
+
errors: list[str] = []
|
| 43 |
+
processed = 0
|
| 44 |
+
failed = 0
|
| 45 |
+
total = len(files)
|
| 46 |
+
if total == 0:
|
| 47 |
+
await fail_ingest_job(jobs_db_path, job_id, message="No files to ingest.")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
try:
|
| 51 |
+
await mark_job_processing(jobs_db_path, job_id)
|
| 52 |
+
for temp_path, display_name in files:
|
| 53 |
+
try:
|
| 54 |
+
doc_ids, num_chunks = await asyncio.to_thread(
|
| 55 |
+
_ingest_one_file_sync,
|
| 56 |
+
temp_path,
|
| 57 |
+
collection_name,
|
| 58 |
+
chroma_persist_directory,
|
| 59 |
+
)
|
| 60 |
+
all_doc_ids.extend(doc_ids)
|
| 61 |
+
processed += 1
|
| 62 |
+
await update_job_progress(
|
| 63 |
+
jobs_db_path,
|
| 64 |
+
job_id,
|
| 65 |
+
processed_files=processed,
|
| 66 |
+
failed_files=failed,
|
| 67 |
+
errors=errors,
|
| 68 |
+
message=f"Ingested {display_name} ({num_chunks} chunks).",
|
| 69 |
+
)
|
| 70 |
+
except Exception as exc:
|
| 71 |
+
failed += 1
|
| 72 |
+
errors.append(f"{display_name}: {exc}")
|
| 73 |
+
await update_job_progress(
|
| 74 |
+
jobs_db_path,
|
| 75 |
+
job_id,
|
| 76 |
+
processed_files=processed,
|
| 77 |
+
failed_files=failed,
|
| 78 |
+
errors=errors,
|
| 79 |
+
message=f"Failed on {display_name}: {exc}",
|
| 80 |
+
)
|
| 81 |
+
finally:
|
| 82 |
+
Path(temp_path).unlink(missing_ok=True)
|
| 83 |
+
|
| 84 |
+
if processed == 0:
|
| 85 |
+
await fail_ingest_job(
|
| 86 |
+
jobs_db_path,
|
| 87 |
+
job_id,
|
| 88 |
+
message="All files failed ingestion.",
|
| 89 |
+
errors=errors,
|
| 90 |
+
)
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
chunk_note = f"{len(all_doc_ids)} chunk vector(s) across {processed} file(s)."
|
| 94 |
+
await complete_ingest_job(
|
| 95 |
jobs_db_path,
|
| 96 |
job_id,
|
| 97 |
+
document_ids=all_doc_ids,
|
| 98 |
+
message=f"Ingestion completed. {chunk_note}",
|
|
|
|
| 99 |
)
|
| 100 |
except Exception as exc:
|
| 101 |
+
await fail_ingest_job(jobs_db_path, job_id, message=str(exc), errors=errors + [str(exc)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|