Mayank Chugh commited on
Commit
a32f9e3
·
1 Parent(s): fceb91f

Enhance environment configuration and API documentation for Milestone 11

Browse files

- Update `.env.example` to reflect new application name, version, and additional configuration options for LLM providers.
- Revise `README.md` to improve architecture overview, quick start instructions, and use cases for the DocuAudit AI application.
- Add detailed specifications for API endpoints in `milestones.md`, including new features for multi-file uploads, query parameters, and audit logging.
- Refactor API routes to support enhanced query and ingestion functionalities, including user ID tracking and improved response structures.
- Update request and response models to align with new API specifications and ensure comprehensive coverage of expected behaviors.

.env.example CHANGED
@@ -1,12 +1,48 @@
1
- APP_NAME=doc-audi-ai
2
- APP_VERSION=0.1.0
 
3
  LLM_PROVIDER=ollama
4
- EMBEDDING_MODEL_NAME=nomic-embed-text
5
- OLLAMA_BASE_URL=http://localhost:11434
6
  OPENAI_API_KEY=
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  CHROMA_PERSIST_DIRECTORY=./data/chroma
8
- JOBS_DB_PATH=./data/jobs.db
 
 
 
9
  CHUNK_SIZE=1000
10
- CHUNK_OVERLAP=150
11
- RETRIEVAL_K=4
12
- MAX_FILE_SIZE_MB=25
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DocuAudit AI — environment template (see docs/DOCUAUDIT_AI_REQUIREMENTS.md)
2
+
3
+ # LLM Provider: ollama | anthropic | openai | huggingface
4
  LLM_PROVIDER=ollama
5
+
6
+ # OpenAI (optional)
7
  OPENAI_API_KEY=
8
+ OPENAI_MODEL=gpt-4o
9
+ OPENAI_EMBEDDING_MODEL=text-embedding-3-small
10
+
11
+ # Anthropic (optional)
12
+ ANTHROPIC_API_KEY=
13
+ ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
14
+
15
+ # Ollama (recommended local default)
16
+ OLLAMA_BASE_URL=http://localhost:11434
17
+ OLLAMA_CHAT_MODEL=llama3.1:8b
18
+ OLLAMA_EMBEDDING_MODEL=nomic-embed-text
19
+
20
+ # App
21
+ APP_NAME=DocuAudit AI
22
+ APP_VERSION=1.0.0
23
+ DEBUG=false
24
+ MAX_FILE_SIZE_MB=50
25
+ # Spec name alias (optional; mapped to MAX_FILE_SIZE_MB in settings)
26
+ MAX_UPLOAD_SIZE_MB=
27
+
28
+ # ChromaDB
29
  CHROMA_PERSIST_DIRECTORY=./data/chroma
30
+ CHROMA_PERSIST_DIR=
31
+ CHROMA_COLLECTION_NAME=docuaudit_docs
32
+
33
+ # Chunking
34
  CHUNK_SIZE=1000
35
+ CHUNK_OVERLAP=200
36
+
37
+ # Retrieval default (overridable per request on /query/ask via top_k)
38
+ TOP_K_RESULTS=5
39
+
40
+ # Audit + jobs SQLite
41
+ AUDIT_DB_PATH=./audit.db
42
+ JOBS_DB_PATH=./data/jobs.db
43
+
44
+ # Limits
45
+ MAX_DOCUMENTS_PER_BATCH=100
46
+
47
+ # Streamlit → API
48
+ STREAMLIT_BACKEND_URL=http://localhost:8000
README.md CHANGED
@@ -1,42 +1,138 @@
1
- # doc-Audi-ai
2
 
 
3
 
4
- create requirements.txt & .env
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # 1. Setup environment
8
- uv venv --python 3.11.14
9
- uv init --python 3.11.14
10
- uv add requirements.txt
11
- uv pip install -r requirements.txt
12
- copy .env.example .env
13
 
14
- Note for Intel Macs (x86_64):
15
- - `pyproject.toml` includes platform-specific pins for `onnxruntime` and `torch` to ensure `uv` resolves versions that have compatible wheels.
 
 
 
 
 
16
 
17
- Install Ollama
18
- curl -fsSL https://ollama.com/install.sh | sh
19
 
20
- Pull required models
21
- ollama pull llama3.1:8b # LLM for answer generation (~2 GB)
22
- ollama pull nomic-embed-text # Embedding model (~274 MB)
23
 
24
- Start Ollama server
25
- ollama serve &
26
 
27
- Verify models are running
28
- curl http://localhost:11434/api/tags
29
 
30
- ### Start backend server- fastapi
31
  ```bash
32
- uv run uvicorn api.main:app --host 0.0.0.0 --port 8000
 
 
 
 
 
 
33
  ```
34
- ### Start frontend server- Streamlit
 
 
35
  ```bash
36
- uv run streamlit run streamlit_app.py --server.address=0.0.0.0 --server.port=8501
37
  ```
38
 
39
- git diff --name-status --diff-filter=AM <from_commit_hash> <to_commit_hash>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- git diff --name-status --diff-filter=AM 18ad0e6c94d041b1fd902e7f9b60113738eee1fa 0f2ee3afa124348adece82df0ff0e5a0943a7b8b
42
 
 
 
1
+ # DocuAudit AI
2
 
3
+ **DocuAudit AI** is a production-oriented FastAPI backend plus optional Streamlit UI for **multi-document RAG**: upload documents, build a Chroma vector index, ask grounded questions with citations, and retain a **SQLite audit trail** of every query.
4
 
5
+ ## Architecture
6
 
7
+ ```mermaid
8
+ flowchart LR
9
+ subgraph ingest [Ingestion]
10
+ A[PDF / TXT / MD] --> B[Loader]
11
+ B --> C[Chunker]
12
+ C --> D[Embedder]
13
+ D --> E[(ChromaDB)]
14
+ end
15
+ subgraph query [Query path]
16
+ Q[User question] --> R[Semantic search]
17
+ R --> E
18
+ R --> T[Top-K chunks]
19
+ T --> L[LLM]
20
+ L --> U[Answer + citations]
21
+ end
22
+ U --> V[(SQLite audit)]
23
+ ```
24
 
25
+ ASCII equivalent:
 
 
 
 
 
26
 
27
+ ```
28
+ PDF Upload Parser Chunker Embedder ChromaDB
29
+
30
+ User Query → Semantic Search → Top-K Chunks → LLM → Answer + Citations
31
+
32
+ Audit Log (SQLite)
33
+ ```
34
 
35
+ ## Use cases
 
36
 
37
+ - **Litigation document analysis** — trace claims to exact pages and filenames.
38
+ - **Corporate finance review** compare disclosures and filings under a consistent audit log.
39
+ - **Investigation support** bulk ingest, async jobs, and reproducible query history.
40
 
41
+ ## Quick start (local, without Docker)
 
42
 
43
+ Docker and Compose are planned under **Milestone 12**. Until then, run the API with **uv** (or your preferred tool):
 
44
 
 
45
  ```bash
46
+ git clone <repository-url> doc-Audi-ai
47
+ cd doc-Audi-ai
48
+ copy .env.example .env
49
+ uv sync
50
+ ollama pull llama3.1:8b
51
+ ollama pull nomic-embed-text
52
+ uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
53
  ```
54
+
55
+ Optional UI:
56
+
57
  ```bash
58
+ uv run streamlit run streamlit_app.py --server.port 8501 --server.address 0.0.0.0
59
  ```
60
 
61
+ After **Milestone 12**, the intended one-command experience will be `docker compose up` for API (`localhost:8000`) and UI (`localhost:8501`).
62
+
63
+ ## API overview
64
+
65
+ | Method | Path | Description |
66
+ |--------|------|-------------|
67
+ | GET | `/health` | Liveness; returns configured app name and version |
68
+ | POST | `/ingest/upload` | Multipart **`files`** (one or more); queues background ingest job |
69
+ | POST | `/ingest/url` | JSON **`urls`** array (1–100); download and queue ingest |
70
+ | GET | `/ingest/collections` | Lists collections with **`document_count`** and optional **`created_at`** |
71
+ | DELETE | `/ingest/collection/{collection_name}` | Drops a collection; returns **`documents_removed`** |
72
+ | GET | `/jobs` | Lists jobs with **`total`** count |
73
+ | GET | `/jobs/{job_id}` | Job status with **`progress_percent`**, file counters, timestamps, **`errors`** |
74
+ | POST | `/query/ask` | Grounded answer; request includes **`top_k`**, **`user_id`** |
75
+ | POST | `/query/summarise` | Collection summary; distinct response shape (`summary`, `document_count`, …) |
76
+ | POST | `/query` | Legacy alias of **`/query/ask`** |
77
+ | GET | `/audit/logs` | Filterable audit index (`user_id`, `from_date`, `to_date`, pagination) |
78
+ | GET | `/audit/logs/{query_id}` | Full stored answer and citations for one query |
79
+
80
+ Interactive docs: `http://localhost:8000/docs`.
81
+
82
+ ## Sample request and response (`POST /query/ask`)
83
+
84
+ Request:
85
+
86
+ ```json
87
+ {
88
+ "question": "What were the key risk factors identified in the Q3 2023 financial report?",
89
+ "collection_name": "default",
90
+ "top_k": 5,
91
+ "user_id": "analyst_001"
92
+ }
93
+ ```
94
+
95
+ Response (shape; values depend on your documents and model):
96
+
97
+ ```json
98
+ {
99
+ "query_id": "uuid-string",
100
+ "question": "What were the key risk factors identified in the Q3 2023 financial report?",
101
+ "answer": "… grounded text with citations …",
102
+ "sources": [
103
+ {
104
+ "document_name": "q3_financial_report.pdf",
105
+ "page_number": 12,
106
+ "chunk_text": "Key risk factors include …",
107
+ "relevance_score": 0.91
108
+ }
109
+ ],
110
+ "model_used": "llama3.1:8b",
111
+ "tokens_used": 0,
112
+ "response_time_ms": 1820,
113
+ "timestamp": "2026-05-03T12:00:00Z"
114
+ }
115
+ ```
116
+
117
+ ## Design decisions
118
+
119
+ - **Source citations** — High-stakes review requires every substantive claim to be tied to **document name** and **page** (where available), not a free-floating model monologue.
120
+ - **Auditability** — Each ask/summarise persists **query id**, **user id**, timing, model id, token usage (when the provider exposes it), and serialized sources so regulators or counsel can reconstruct what the system returned.
121
+
122
+ ## Scale note
123
+
124
+ Architecture is designed for **high-volume document ingestion** via **async background jobs** (FastAPI `BackgroundTasks`), persistent Chroma collections, and a stateless API tier that can be replicated once you add a shared vector store and job queue.
125
+
126
+ ## Tests
127
+
128
+ ```bash
129
+ uv run pytest tests/ -q
130
+ ```
131
+
132
+ ## Configuration
133
+
134
+ See **`.env.example`**. Common variables include `LLM_PROVIDER`, Ollama/OpenAI/Anthropic keys and models, `CHROMA_PERSIST_DIRECTORY`, `AUDIT_DB_PATH`, `JOBS_DB_PATH`, and upload limits (`MAX_FILE_SIZE_MB`; **`MAX_UPLOAD_SIZE_MB`** is accepted as an alias via settings normalization).
135
 
136
+ ## Specification
137
 
138
+ Authoritative product and API shapes: **`docs/DOCUAUDIT_AI_REQUIREMENTS.md`**. Gap tracking: **`docs/REQUIREMENTS_IMPLEMENTATION_GAPS.md`**.
api/config.py CHANGED
@@ -1,5 +1,7 @@
1
  from functools import lru_cache
2
- from pydantic import Field
 
 
3
 
4
  from pydantic_settings import BaseSettings, SettingsConfigDict
5
 
@@ -10,10 +12,30 @@ class Settings(BaseSettings):
10
  env_file_encoding="utf-8",
11
  extra="ignore",
12
  case_sensitive=False,
 
13
  )
14
 
15
- app_name: str = Field(default="doc-audi-ai", description="The name of the application")
16
- app_version: str = Field(default="0.1.0", description="The version of the application")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  llm_provider: str = Field(default="ollama", description="Embedding provider")
18
 
19
  openai_api_key: str | None = Field(default=None, description="OpenAI API key")
@@ -38,12 +60,12 @@ class Settings(BaseSettings):
38
 
39
  chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
40
  chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
41
- top_k_results: int = Field(default=4, ge=1, le=20, description="Number of chunks to retrieve")
42
 
43
  audit_db_path: str = "./audit.db"
44
  jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
45
 
46
- max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size")
47
  max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
48
 
49
 
 
1
  from functools import lru_cache
2
+ from typing import Any
3
+
4
+ from pydantic import Field, model_validator
5
 
6
  from pydantic_settings import BaseSettings, SettingsConfigDict
7
 
 
12
  env_file_encoding="utf-8",
13
  extra="ignore",
14
  case_sensitive=False,
15
+ populate_by_name=True,
16
  )
17
 
18
+ @model_validator(mode="before")
19
+ @classmethod
20
+ def _map_max_upload_env_alias(cls, data: Any) -> Any:
21
+ if not isinstance(data, dict):
22
+ return data
23
+ out = dict(data)
24
+ if out.get("max_file_size_mb") in (None, "") and out.get("max_upload_size_mb") not in (None, ""):
25
+ out["max_file_size_mb"] = out.pop("max_upload_size_mb")
26
+ elif "max_upload_size_mb" in out and "max_file_size_mb" not in out:
27
+ out["max_file_size_mb"] = out.pop("max_upload_size_mb")
28
+ return out
29
+
30
+ app_name: str = Field(default="DocuAudit AI", description="FastAPI title and product name")
31
+ app_version: str = Field(default="1.0.0", description="Application version")
32
+ app_description: str = Field(
33
+ default=(
34
+ "Multi-document RAG API for high-stakes consulting environments. "
35
+ "Every answer is grounded in source documents with full audit trails."
36
+ ),
37
+ description="OpenAPI /docs description",
38
+ )
39
  llm_provider: str = Field(default="ollama", description="Embedding provider")
40
 
41
  openai_api_key: str | None = Field(default=None, description="OpenAI API key")
 
60
 
61
  chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
62
  chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
63
+ top_k_results: int = Field(default=5, ge=1, le=20, description="Default number of chunks to retrieve")
64
 
65
  audit_db_path: str = "./audit.db"
66
  jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
67
 
68
+ max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size (MB)")
69
  max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
70
 
71
 
api/main.py CHANGED
@@ -4,13 +4,27 @@ import os
4
  os.environ.setdefault("ANONYMIZED_TELEMETRY", "FALSE")
5
 
6
  from fastapi import FastAPI
 
7
 
8
  from api.config import get_settings
9
  from storage.audit_store import init_audit_db
10
  from storage.job_store import init_jobs_db
11
  from .routes import audit, ingest, jobs, query
12
 
13
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  app.include_router(audit.router)
16
  app.include_router(ingest.router)
@@ -25,6 +39,12 @@ async def startup() -> None:
25
  await init_audit_db(settings.audit_db_path)
26
  await init_jobs_db(settings.jobs_db_path)
27
 
 
28
  @app.get("/health", tags=["Health"])
29
  def health() -> dict[str, str]:
30
- return {"status": "ok","app_name": "doc-audi-ai", "version": "0.1.0"}
 
 
 
 
 
 
4
  os.environ.setdefault("ANONYMIZED_TELEMETRY", "FALSE")
5
 
6
  from fastapi import FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
 
9
  from api.config import get_settings
10
  from storage.audit_store import init_audit_db
11
  from storage.job_store import init_jobs_db
12
  from .routes import audit, ingest, jobs, query
13
 
14
+ _settings = get_settings()
15
+ app = FastAPI(
16
+ title=_settings.app_name,
17
+ version=_settings.app_version,
18
+ description=_settings.app_description,
19
+ )
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
 
29
  app.include_router(audit.router)
30
  app.include_router(ingest.router)
 
39
  await init_audit_db(settings.audit_db_path)
40
  await init_jobs_db(settings.jobs_db_path)
41
 
42
+
43
  @app.get("/health", tags=["Health"])
44
  def health() -> dict[str, str]:
45
+ settings = get_settings()
46
+ return {
47
+ "status": "ok",
48
+ "app": settings.app_name,
49
+ "version": settings.app_version,
50
+ }
api/routes/audit.py CHANGED
@@ -4,42 +4,54 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
4
 
5
  from api.config import get_settings
6
  from models.requests import AuditListParams
7
- from models.responses import AuditDetailResponse, AuditEvent, AuditListResponse
8
  from storage.audit_store import get_audit_event, list_audit_events
9
 
10
 
11
  def _audit_list_params(
12
- limit: Annotated[int, Query(ge=1, le=100)] = 10,
13
  offset: Annotated[int, Query(ge=0)] = 0,
 
 
 
14
  ) -> AuditListParams:
15
- return AuditListParams(limit=limit, offset=offset)
 
 
 
 
 
 
16
 
17
 
18
  router = APIRouter(prefix="/audit", tags=["audit"])
19
 
20
 
21
- @router.get("/logs", response_model=AuditListResponse)
22
  async def audit_logs(
23
  params: Annotated[AuditListParams, Depends(_audit_list_params)],
24
- ) -> AuditListResponse:
25
  settings = get_settings()
26
- rows = await list_audit_events(settings.audit_db_path, limit=params.limit, offset=params.offset)
27
- events = [AuditEvent.model_validate(row) for row in rows]
28
- return AuditListResponse(
29
- status="success",
30
- message=f"Returned {len(events)} audit event(s).",
31
- events=events,
 
 
 
 
 
 
 
32
  )
33
 
34
 
35
- @router.get("/logs/{query_id}", response_model=AuditDetailResponse)
36
- async def audit_log_detail(query_id: str) -> AuditDetailResponse:
37
  settings = get_settings()
38
  event = await get_audit_event(settings.audit_db_path, query_id)
39
  if event is None:
40
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
41
- return AuditDetailResponse(
42
- status="success",
43
- message="Audit event retrieved.",
44
- event=event,
45
- )
 
4
 
5
  from api.config import get_settings
6
  from models.requests import AuditListParams
7
+ from models.responses import AuditLogDetailResponse, AuditLogsResponse
8
  from storage.audit_store import get_audit_event, list_audit_events
9
 
10
 
11
  def _audit_list_params(
12
+ limit: Annotated[int, Query(ge=1, le=100)] = 50,
13
  offset: Annotated[int, Query(ge=0)] = 0,
14
+ user_id: Annotated[str | None, Query(max_length=256)] = None,
15
+ from_date: Annotated[str | None, Query(description="ISO 8601 lower bound")] = None,
16
+ to_date: Annotated[str | None, Query(description="ISO 8601 upper bound")] = None,
17
  ) -> AuditListParams:
18
+ return AuditListParams(
19
+ limit=limit,
20
+ offset=offset,
21
+ user_id=user_id,
22
+ from_date=from_date,
23
+ to_date=to_date,
24
+ )
25
 
26
 
27
  router = APIRouter(prefix="/audit", tags=["audit"])
28
 
29
 
30
+ @router.get("/logs", response_model=AuditLogsResponse)
31
  async def audit_logs(
32
  params: Annotated[AuditListParams, Depends(_audit_list_params)],
33
+ ) -> AuditLogsResponse:
34
  settings = get_settings()
35
+ logs, total = await list_audit_events(
36
+ settings.audit_db_path,
37
+ limit=params.limit,
38
+ offset=params.offset,
39
+ user_id=params.user_id,
40
+ from_date=params.from_date,
41
+ to_date=params.to_date,
42
+ )
43
+ return AuditLogsResponse(
44
+ logs=logs,
45
+ total=total,
46
+ limit=params.limit,
47
+ offset=params.offset,
48
  )
49
 
50
 
51
+ @router.get("/logs/{query_id}", response_model=AuditLogDetailResponse)
52
+ async def audit_log_detail(query_id: str) -> AuditLogDetailResponse:
53
  settings = get_settings()
54
  event = await get_audit_event(settings.audit_db_path, query_id)
55
  if event is None:
56
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
57
+ return event
 
 
 
 
api/routes/ingest.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pathlib import Path
2
  from tempfile import NamedTemporaryFile
3
  from typing import Annotated
@@ -7,14 +8,20 @@ import httpx
7
  from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
8
 
9
  from api.config import get_settings
10
- from models.requests import IngestUrlRequest
11
  from models.responses import (
 
12
  IngestCollectionsResponse,
13
  IngestDeleteCollectionResponse,
14
  IngestUploadResponse,
15
- CollectionItem,
 
 
 
 
 
 
16
  )
17
- from rag.vector_store import delete_collection, list_collection_names
18
  from storage.job_store import create_ingest_job
19
  from workers.ingest_worker import run_ingest_job
20
 
@@ -86,7 +93,7 @@ async def _download_url_to_temp(url: str, max_bytes: int) -> tuple[str, str]:
86
 
87
  timeout = httpx.Timeout(60.0, connect=10.0)
88
  limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
89
- headers = {"User-Agent": "doc-audi-ai/ingest"}
90
 
91
  try:
92
  async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
@@ -132,97 +139,131 @@ async def _download_url_to_temp(url: str, max_bytes: int) -> tuple[str, str]:
132
  return temp_path, display_name
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  @router.post("/upload", response_model=IngestUploadResponse)
136
  async def upload_endpoint(
137
  background_tasks: BackgroundTasks,
138
- file: UploadFile = File(..., description="PDF/TXT/MD document to ingest"),
139
  collection_name: Annotated[str, Form(min_length=1, max_length=256)] = "default",
140
  ) -> IngestUploadResponse:
141
  settings = get_settings()
142
- max_bytes = settings.max_file_size_mb * 1024 * 1024
143
- suffix = _validate_file(file, max_bytes)
144
- display_name = (file.filename or "upload").strip()
 
 
 
 
145
 
146
- temp_path = ""
 
 
147
  try:
148
- file_bytes = await file.read()
149
- with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
150
- temp_path = tmp.name
151
- tmp.write(file_bytes)
 
 
 
 
 
152
 
153
  job_id = await create_ingest_job(
154
  settings.jobs_db_path,
155
- collection_name=collection_name,
156
- filename=display_name,
157
  )
158
 
159
  background_tasks.add_task(
160
  run_ingest_job,
161
  job_id,
162
- temp_path,
163
- collection_name,
164
  settings.jobs_db_path,
165
  settings.chroma_persist_directory,
166
  )
167
 
168
  return IngestUploadResponse(
169
- status="queued",
170
- message=f"Ingestion job accepted. Poll GET /jobs/{job_id} for status.",
171
  job_id=job_id,
172
- document_ids=[],
 
 
 
173
  )
174
  except HTTPException:
175
- if temp_path:
176
- Path(temp_path).unlink(missing_ok=True)
177
  raise
178
  except Exception as exc:
179
- if temp_path:
180
- Path(temp_path).unlink(missing_ok=True)
181
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
182
- finally:
183
- await file.close()
184
 
185
 
186
- @router.post("/url", response_model=IngestUploadResponse)
187
  async def ingest_url_endpoint(
188
  background_tasks: BackgroundTasks,
189
- payload: IngestUrlRequest,
190
- ) -> IngestUploadResponse:
191
  settings = get_settings()
192
  max_bytes = settings.max_file_size_mb * 1024 * 1024
193
- url_str = str(payload.url).strip()
194
- temp_path = ""
 
 
 
 
 
 
195
  try:
196
- temp_path, display_name = await _download_url_to_temp(url_str, max_bytes)
 
 
197
 
 
198
  job_id = await create_ingest_job(
199
  settings.jobs_db_path,
200
- collection_name=payload.collection_name,
201
- filename=display_name,
202
  )
203
 
204
  background_tasks.add_task(
205
  run_ingest_job,
206
  job_id,
207
- temp_path,
208
- payload.collection_name,
209
  settings.jobs_db_path,
210
  settings.chroma_persist_directory,
211
  )
212
 
213
- return IngestUploadResponse(
214
- status="queued",
215
- message=f"Ingestion job accepted. Poll GET /jobs/{job_id} for status.",
216
  job_id=job_id,
217
- document_ids=[],
 
 
218
  )
219
  except HTTPException:
220
- if temp_path:
221
- Path(temp_path).unlink(missing_ok=True)
222
  raise
223
  except Exception as exc:
224
- if temp_path:
225
- Path(temp_path).unlink(missing_ok=True)
226
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
227
 
228
 
@@ -233,12 +274,18 @@ async def list_collections_endpoint() -> IngestCollectionsResponse:
233
  names = list_collection_names(settings.chroma_persist_directory)
234
  except Exception as exc:
235
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
236
- items = [CollectionItem(name=n) for n in names]
237
- return IngestCollectionsResponse(
238
- status="success",
239
- message=f"Found {len(items)} collection(s).",
240
- collections=items,
241
- )
 
 
 
 
 
 
242
 
243
 
244
  @router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
@@ -251,13 +298,12 @@ async def delete_collection_endpoint(collection_name: str) -> IngestDeleteCollec
251
  existing = list_collection_names(settings.chroma_persist_directory)
252
  if name not in existing:
253
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
254
- delete_collection(settings.chroma_persist_directory, name)
255
  except HTTPException:
256
  raise
257
  except Exception as exc:
258
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
259
  return IngestDeleteCollectionResponse(
260
- status="success",
261
- message=f"Deleted collection '{name}'.",
262
- collection_name=name,
263
  )
 
1
+ from datetime import datetime, timezone
2
  from pathlib import Path
3
  from tempfile import NamedTemporaryFile
4
  from typing import Annotated
 
8
  from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
9
 
10
  from api.config import get_settings
11
+ from models.requests import URLIngestRequest
12
  from models.responses import (
13
+ CollectionItem,
14
  IngestCollectionsResponse,
15
  IngestDeleteCollectionResponse,
16
  IngestUploadResponse,
17
+ UrlIngestResponse,
18
+ )
19
+ from rag.vector_store import (
20
+ collection_created_at,
21
+ collection_document_count,
22
+ delete_collection,
23
+ list_collection_names,
24
  )
 
25
  from storage.job_store import create_ingest_job
26
  from workers.ingest_worker import run_ingest_job
27
 
 
93
 
94
  timeout = httpx.Timeout(60.0, connect=10.0)
95
  limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
96
+ headers = {"User-Agent": "docuaudit-ai/ingest"}
97
 
98
  try:
99
  async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
 
139
  return temp_path, display_name
140
 
141
 
142
+ def _parse_created_at(raw: str | None) -> datetime | None:
143
+ if not raw:
144
+ return None
145
+ s = raw.strip()
146
+ if s.endswith("Z"):
147
+ s = s[:-1] + "+00:00"
148
+ try:
149
+ dt = datetime.fromisoformat(s)
150
+ if dt.tzinfo is None:
151
+ return dt.replace(tzinfo=timezone.utc)
152
+ return dt
153
+ except ValueError:
154
+ return None
155
+
156
+
157
  @router.post("/upload", response_model=IngestUploadResponse)
158
  async def upload_endpoint(
159
  background_tasks: BackgroundTasks,
160
+ files: list[UploadFile] = File(..., description="One or more PDF, TXT, or MD files"),
161
  collection_name: Annotated[str, Form(min_length=1, max_length=256)] = "default",
162
  ) -> IngestUploadResponse:
163
  settings = get_settings()
164
+ if not files:
165
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required.")
166
+ if len(files) > settings.max_documents_per_batch:
167
+ raise HTTPException(
168
+ status_code=status.HTTP_400_BAD_REQUEST,
169
+ detail=f"Too many files in one request (max {settings.max_documents_per_batch}).",
170
+ )
171
 
172
+ max_bytes = settings.max_file_size_mb * 1024 * 1024
173
+ temp_paths: list[tuple[str, str]] = []
174
+ filenames: list[str] = []
175
  try:
176
+ for file in files:
177
+ suffix = _validate_file(file, max_bytes)
178
+ display_name = (file.filename or "upload").strip()
179
+ file_bytes = await file.read()
180
+ with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
181
+ tmp.write(file_bytes)
182
+ temp_paths.append((tmp.name, display_name))
183
+ filenames.append(display_name)
184
+ await file.close()
185
 
186
  job_id = await create_ingest_job(
187
  settings.jobs_db_path,
188
+ collection_name=collection_name.strip(),
189
+ filenames=filenames,
190
  )
191
 
192
  background_tasks.add_task(
193
  run_ingest_job,
194
  job_id,
195
+ temp_paths,
196
+ collection_name.strip(),
197
  settings.jobs_db_path,
198
  settings.chroma_persist_directory,
199
  )
200
 
201
  return IngestUploadResponse(
 
 
202
  job_id=job_id,
203
+ status="queued",
204
+ total_files=len(filenames),
205
+ filenames=filenames,
206
+ message=f"Documents queued for processing. Poll /jobs/{job_id} for status.",
207
  )
208
  except HTTPException:
209
+ for path, _ in temp_paths:
210
+ Path(path).unlink(missing_ok=True)
211
  raise
212
  except Exception as exc:
213
+ for path, _ in temp_paths:
214
+ Path(path).unlink(missing_ok=True)
215
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
 
 
216
 
217
 
218
+ @router.post("/url", response_model=UrlIngestResponse)
219
  async def ingest_url_endpoint(
220
  background_tasks: BackgroundTasks,
221
+ payload: URLIngestRequest,
222
+ ) -> UrlIngestResponse:
223
  settings = get_settings()
224
  max_bytes = settings.max_file_size_mb * 1024 * 1024
225
+ url_strings = [str(u).strip() for u in payload.urls]
226
+ if len(url_strings) > settings.max_documents_per_batch:
227
+ raise HTTPException(
228
+ status_code=status.HTTP_400_BAD_REQUEST,
229
+ detail=f"Too many URLs in one request (max {settings.max_documents_per_batch}).",
230
+ )
231
+
232
+ downloaded: list[tuple[str, str]] = []
233
  try:
234
+ for url_str in url_strings:
235
+ temp_path, display_name = await _download_url_to_temp(url_str, max_bytes)
236
+ downloaded.append((temp_path, display_name))
237
 
238
+ coll = (payload.collection_name or "default").strip()
239
  job_id = await create_ingest_job(
240
  settings.jobs_db_path,
241
+ collection_name=coll,
242
+ filenames=[name for _, name in downloaded],
243
  )
244
 
245
  background_tasks.add_task(
246
  run_ingest_job,
247
  job_id,
248
+ downloaded,
249
+ coll,
250
  settings.jobs_db_path,
251
  settings.chroma_persist_directory,
252
  )
253
 
254
+ return UrlIngestResponse(
 
 
255
  job_id=job_id,
256
+ status="queued",
257
+ total_urls=len(downloaded),
258
+ message="URLs queued for download and processing.",
259
  )
260
  except HTTPException:
261
+ for path, _ in downloaded:
262
+ Path(path).unlink(missing_ok=True)
263
  raise
264
  except Exception as exc:
265
+ for path, _ in downloaded:
266
+ Path(path).unlink(missing_ok=True)
267
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
268
 
269
 
 
274
  names = list_collection_names(settings.chroma_persist_directory)
275
  except Exception as exc:
276
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
277
+ items: list[CollectionItem] = []
278
+ for n in names:
279
+ cnt = collection_document_count(settings.chroma_persist_directory, n)
280
+ raw_created = collection_created_at(settings.chroma_persist_directory, n)
281
+ items.append(
282
+ CollectionItem(
283
+ name=n,
284
+ document_count=cnt,
285
+ created_at=_parse_created_at(raw_created),
286
+ )
287
+ )
288
+ return IngestCollectionsResponse(collections=items, total=len(items))
289
 
290
 
291
  @router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
 
298
  existing = list_collection_names(settings.chroma_persist_directory)
299
  if name not in existing:
300
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
301
+ removed = delete_collection(settings.chroma_persist_directory, name)
302
  except HTTPException:
303
  raise
304
  except Exception as exc:
305
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
306
  return IngestDeleteCollectionResponse(
307
+ message=f"Collection '{name}' deleted successfully.",
308
+ documents_removed=removed,
 
309
  )
api/routes/jobs.py CHANGED
@@ -4,8 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
4
 
5
  from api.config import get_settings
6
  from models.requests import JobsListParams
7
- from models.responses import IngestJobDetailResponse, JobListResponse, JobSummary
8
- from storage.job_store import get_ingest_job, list_ingest_jobs
9
 
10
 
11
  def _jobs_list_params(
@@ -23,27 +23,18 @@ async def list_jobs(
23
  params: Annotated[JobsListParams, Depends(_jobs_list_params)],
24
  ) -> JobListResponse:
25
  settings = get_settings()
26
- rows = await list_ingest_jobs(
27
  settings.jobs_db_path,
28
  limit=params.limit,
29
  offset=params.offset,
30
  )
31
- jobs = [JobSummary.model_validate(row) for row in rows]
32
- return JobListResponse(
33
- status="success",
34
- message=f"Returned {len(jobs)} job(s).",
35
- jobs=jobs,
36
- )
37
 
38
 
39
- @router.get("/jobs/{job_id}", response_model=IngestJobDetailResponse)
40
- async def get_job(job_id: str) -> IngestJobDetailResponse:
41
  settings = get_settings()
42
- job = await get_ingest_job(settings.jobs_db_path, job_id)
43
  if job is None:
44
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found.")
45
- return IngestJobDetailResponse(
46
- status="success",
47
- message="Job found.",
48
- job=job,
49
- )
 
4
 
5
  from api.config import get_settings
6
  from models.requests import JobsListParams
7
+ from models.responses import JobListResponse, JobStatusResponse
8
+ from storage.job_store import get_job_status, list_ingest_jobs
9
 
10
 
11
  def _jobs_list_params(
 
23
  params: Annotated[JobsListParams, Depends(_jobs_list_params)],
24
  ) -> JobListResponse:
25
  settings = get_settings()
26
+ jobs, total = await list_ingest_jobs(
27
  settings.jobs_db_path,
28
  limit=params.limit,
29
  offset=params.offset,
30
  )
31
+ return JobListResponse(jobs=jobs, total=total)
 
 
 
 
 
32
 
33
 
34
+ @router.get("/jobs/{job_id}", response_model=JobStatusResponse)
35
+ async def get_job(job_id: str) -> JobStatusResponse:
36
  settings = get_settings()
37
+ job = await get_job_status(settings.jobs_db_path, job_id)
38
  if job is None:
39
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found.")
40
+ return job
 
 
 
 
api/routes/query.py CHANGED
@@ -1,8 +1,12 @@
 
 
 
 
1
  from fastapi import APIRouter, HTTPException, status
2
 
3
- from api.config import get_settings
4
  from models.requests import QueryRequest, SummariseRequest
5
- from models.responses import QueryResponse, QueryResultItem, QuerySourceItem
6
  from rag.embedder import create_embedding_function
7
  from rag.retriever import (
8
  SUMMARY_RETRIEVAL_QUERY,
@@ -11,117 +15,153 @@ from rag.retriever import (
11
  retrieve_chunks,
12
  summarise_with_grounding,
13
  )
14
- from rag.vector_store import get_vector_store
15
  from storage.audit_store import persist_query_audit
16
 
17
  router = APIRouter(prefix="/query", tags=["query"])
18
 
19
 
20
- def _response_from_chunks(
21
- *,
22
- collection_name: str,
23
- chunks: list[RetrievedChunk],
24
- answer: str,
25
- message: str,
26
- ) -> QueryResponse:
27
- results = [QueryResultItem(text=chunk.text, score=chunk.score) for chunk in chunks]
28
- sources = [
29
- QuerySourceItem(
30
- source=chunk.source,
31
- page=chunk.page,
32
- chunk_index=chunk.chunk_index,
33
- score=chunk.score,
34
- excerpt=chunk.text[:280],
35
- )
36
- for chunk in chunks
37
- ]
38
- return QueryResponse(
39
- status="success",
40
- message=message,
41
- answer=answer,
42
- sources=sources,
43
- results=results,
44
- )
45
 
46
 
47
- @router.post("/ask", response_model=QueryResponse)
48
- async def ask_endpoint(payload: QueryRequest) -> QueryResponse:
49
- settings = get_settings()
50
- try:
51
- embedding_function = create_embedding_function()
52
- vector_store = get_vector_store(
53
- persist_directory=settings.chroma_persist_directory,
54
- collection_name=payload.collection_name,
55
- embedding_function=embedding_function,
 
 
 
56
  )
57
- chunks = retrieve_chunks(vector_store, payload.question, settings.top_k_results)
58
- answer = answer_with_grounding(settings, payload.question, chunks)
59
- except Exception as exc:
60
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
61
 
62
- response = _response_from_chunks(
63
- collection_name=payload.collection_name,
64
- chunks=chunks,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  answer=answer,
66
- message=(
67
- f"Retrieved {len(chunks)} chunks from '{payload.collection_name}' and generated a grounded answer."
68
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
- try:
71
- await persist_query_audit(
72
- settings.audit_db_path,
73
- action="query",
74
- question=payload.question,
75
- collection_name=payload.collection_name,
76
- response=response,
77
- )
78
- except Exception as exc:
79
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
80
  return response
81
 
82
 
83
- @router.post("/summarise", response_model=QueryResponse)
84
- async def summarise_endpoint(payload: SummariseRequest) -> QueryResponse:
85
- settings = get_settings()
 
 
86
  retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
87
  audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  try:
89
- embedding_function = create_embedding_function()
90
- vector_store = get_vector_store(
91
- persist_directory=settings.chroma_persist_directory,
92
- collection_name=payload.collection_name,
93
- embedding_function=embedding_function,
94
- )
95
- chunks = retrieve_chunks(vector_store, retrieval_query, settings.top_k_results)
96
- answer = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
97
  except Exception as exc:
98
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
99
 
100
- response = _response_from_chunks(
101
- collection_name=payload.collection_name,
102
- chunks=chunks,
103
- answer=answer,
104
- message=(
105
- f"Retrieved {len(chunks)} chunks from '{payload.collection_name}' and generated a grounded summary."
106
- ),
107
- )
108
  try:
109
- await persist_query_audit(
110
- settings.audit_db_path,
111
- action="summarise",
112
- question=audit_question,
113
- collection_name=payload.collection_name,
114
- response=response,
115
- )
116
  except Exception as exc:
117
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
118
- return response
119
 
120
 
121
  legacy_query_router = APIRouter(tags=["query"])
122
 
123
 
124
- @legacy_query_router.post("/query", response_model=QueryResponse)
125
- async def query_post_compat(payload: QueryRequest) -> QueryResponse:
126
  """Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
127
  return await ask_endpoint(payload)
 
1
+ import time
2
+ from datetime import datetime, timezone
3
+ from uuid import uuid4
4
+
5
  from fastapi import APIRouter, HTTPException, status
6
 
7
+ from api.config import Settings, get_settings
8
  from models.requests import QueryRequest, SummariseRequest
9
+ from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
10
  from rag.embedder import create_embedding_function
11
  from rag.retriever import (
12
  SUMMARY_RETRIEVAL_QUERY,
 
15
  retrieve_chunks,
16
  summarise_with_grounding,
17
  )
18
+ from rag.vector_store import collection_document_count, get_vector_store
19
  from storage.audit_store import persist_query_audit
20
 
21
  router = APIRouter(prefix="/query", tags=["query"])
22
 
23
 
24
+ def _model_used_label(settings: Settings) -> str:
25
+ provider = settings.llm_provider.lower()
26
+ if provider == "openai":
27
+ return settings.openai_model
28
+ if provider == "ollama":
29
+ return settings.ollama_chat_model
30
+ if provider == "anthropic":
31
+ return settings.anthropic_model
32
+ if provider == "huggingface":
33
+ return settings.huggingface_model
34
+ return f"{provider}:unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
+ def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
38
+ citations: list[SourceCitation] = []
39
+ for chunk in chunks:
40
+ page = chunk.page if chunk.page is not None else 0
41
+ score = float(chunk.score) if chunk.score is not None else 0.0
42
+ citations.append(
43
+ SourceCitation(
44
+ document_name=chunk.source or "unknown",
45
+ page_number=page,
46
+ chunk_text=chunk.text,
47
+ relevance_score=score,
48
+ )
49
  )
50
+ return citations
 
 
 
51
 
52
+
53
+ async def _run_ask(
54
+ settings: Settings,
55
+ payload: QueryRequest,
56
+ ) -> AskQueryResponse:
57
+ top_k = payload.top_k
58
+ t0 = time.perf_counter()
59
+ embedding_function = create_embedding_function()
60
+ vector_store = get_vector_store(
61
+ persist_directory=settings.chroma_persist_directory,
62
+ collection_name=payload.collection_name or "default",
63
+ embedding_function=embedding_function,
64
+ )
65
+ chunks = retrieve_chunks(vector_store, payload.question, top_k)
66
+ answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
67
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
68
+ citations = _chunks_to_citations(chunks)
69
+ query_id = str(uuid4())
70
+ ts = datetime.now(timezone.utc)
71
+ response = AskQueryResponse(
72
+ query_id=query_id,
73
+ question=payload.question,
74
  answer=answer,
75
+ sources=citations,
76
+ model_used=_model_used_label(settings),
77
+ tokens_used=tokens_used,
78
+ response_time_ms=elapsed_ms,
79
+ timestamp=ts,
80
+ )
81
+ await persist_query_audit(
82
+ settings.audit_db_path,
83
+ query_id=query_id,
84
+ action="query",
85
+ user_id=payload.user_id,
86
+ question=payload.question,
87
+ collection_name=payload.collection_name or "default",
88
+ answer=answer,
89
+ sources=citations,
90
+ model_used=response.model_used,
91
+ tokens_used=tokens_used,
92
+ response_time_ms=elapsed_ms,
93
+ kind="ask",
94
  )
 
 
 
 
 
 
 
 
 
 
95
  return response
96
 
97
 
98
+ async def _run_summarise(
99
+ settings: Settings,
100
+ payload: SummariseRequest,
101
+ ) -> SummariseQueryResponse:
102
+ top_k = settings.top_k_results
103
  retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
104
  audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
105
+ t0 = time.perf_counter()
106
+ embedding_function = create_embedding_function()
107
+ vector_store = get_vector_store(
108
+ persist_directory=settings.chroma_persist_directory,
109
+ collection_name=payload.collection_name,
110
+ embedding_function=embedding_function,
111
+ )
112
+ chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
113
+ summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
114
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
115
+ citations = _chunks_to_citations(chunks)
116
+ doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
117
+ query_id = str(uuid4())
118
+ ts = datetime.now(timezone.utc)
119
+ response = SummariseQueryResponse(
120
+ query_id=query_id,
121
+ summary=summary,
122
+ document_count=doc_count,
123
+ sources=citations,
124
+ timestamp=ts,
125
+ )
126
+ await persist_query_audit(
127
+ settings.audit_db_path,
128
+ query_id=query_id,
129
+ action="summarise",
130
+ user_id=payload.user_id,
131
+ question=audit_question,
132
+ collection_name=payload.collection_name,
133
+ answer=summary,
134
+ sources=citations,
135
+ model_used=_model_used_label(settings),
136
+ tokens_used=tokens_used,
137
+ response_time_ms=elapsed_ms,
138
+ kind="summarise",
139
+ )
140
+ return response
141
+
142
+
143
+ @router.post("/ask", response_model=AskQueryResponse)
144
+ async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
145
+ settings = get_settings()
146
  try:
147
+ return await _run_ask(settings, payload)
 
 
 
 
 
 
 
148
  except Exception as exc:
149
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
150
 
151
+
152
+ @router.post("/summarise", response_model=SummariseQueryResponse)
153
+ async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
154
+ settings = get_settings()
 
 
 
 
155
  try:
156
+ return await _run_summarise(settings, payload)
 
 
 
 
 
 
157
  except Exception as exc:
158
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
 
159
 
160
 
161
  legacy_query_router = APIRouter(tags=["query"])
162
 
163
 
164
+ @legacy_query_router.post("/query", response_model=AskQueryResponse)
165
+ async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
166
  """Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
167
  return await ask_endpoint(payload)
models/requests.py CHANGED
@@ -1,16 +1,20 @@
 
 
1
  from pydantic import BaseModel, ConfigDict, Field, HttpUrl
2
 
3
 
4
  class QueryRequest(BaseModel):
5
  model_config = ConfigDict(extra="forbid")
6
 
7
- question: str = Field(min_length=1, max_length=8000, description="The question to ask the document")
8
- collection_name: str = Field(
9
  default="default",
10
  min_length=1,
11
  max_length=256,
12
- description="The name of the collection to ask the question from",
13
  )
 
 
14
 
15
 
16
  class SummariseRequest(BaseModel):
@@ -20,37 +24,50 @@ class SummariseRequest(BaseModel):
20
  default="default",
21
  min_length=1,
22
  max_length=256,
23
- description="Chroma collection to summarise from",
24
  )
25
  focus: str | None = Field(
26
  default=None,
27
  max_length=8000,
28
- description="Optional angle or scope for retrieval and the summary (e.g. 'contract payment terms')",
29
  )
 
30
 
31
 
32
- class IngestUrlRequest(BaseModel):
33
  model_config = ConfigDict(extra="forbid")
34
 
35
- url: HttpUrl = Field(description="HTTP(S) URL to a PDF, TXT, or Markdown document")
36
- collection_name: str = Field(
 
 
 
 
37
  default="default",
38
  min_length=1,
39
  max_length=256,
40
  description="Target Chroma collection name",
41
  )
42
 
43
- class IngestUploadRequest(BaseModel):
44
- model_config = ConfigDict(extra="forbid")
45
- collection_name: str = Field(default="default", min_length=1, max_length=256, description="The name of the collection to upload the document to")
46
- filename: str = Field(min_length=1, max_length=1024, description="The name of the file to upload")
47
 
48
  class JobsListParams(BaseModel):
49
  model_config = ConfigDict(extra="forbid")
50
- limit: int = Field(default=10, ge=1, le=100, description="The limit of the jobs to list")
51
- offset: int = Field(default=0, ge=0, description="The offset of the jobs to list")
 
 
52
 
53
  class AuditListParams(BaseModel):
54
  model_config = ConfigDict(extra="forbid")
55
- limit: int = Field(default=10, ge=1, le=100, description="The limit of the audit to list")
56
- offset: int = Field(default=0, ge=0, description="The offset of the audit to list")
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
  from pydantic import BaseModel, ConfigDict, Field, HttpUrl
4
 
5
 
6
  class QueryRequest(BaseModel):
7
  model_config = ConfigDict(extra="forbid")
8
 
9
+ question: str = Field(min_length=5, max_length=2000, description="Natural language question")
10
+ collection_name: Optional[str] = Field(
11
  default="default",
12
  min_length=1,
13
  max_length=256,
14
+ description="Chroma collection to search",
15
  )
16
+ top_k: int = Field(default=5, ge=1, le=20, description="Number of chunks to retrieve")
17
+ user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
18
 
19
 
20
  class SummariseRequest(BaseModel):
 
24
  default="default",
25
  min_length=1,
26
  max_length=256,
27
+ description="Chroma collection to summarise",
28
  )
29
  focus: str | None = Field(
30
  default=None,
31
  max_length=8000,
32
+ description="Optional angle or scope for retrieval and the summary",
33
  )
34
+ user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
35
 
36
 
37
+ class URLIngestRequest(BaseModel):
38
  model_config = ConfigDict(extra="forbid")
39
 
40
+ urls: list[HttpUrl] = Field(
41
+ min_length=1,
42
+ max_length=100,
43
+ description="One or more HTTP(S) URLs to PDF, TXT, or Markdown documents",
44
+ )
45
+ collection_name: Optional[str] = Field(
46
  default="default",
47
  min_length=1,
48
  max_length=256,
49
  description="Target Chroma collection name",
50
  )
51
 
 
 
 
 
52
 
53
  class JobsListParams(BaseModel):
54
  model_config = ConfigDict(extra="forbid")
55
+
56
+ limit: int = Field(default=10, ge=1, le=100, description="Max jobs to return")
57
+ offset: int = Field(default=0, ge=0, description="Offset for pagination")
58
+
59
 
60
  class AuditListParams(BaseModel):
61
  model_config = ConfigDict(extra="forbid")
62
+
63
+ limit: int = Field(default=50, ge=1, le=100, description="Max log entries to return")
64
+ offset: int = Field(default=0, ge=0, description="Offset for pagination")
65
+ user_id: str | None = Field(default=None, max_length=256, description="Filter by user id")
66
+ from_date: str | None = Field(
67
+ default=None,
68
+ description="ISO 8601 datetime lower bound (inclusive) on timestamp",
69
+ )
70
+ to_date: str | None = Field(
71
+ default=None,
72
+ description="ISO 8601 datetime upper bound (inclusive) on timestamp",
73
+ )
models/responses.py CHANGED
@@ -1,102 +1,130 @@
 
 
1
  from pydantic import BaseModel, Field
2
 
3
 
4
- class QueryResultItem(BaseModel):
5
- text: str | None = None
6
- score: float | None = None
7
 
8
- class QuerySourceItem(BaseModel):
9
- source: str
10
- page: int | None = None
11
- chunk_index: int | None = None
12
- score: float | None = None
13
- excerpt: str | None = None
14
 
15
- class QueryResponse(BaseModel):
16
- status: str
17
- message: str
18
- answer: str | None = None
19
- sources: list[QuerySourceItem] = Field(default_factory=list)
20
- results: list[QueryResultItem] = Field(default_factory=list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class IngestUploadResponse(BaseModel):
 
23
  status: str
 
 
24
  message: str
 
 
 
25
  job_id: str
26
- document_ids: list[str] = Field(default_factory=list)
 
 
27
 
28
 
29
  class CollectionItem(BaseModel):
30
  name: str
 
 
31
 
32
 
33
  class IngestCollectionsResponse(BaseModel):
34
- status: str
35
- message: str
36
  collections: list[CollectionItem] = Field(default_factory=list)
 
37
 
38
 
39
  class IngestDeleteCollectionResponse(BaseModel):
40
- status: str
41
  message: str
42
- collection_name: str
43
 
44
- class JobSummary(BaseModel):
45
- job_id: str
46
- status: str
47
- collection_name: str | None = None
48
- filename: str | None = None
49
- created_at: str | None = None
50
 
51
- class JobListResponse(BaseModel):
52
- status: str
53
- message: str
54
- jobs: list[JobSummary] = Field(default_factory=list)
55
 
56
 
57
- class IngestJobDetail(BaseModel):
58
  job_id: str
59
  status: str
60
- collection_name: str
61
- filename: str
62
- message: str
63
- document_ids: list[str] = Field(default_factory=list)
64
- created_at: str
65
- updated_at: str
 
66
 
67
 
68
- class IngestJobDetailResponse(BaseModel):
 
69
  status: str
70
- message: str
71
- job: IngestJobDetail | None = None
72
 
73
- class AuditEvent(BaseModel):
74
- event_id: str
75
- action: str
76
- question: str | None = None
77
- collection_name: str | None = None
78
- created_at: str | None = None
79
 
80
- class AuditListResponse(BaseModel):
81
- status: str
82
- message: str
83
- events: list[AuditEvent] = Field(default_factory=list)
 
 
84
 
85
 
86
- class AuditDetail(BaseModel):
87
- event_id: str
88
- action: str
89
  question: str
90
- collection_name: str
91
- answer: str | None = None
92
- status: str
93
- message: str
94
- sources: list[QuerySourceItem] = Field(default_factory=list)
95
- results: list[QueryResultItem] = Field(default_factory=list)
96
- created_at: str
97
 
98
 
99
- class AuditDetailResponse(BaseModel):
100
- status: str
101
- message: str
102
- event: AuditDetail | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
  from pydantic import BaseModel, Field
4
 
5
 
6
+ # --- Shared citations (spec-shaped) ---
 
 
7
 
 
 
 
 
 
 
8
 
9
+ class SourceCitation(BaseModel):
10
+ document_name: str
11
+ page_number: int
12
+ chunk_text: str
13
+ relevance_score: float
14
+
15
+
16
+ # --- Query: ask ---
17
+
18
+
19
+ class AskQueryResponse(BaseModel):
20
+ query_id: str
21
+ question: str
22
+ answer: str
23
+ sources: list[SourceCitation] = Field(default_factory=list)
24
+ model_used: str
25
+ tokens_used: int
26
+ response_time_ms: int
27
+ timestamp: datetime
28
+
29
+
30
+ # --- Query: summarise ---
31
+
32
+
33
+ class SummariseQueryResponse(BaseModel):
34
+ query_id: str
35
+ summary: str
36
+ document_count: int
37
+ sources: list[SourceCitation] = Field(default_factory=list)
38
+ timestamp: datetime
39
+
40
+
41
+ # --- Ingest ---
42
+
43
 
44
  class IngestUploadResponse(BaseModel):
45
+ job_id: str
46
  status: str
47
+ total_files: int
48
+ filenames: list[str]
49
  message: str
50
+
51
+
52
+ class UrlIngestResponse(BaseModel):
53
  job_id: str
54
+ status: str
55
+ total_urls: int
56
+ message: str
57
 
58
 
59
  class CollectionItem(BaseModel):
60
  name: str
61
+ document_count: int
62
+ created_at: datetime | None = None
63
 
64
 
65
  class IngestCollectionsResponse(BaseModel):
 
 
66
  collections: list[CollectionItem] = Field(default_factory=list)
67
+ total: int
68
 
69
 
70
  class IngestDeleteCollectionResponse(BaseModel):
 
71
  message: str
72
+ documents_removed: int
73
 
 
 
 
 
 
 
74
 
75
+ # --- Jobs ---
 
 
 
76
 
77
 
78
+ class JobStatusResponse(BaseModel):
79
  job_id: str
80
  status: str
81
+ total_files: int
82
+ processed_files: int
83
+ failed_files: int
84
+ progress_percent: int
85
+ started_at: datetime | None
86
+ completed_at: datetime | None
87
+ errors: list[str] = Field(default_factory=list)
88
 
89
 
90
+ class JobListItem(BaseModel):
91
+ job_id: str
92
  status: str
93
+ total_files: int
94
+ completed_at: datetime | None = None
95
 
 
 
 
 
 
 
96
 
97
+ class JobListResponse(BaseModel):
98
+ jobs: list[JobListItem] = Field(default_factory=list)
99
+ total: int
100
+
101
+
102
+ # --- Audit ---
103
 
104
 
105
+ class AuditLogEntry(BaseModel):
106
+ query_id: str
107
+ user_id: str
108
  question: str
109
+ answer_summary: str
110
+ sources_count: int
111
+ model_used: str | None
112
+ timestamp: datetime
 
 
 
113
 
114
 
115
+ class AuditLogsResponse(BaseModel):
116
+ logs: list[AuditLogEntry] = Field(default_factory=list)
117
+ total: int
118
+ limit: int
119
+ offset: int
120
+
121
+
122
+ class AuditLogDetailResponse(BaseModel):
123
+ query_id: str
124
+ user_id: str
125
+ question: str
126
+ full_answer: str
127
+ sources: list[SourceCitation] = Field(default_factory=list)
128
+ model_used: str | None
129
+ tokens_used: int | None
130
+ timestamp: datetime
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ python_files = test_*.py
rag/retriever.py CHANGED
@@ -27,9 +27,27 @@ except ImportError:
27
 
28
  from api.config import Settings
29
 
30
- NO_MATCH_ANSWER = "I could not find this information in the uploaded documents."
31
  MIN_RELEVANCE_SCORE = 0.15
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  @dataclass
35
  class RetrievedChunk:
@@ -62,31 +80,19 @@ SUMMARY_RETRIEVAL_QUERY = (
62
  )
63
 
64
 
65
- def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> str:
66
  ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
67
  if not ranked_chunks:
68
- return NO_MATCH_ANSWER
69
 
70
  llm = _create_chat_model(settings)
71
  prompt_context = _format_context(ranked_chunks)
72
- messages = [
73
- SystemMessage(
74
- content=(
75
- "You answer questions using only the provided context from uploaded documents. "
76
- "If the answer is not in context, say you do not know."
77
- )
78
- ),
79
- HumanMessage(
80
- content=(
81
- f"Question: {question}\n\n"
82
- f"Context:\n{prompt_context}\n\n"
83
- "Return a concise grounded answer."
84
- )
85
- ),
86
- ]
87
  response = llm.invoke(messages)
88
  answer = _extract_message_text(response).strip()
89
- return answer or NO_MATCH_ANSWER
 
90
 
91
 
92
  def summarise_with_grounding(
@@ -94,10 +100,10 @@ def summarise_with_grounding(
94
  *,
95
  focus: str | None,
96
  chunks: list[RetrievedChunk],
97
- ) -> str:
98
  ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
99
  if not ranked_chunks:
100
- return NO_MATCH_ANSWER
101
 
102
  llm = _create_chat_model(settings)
103
  prompt_context = _format_context(ranked_chunks)
@@ -123,7 +129,8 @@ def summarise_with_grounding(
123
  ]
124
  response = llm.invoke(messages)
125
  answer = _extract_message_text(response).strip()
126
- return answer or NO_MATCH_ANSWER
 
127
 
128
 
129
  def _create_chat_model(settings: Settings) -> BaseChatModel:
@@ -186,6 +193,25 @@ def _to_int_or_none(value: object) -> int | None:
186
  return None
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def _extract_message_text(response: object) -> str:
190
  content = getattr(response, "content", "")
191
  if isinstance(content, str):
 
27
 
28
  from api.config import Settings
29
 
30
+ NO_MATCH_ANSWER = "I cannot find this information in the uploaded documents."
31
  MIN_RELEVANCE_SCORE = 0.15
32
 
33
+ # Verbatim from DOCUAUDIT_AI_REQUIREMENTS.md (placeholders filled at runtime).
34
+ DOCUAUDIT_ASK_TEMPLATE = """You are DocuAudit AI, an expert document analyst for consulting environments.
35
+
36
+ RULES:
37
+ 1. Answer ONLY based on the provided document excerpts below.
38
+ 2. If the answer is not in the documents, say: "I cannot find this information in the uploaded documents."
39
+ 3. ALWAYS cite your sources: mention the document name and page number for every claim.
40
+ 4. Be precise and professional. This is a high-stakes consulting environment.
41
+ 5. Do not speculate or add information not present in the documents.
42
+
43
+ DOCUMENT EXCERPTS:
44
+ {context}
45
+
46
+ QUESTION: {question}
47
+
48
+ ANSWER (with source citations):
49
+ """
50
+
51
 
52
  @dataclass
53
  class RetrievedChunk:
 
80
  )
81
 
82
 
83
+ def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> tuple[str, int]:
84
  ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
85
  if not ranked_chunks:
86
+ return NO_MATCH_ANSWER, 0
87
 
88
  llm = _create_chat_model(settings)
89
  prompt_context = _format_context(ranked_chunks)
90
+ user_content = DOCUAUDIT_ASK_TEMPLATE.format(context=prompt_context, question=question)
91
+ messages = [HumanMessage(content=user_content)]
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  response = llm.invoke(messages)
93
  answer = _extract_message_text(response).strip()
94
+ tokens = _extract_usage_tokens(response)
95
+ return (answer or NO_MATCH_ANSWER), tokens
96
 
97
 
98
  def summarise_with_grounding(
 
100
  *,
101
  focus: str | None,
102
  chunks: list[RetrievedChunk],
103
+ ) -> tuple[str, int]:
104
  ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
105
  if not ranked_chunks:
106
+ return NO_MATCH_ANSWER, 0
107
 
108
  llm = _create_chat_model(settings)
109
  prompt_context = _format_context(ranked_chunks)
 
129
  ]
130
  response = llm.invoke(messages)
131
  answer = _extract_message_text(response).strip()
132
+ tokens = _extract_usage_tokens(response)
133
+ return (answer or NO_MATCH_ANSWER), tokens
134
 
135
 
136
  def _create_chat_model(settings: Settings) -> BaseChatModel:
 
193
  return None
194
 
195
 
196
+ def _extract_usage_tokens(response: object) -> int:
197
+ um = getattr(response, "usage_metadata", None)
198
+ if isinstance(um, dict):
199
+ total = um.get("total_tokens")
200
+ if total is not None:
201
+ return int(total)
202
+ inp = int(um.get("input_tokens", 0) or 0)
203
+ out = int(um.get("output_tokens", 0) or 0)
204
+ return inp + out
205
+ rm = getattr(response, "response_metadata", None) or {}
206
+ if isinstance(rm, dict):
207
+ tu = rm.get("token_usage")
208
+ if isinstance(tu, dict):
209
+ if tu.get("total_tokens") is not None:
210
+ return int(tu["total_tokens"])
211
+ return int(tu.get("prompt_tokens", 0) or 0) + int(tu.get("completion_tokens", 0) or 0)
212
+ return 0
213
+
214
+
215
  def _extract_message_text(response: object) -> str:
216
  content = getattr(response, "content", "")
217
  if isinstance(content, str):
rag/vector_store.py CHANGED
@@ -36,8 +36,42 @@ def list_collection_names(persist_directory: str) -> list[str]:
36
  return sorted(c.name for c in client.list_collections())
37
 
38
 
39
- def delete_collection(persist_directory: str, collection_name: str) -> None:
 
40
  Path(persist_directory).mkdir(parents=True, exist_ok=True)
41
  client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
 
 
 
 
 
 
42
  client.delete_collection(name=collection_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
36
  return sorted(c.name for c in client.list_collections())
37
 
38
 
39
+ def delete_collection(persist_directory: str, collection_name: str) -> int:
40
+ """Delete a collection and return the number of documents that were removed (best effort)."""
41
  Path(persist_directory).mkdir(parents=True, exist_ok=True)
42
  client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
43
+ removed = 0
44
+ try:
45
+ col = client.get_collection(name=collection_name)
46
+ removed = int(col.count())
47
+ except Exception:
48
+ removed = 0
49
  client.delete_collection(name=collection_name)
50
+ return removed
51
+
52
+
53
+ def collection_document_count(persist_directory: str, collection_name: str) -> int:
54
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
55
+ client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
56
+ try:
57
+ col = client.get_collection(name=collection_name)
58
+ return int(col.count())
59
+ except Exception:
60
+ return 0
61
+
62
+
63
+ def collection_created_at(persist_directory: str, collection_name: str) -> str | None:
64
+ """Return collection metadata ``created_at`` if present (Chroma-specific)."""
65
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
66
+ client = chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
67
+ try:
68
+ col = client.get_collection(name=collection_name)
69
+ meta = getattr(col, "metadata", None) or {}
70
+ if isinstance(meta, dict):
71
+ raw = meta.get("created_at") or meta.get("created")
72
+ if raw is not None:
73
+ return str(raw)
74
+ except Exception:
75
+ pass
76
+ return None
77
 
storage/audit_store.py CHANGED
@@ -1,11 +1,54 @@
1
  import json
 
2
  from pathlib import Path
3
  from typing import Any
4
  from uuid import uuid4
5
 
6
  import aiosqlite
7
 
8
- from models.responses import AuditDetail, QueryResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  async def init_audit_db(db_path: str) -> None:
@@ -24,80 +67,219 @@ async def init_audit_db(db_path: str) -> None:
24
  message TEXT NOT NULL,
25
  sources_json TEXT NOT NULL,
26
  results_json TEXT NOT NULL,
27
- created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
 
 
 
 
 
 
28
  )
29
  """
30
  )
31
  await conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  async def persist_query_audit(
35
  db_path: str,
36
  *,
 
37
  action: str,
 
38
  question: str,
39
  collection_name: str,
40
- response: QueryResponse,
 
 
 
 
 
 
 
41
  ) -> str:
42
- event_id = str(uuid4())
43
  await init_audit_db(db_path)
 
 
 
44
  async with aiosqlite.connect(db_path) as conn:
45
  await conn.execute(
46
  """
47
  INSERT INTO audit_events (
48
- event_id, action, question, collection_name, answer, status, message, sources_json, results_json
49
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
 
 
50
  """,
51
  (
52
- event_id,
53
  action,
54
  question,
55
  collection_name,
56
- response.answer,
57
- response.status,
58
- response.message,
59
- json.dumps([item.model_dump() for item in response.sources]),
60
- json.dumps([item.model_dump() for item in response.results]),
 
 
 
 
 
 
61
  ),
62
  )
63
  await conn.commit()
64
- return event_id
65
 
66
 
67
- async def list_audit_events(db_path: str, *, limit: int, offset: int) -> list[dict[str, Any]]:
 
 
 
 
 
 
68
  await init_audit_db(db_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  async with aiosqlite.connect(db_path) as conn:
70
  conn.row_factory = aiosqlite.Row
71
  cursor = await conn.execute(
72
- """
73
- SELECT event_id, action, question, collection_name, created_at
74
  FROM audit_events
 
75
  ORDER BY datetime(created_at) DESC, rowid DESC
76
  LIMIT ? OFFSET ?
77
  """,
78
- (limit, offset),
79
  )
80
  rows = await cursor.fetchall()
81
- return [dict(row) for row in rows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
- async def get_audit_event(db_path: str, event_id: str) -> AuditDetail | None:
85
  await init_audit_db(db_path)
86
  async with aiosqlite.connect(db_path) as conn:
87
  conn.row_factory = aiosqlite.Row
88
  cursor = await conn.execute(
89
  """
90
- SELECT event_id, action, question, collection_name, answer, status, message, sources_json, results_json, created_at
91
  FROM audit_events
92
  WHERE event_id = ?
93
  """,
94
- (event_id,),
95
  )
96
  row = await cursor.fetchone()
97
  if row is None:
98
  return None
99
-
100
- payload = dict(row)
101
- payload["sources"] = json.loads(payload.pop("sources_json") or "[]")
102
- payload["results"] = json.loads(payload.pop("results_json") or "[]")
103
- return AuditDetail.model_validate(payload)
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ from datetime import datetime, timezone
3
  from pathlib import Path
4
  from typing import Any
5
  from uuid import uuid4
6
 
7
  import aiosqlite
8
 
9
+ from models.responses import AuditLogDetailResponse, AuditLogEntry, SourceCitation
10
+
11
+
12
+ def _utc_now_iso() -> str:
13
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
14
+
15
+
16
+ def _parse_ts(value: object) -> datetime:
17
+ if value is None or value == "":
18
+ return datetime.now(timezone.utc)
19
+ s = str(value).strip()
20
+ if s.endswith("Z"):
21
+ s = s[:-1] + "+00:00"
22
+ try:
23
+ dt = datetime.fromisoformat(s)
24
+ if dt.tzinfo is None:
25
+ return dt.replace(tzinfo=timezone.utc)
26
+ return dt
27
+ except ValueError:
28
+ return datetime.now(timezone.utc)
29
+
30
+
31
+ async def _migrate_audit_columns(conn: aiosqlite.Connection) -> None:
32
+ cursor = await conn.execute("PRAGMA table_info(audit_events)")
33
+ rows = await cursor.fetchall()
34
+ col_names = {str(r[1]) for r in rows}
35
+ alters: list[str] = []
36
+ if "user_id" not in col_names:
37
+ alters.append("ALTER TABLE audit_events ADD COLUMN user_id TEXT NOT NULL DEFAULT 'anonymous'")
38
+ if "model_used" not in col_names:
39
+ alters.append("ALTER TABLE audit_events ADD COLUMN model_used TEXT")
40
+ if "tokens_used" not in col_names:
41
+ alters.append("ALTER TABLE audit_events ADD COLUMN tokens_used INTEGER")
42
+ if "response_time_ms" not in col_names:
43
+ alters.append("ALTER TABLE audit_events ADD COLUMN response_time_ms INTEGER")
44
+ if "answer_summary" not in col_names:
45
+ alters.append("ALTER TABLE audit_events ADD COLUMN answer_summary TEXT")
46
+ if "kind" not in col_names:
47
+ alters.append("ALTER TABLE audit_events ADD COLUMN kind TEXT NOT NULL DEFAULT 'ask'")
48
+ for stmt in alters:
49
+ await conn.execute(stmt)
50
+ if alters:
51
+ await conn.commit()
52
 
53
 
54
  async def init_audit_db(db_path: str) -> None:
 
67
  message TEXT NOT NULL,
68
  sources_json TEXT NOT NULL,
69
  results_json TEXT NOT NULL,
70
+ created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
71
+ user_id TEXT NOT NULL DEFAULT 'anonymous',
72
+ model_used TEXT,
73
+ tokens_used INTEGER,
74
+ response_time_ms INTEGER,
75
+ answer_summary TEXT,
76
+ kind TEXT NOT NULL DEFAULT 'ask'
77
  )
78
  """
79
  )
80
  await conn.commit()
81
+ await _migrate_audit_columns(conn)
82
+
83
+
84
+ def _summary_from_answer(answer: str, max_len: int = 280) -> str:
85
+ text = (answer or "").strip()
86
+ if len(text) <= max_len:
87
+ return text
88
+ return text[: max_len - 1].rstrip() + "…"
89
+
90
+
91
+ def _sources_to_citations(raw: list[dict[str, Any]]) -> list[SourceCitation]:
92
+ out: list[SourceCitation] = []
93
+ for item in raw:
94
+ if not isinstance(item, dict):
95
+ continue
96
+ if "document_name" in item:
97
+ doc = str(item.get("document_name", ""))
98
+ page = int(item.get("page_number", 0) or 0)
99
+ chunk = str(item.get("chunk_text", ""))
100
+ score = float(item.get("relevance_score", 0.0) or 0.0)
101
+ else:
102
+ doc = str(item.get("source", item.get("document_name", "")))
103
+ p = item.get("page_number", item.get("page"))
104
+ try:
105
+ page = int(p) if p is not None else 0
106
+ except (TypeError, ValueError):
107
+ page = 0
108
+ chunk = str(item.get("chunk_text", item.get("excerpt", item.get("text", ""))))
109
+ s = item.get("relevance_score", item.get("score"))
110
+ try:
111
+ score = float(s) if s is not None else 0.0
112
+ except (TypeError, ValueError):
113
+ score = 0.0
114
+ out.append(
115
+ SourceCitation(
116
+ document_name=doc or "unknown",
117
+ page_number=page,
118
+ chunk_text=chunk,
119
+ relevance_score=score,
120
+ )
121
+ )
122
+ return out
123
 
124
 
125
  async def persist_query_audit(
126
  db_path: str,
127
  *,
128
+ query_id: str,
129
  action: str,
130
+ user_id: str,
131
  question: str,
132
  collection_name: str,
133
+ answer: str,
134
+ sources: list[SourceCitation],
135
+ model_used: str,
136
+ tokens_used: int,
137
+ response_time_ms: int,
138
+ status: str = "success",
139
+ message: str = "ok",
140
+ kind: str = "ask",
141
  ) -> str:
 
142
  await init_audit_db(db_path)
143
+ sources_payload = [s.model_dump(mode="json") for s in sources]
144
+ summary = _summary_from_answer(answer)
145
+ created = _utc_now_iso()
146
  async with aiosqlite.connect(db_path) as conn:
147
  await conn.execute(
148
  """
149
  INSERT INTO audit_events (
150
+ event_id, action, question, collection_name, answer, status, message,
151
+ sources_json, results_json, created_at, user_id, model_used, tokens_used,
152
+ response_time_ms, answer_summary, kind
153
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, '[]', ?, ?, ?, ?, ?, ?, ?)
154
  """,
155
  (
156
+ query_id,
157
  action,
158
  question,
159
  collection_name,
160
+ answer,
161
+ status,
162
+ message,
163
+ json.dumps(sources_payload),
164
+ created,
165
+ user_id,
166
+ model_used,
167
+ tokens_used,
168
+ response_time_ms,
169
+ summary,
170
+ kind,
171
  ),
172
  )
173
  await conn.commit()
174
+ return query_id
175
 
176
 
177
+ async def count_audit_events(
178
+ db_path: str,
179
+ *,
180
+ user_id: str | None = None,
181
+ from_date: str | None = None,
182
+ to_date: str | None = None,
183
+ ) -> int:
184
  await init_audit_db(db_path)
185
+ where, params = _audit_filters(user_id, from_date, to_date)
186
+ async with aiosqlite.connect(db_path) as conn:
187
+ cur = await conn.execute(f"SELECT COUNT(*) AS c FROM audit_events {where}", params)
188
+ row = await cur.fetchone()
189
+ return int(row[0]) if row else 0
190
+
191
+
192
+ def _audit_filters(user_id: str | None, from_date: str | None, to_date: str | None) -> tuple[str, list[Any]]:
193
+ clauses: list[str] = []
194
+ params: list[Any] = []
195
+ if user_id:
196
+ clauses.append("user_id = ?")
197
+ params.append(user_id)
198
+ if from_date:
199
+ clauses.append("datetime(created_at) >= datetime(?)")
200
+ params.append(from_date)
201
+ if to_date:
202
+ clauses.append("datetime(created_at) <= datetime(?)")
203
+ params.append(to_date)
204
+ if not clauses:
205
+ return "", []
206
+ return "WHERE " + " AND ".join(clauses), params
207
+
208
+
209
+ async def list_audit_events(
210
+ db_path: str,
211
+ *,
212
+ limit: int,
213
+ offset: int,
214
+ user_id: str | None = None,
215
+ from_date: str | None = None,
216
+ to_date: str | None = None,
217
+ ) -> tuple[list[AuditLogEntry], int]:
218
+ await init_audit_db(db_path)
219
+ where, fparams = _audit_filters(user_id, from_date, to_date)
220
+ total = await count_audit_events(db_path, user_id=user_id, from_date=from_date, to_date=to_date)
221
  async with aiosqlite.connect(db_path) as conn:
222
  conn.row_factory = aiosqlite.Row
223
  cursor = await conn.execute(
224
+ f"""
225
+ SELECT event_id, user_id, question, answer, answer_summary, sources_json, model_used, created_at
226
  FROM audit_events
227
+ {where}
228
  ORDER BY datetime(created_at) DESC, rowid DESC
229
  LIMIT ? OFFSET ?
230
  """,
231
+ [*fparams, limit, offset],
232
  )
233
  rows = await cursor.fetchall()
234
+ logs: list[AuditLogEntry] = []
235
+ for row in rows:
236
+ src_raw = json.loads(row["sources_json"] or "[]")
237
+ if not isinstance(src_raw, list):
238
+ src_raw = []
239
+ summary_cell = row["answer_summary"]
240
+ summary_text = str(summary_cell).strip() if summary_cell else ""
241
+ if not summary_text:
242
+ summary_text = _summary_from_answer(str(row["answer"] or ""))
243
+ logs.append(
244
+ AuditLogEntry(
245
+ query_id=str(row["event_id"]),
246
+ user_id=str(row["user_id"] or "anonymous"),
247
+ question=str(row["question"]),
248
+ answer_summary=summary_text,
249
+ sources_count=len(src_raw),
250
+ model_used=row["model_used"],
251
+ timestamp=_parse_ts(row["created_at"]),
252
+ )
253
+ )
254
+ return logs, total
255
 
256
 
257
+ async def get_audit_event(db_path: str, query_id: str) -> AuditLogDetailResponse | None:
258
  await init_audit_db(db_path)
259
  async with aiosqlite.connect(db_path) as conn:
260
  conn.row_factory = aiosqlite.Row
261
  cursor = await conn.execute(
262
  """
263
+ SELECT event_id, user_id, question, answer, sources_json, model_used, tokens_used, created_at
264
  FROM audit_events
265
  WHERE event_id = ?
266
  """,
267
+ (query_id,),
268
  )
269
  row = await cursor.fetchone()
270
  if row is None:
271
  return None
272
+ src_raw = json.loads(row["sources_json"] or "[]")
273
+ if not isinstance(src_raw, list):
274
+ src_raw = []
275
+ citations = _sources_to_citations(src_raw)
276
+ return AuditLogDetailResponse(
277
+ query_id=str(row["event_id"]),
278
+ user_id=str(row["user_id"] or "anonymous"),
279
+ question=str(row["question"]),
280
+ full_answer=str(row["answer"] or ""),
281
+ sources=citations,
282
+ model_used=row["model_used"],
283
+ tokens_used=row["tokens_used"],
284
+ timestamp=_parse_ts(row["created_at"]),
285
+ )
storage/job_store.py CHANGED
@@ -1,11 +1,64 @@
1
  import json
 
2
  from pathlib import Path
3
  from typing import Any
4
  from uuid import uuid4
5
 
6
  import aiosqlite
7
 
8
- from models.responses import IngestJobDetail
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  async def init_jobs_db(db_path: str) -> None:
@@ -22,74 +75,134 @@ async def init_jobs_db(db_path: str) -> None:
22
  message TEXT NOT NULL DEFAULT '',
23
  document_ids_json TEXT NOT NULL DEFAULT '[]',
24
  created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
25
- updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
 
 
 
 
 
 
 
26
  )
27
  """
28
  )
29
  await conn.commit()
 
30
 
31
 
32
  async def create_ingest_job(
33
  db_path: str,
34
  *,
35
  collection_name: str,
36
- filename: str,
37
  ) -> str:
 
 
38
  job_id = str(uuid4())
 
 
 
39
  await init_jobs_db(db_path)
40
  async with aiosqlite.connect(db_path) as conn:
41
  await conn.execute(
42
  """
43
  INSERT INTO ingest_jobs (
44
- job_id, status, collection_name, filename, message, document_ids_json
45
- ) VALUES (?, 'queued', ?, ?, '', '[]')
 
46
  """,
47
- (job_id, collection_name, filename),
48
  )
49
  await conn.commit()
50
  return job_id
51
 
52
 
53
- async def update_ingest_job(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  db_path: str,
55
  job_id: str,
56
  *,
57
- status: str,
 
 
58
  message: str | None = None,
59
- document_ids: list[str] | None = None,
60
  ) -> None:
61
  await init_jobs_db(db_path)
62
  async with aiosqlite.connect(db_path) as conn:
63
- if document_ids is not None:
64
- await conn.execute(
65
- """
66
- UPDATE ingest_jobs
67
- SET status = ?, message = COALESCE(?, message), document_ids_json = ?,
68
- updated_at = CURRENT_TIMESTAMP
69
- WHERE job_id = ?
70
- """,
71
- (status, message, json.dumps(document_ids), job_id),
72
- )
73
- else:
74
- await conn.execute(
75
- """
76
- UPDATE ingest_jobs
77
- SET status = ?, message = COALESCE(?, message),
78
- updated_at = CURRENT_TIMESTAMP
79
- WHERE job_id = ?
80
- """,
81
- (status, message, job_id),
82
- )
83
  await conn.commit()
84
 
85
 
86
- async def get_ingest_job(db_path: str, job_id: str) -> IngestJobDetail | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  await init_jobs_db(db_path)
88
  async with aiosqlite.connect(db_path) as conn:
89
  conn.row_factory = aiosqlite.Row
90
  cursor = await conn.execute(
91
  """
92
- SELECT job_id, status, collection_name, filename, message, document_ids_json, created_at, updated_at
 
93
  FROM ingest_jobs
94
  WHERE job_id = ?
95
  """,
@@ -98,18 +211,39 @@ async def get_ingest_job(db_path: str, job_id: str) -> IngestJobDetail | None:
98
  row = await cursor.fetchone()
99
  if row is None:
100
  return None
101
- payload = dict(row)
102
- payload["document_ids"] = json.loads(payload.pop("document_ids_json") or "[]")
103
- return IngestJobDetail.model_validate(payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
- async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> list[dict[str, Any]]:
107
  await init_jobs_db(db_path)
108
  async with aiosqlite.connect(db_path) as conn:
109
  conn.row_factory = aiosqlite.Row
 
 
 
110
  cursor = await conn.execute(
111
  """
112
- SELECT job_id, status, collection_name, filename, created_at
113
  FROM ingest_jobs
114
  ORDER BY datetime(updated_at) DESC, rowid DESC
115
  LIMIT ? OFFSET ?
@@ -117,4 +251,30 @@ async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> list[dic
117
  (limit, offset),
118
  )
119
  rows = await cursor.fetchall()
120
- return [dict(row) for row in rows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ from datetime import datetime, timezone
3
  from pathlib import Path
4
  from typing import Any
5
  from uuid import uuid4
6
 
7
  import aiosqlite
8
 
9
+ from models.responses import JobListItem, JobStatusResponse
10
+
11
+
12
+ def _utc_now_iso() -> str:
13
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
14
+
15
+
16
+ async def _migrate_jobs_columns(conn: aiosqlite.Connection) -> None:
17
+ cursor = await conn.execute("PRAGMA table_info(ingest_jobs)")
18
+ rows = await cursor.fetchall()
19
+ col_names = {str(r[1]) for r in rows}
20
+ alters: list[str] = []
21
+ if "total_files" not in col_names:
22
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN total_files INTEGER NOT NULL DEFAULT 1")
23
+ if "processed_files" not in col_names:
24
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN processed_files INTEGER NOT NULL DEFAULT 0")
25
+ if "failed_files" not in col_names:
26
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN failed_files INTEGER NOT NULL DEFAULT 0")
27
+ if "filenames_json" not in col_names:
28
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN filenames_json TEXT NOT NULL DEFAULT '[]'")
29
+ if "errors_json" not in col_names:
30
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN errors_json TEXT NOT NULL DEFAULT '[]'")
31
+ if "started_at" not in col_names:
32
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN started_at TEXT")
33
+ if "completed_at" not in col_names:
34
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN completed_at TEXT")
35
+ for stmt in alters:
36
+ await conn.execute(stmt)
37
+ if alters:
38
+ await conn.commit()
39
+ await _backfill_job_filenames(conn)
40
+
41
+
42
+ async def _backfill_job_filenames(conn: aiosqlite.Connection) -> None:
43
+ conn.row_factory = aiosqlite.Row
44
+ cursor = await conn.execute("SELECT job_id, filename, filenames_json, total_files FROM ingest_jobs")
45
+ rows = await cursor.fetchall()
46
+ for row in rows:
47
+ raw = row["filenames_json"] or "[]"
48
+ try:
49
+ parsed: Any = json.loads(raw)
50
+ except json.JSONDecodeError:
51
+ parsed = []
52
+ if not parsed and row["filename"]:
53
+ await conn.execute(
54
+ """
55
+ UPDATE ingest_jobs
56
+ SET filenames_json = ?, total_files = CASE WHEN total_files IS NULL OR total_files < 1 THEN 1 ELSE total_files END
57
+ WHERE job_id = ?
58
+ """,
59
+ (json.dumps([row["filename"]]), row["job_id"]),
60
+ )
61
+ await conn.commit()
62
 
63
 
64
  async def init_jobs_db(db_path: str) -> None:
 
75
  message TEXT NOT NULL DEFAULT '',
76
  document_ids_json TEXT NOT NULL DEFAULT '[]',
77
  created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
78
+ updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
79
+ total_files INTEGER NOT NULL DEFAULT 1,
80
+ processed_files INTEGER NOT NULL DEFAULT 0,
81
+ failed_files INTEGER NOT NULL DEFAULT 0,
82
+ filenames_json TEXT NOT NULL DEFAULT '[]',
83
+ errors_json TEXT NOT NULL DEFAULT '[]',
84
+ started_at TEXT,
85
+ completed_at TEXT
86
  )
87
  """
88
  )
89
  await conn.commit()
90
+ await _migrate_jobs_columns(conn)
91
 
92
 
93
  async def create_ingest_job(
94
  db_path: str,
95
  *,
96
  collection_name: str,
97
+ filenames: list[str],
98
  ) -> str:
99
+ if not filenames:
100
+ raise ValueError("filenames must not be empty")
101
  job_id = str(uuid4())
102
+ primary = filenames[0]
103
+ names_json = json.dumps(filenames)
104
+ total = len(filenames)
105
  await init_jobs_db(db_path)
106
  async with aiosqlite.connect(db_path) as conn:
107
  await conn.execute(
108
  """
109
  INSERT INTO ingest_jobs (
110
+ job_id, status, collection_name, filename, message, document_ids_json,
111
+ total_files, processed_files, failed_files, filenames_json, errors_json
112
+ ) VALUES (?, 'queued', ?, ?, '', '[]', ?, 0, 0, ?, '[]')
113
  """,
114
+ (job_id, collection_name, primary, total, names_json),
115
  )
116
  await conn.commit()
117
  return job_id
118
 
119
 
120
+ async def mark_job_processing(db_path: str, job_id: str) -> None:
121
+ await init_jobs_db(db_path)
122
+ started = _utc_now_iso()
123
+ async with aiosqlite.connect(db_path) as conn:
124
+ await conn.execute(
125
+ """
126
+ UPDATE ingest_jobs
127
+ SET status = 'processing', message = 'Ingestion in progress.', started_at = COALESCE(started_at, ?),
128
+ updated_at = CURRENT_TIMESTAMP
129
+ WHERE job_id = ?
130
+ """,
131
+ (started, job_id),
132
+ )
133
+ await conn.commit()
134
+
135
+
136
+ async def update_job_progress(
137
  db_path: str,
138
  job_id: str,
139
  *,
140
+ processed_files: int,
141
+ failed_files: int,
142
+ errors: list[str],
143
  message: str | None = None,
 
144
  ) -> None:
145
  await init_jobs_db(db_path)
146
  async with aiosqlite.connect(db_path) as conn:
147
+ await conn.execute(
148
+ """
149
+ UPDATE ingest_jobs
150
+ SET processed_files = ?, failed_files = ?, errors_json = ?,
151
+ message = COALESCE(?, message), updated_at = CURRENT_TIMESTAMP
152
+ WHERE job_id = ?
153
+ """,
154
+ (processed_files, failed_files, json.dumps(errors), message, job_id),
155
+ )
 
 
 
 
 
 
 
 
 
 
 
156
  await conn.commit()
157
 
158
 
159
+ async def complete_ingest_job(
160
+ db_path: str,
161
+ job_id: str,
162
+ *,
163
+ document_ids: list[str],
164
+ message: str,
165
+ ) -> None:
166
+ await init_jobs_db(db_path)
167
+ completed = _utc_now_iso()
168
+ async with aiosqlite.connect(db_path) as conn:
169
+ await conn.execute(
170
+ """
171
+ UPDATE ingest_jobs
172
+ SET status = 'completed', message = ?, document_ids_json = ?,
173
+ completed_at = ?, updated_at = CURRENT_TIMESTAMP
174
+ WHERE job_id = ?
175
+ """,
176
+ (message, json.dumps(document_ids), completed, job_id),
177
+ )
178
+ await conn.commit()
179
+
180
+
181
+ async def fail_ingest_job(db_path: str, job_id: str, *, message: str, errors: list[str] | None = None) -> None:
182
+ await init_jobs_db(db_path)
183
+ completed = _utc_now_iso()
184
+ err_json = json.dumps(errors or [message])
185
+ async with aiosqlite.connect(db_path) as conn:
186
+ await conn.execute(
187
+ """
188
+ UPDATE ingest_jobs
189
+ SET status = 'failed', message = ?, errors_json = ?, completed_at = ?,
190
+ updated_at = CURRENT_TIMESTAMP
191
+ WHERE job_id = ?
192
+ """,
193
+ (message, err_json, completed, job_id),
194
+ )
195
+ await conn.commit()
196
+
197
+
198
+ async def get_job_status(db_path: str, job_id: str) -> JobStatusResponse | None:
199
  await init_jobs_db(db_path)
200
  async with aiosqlite.connect(db_path) as conn:
201
  conn.row_factory = aiosqlite.Row
202
  cursor = await conn.execute(
203
  """
204
+ SELECT job_id, status, total_files, processed_files, failed_files, errors_json,
205
+ started_at, completed_at, message
206
  FROM ingest_jobs
207
  WHERE job_id = ?
208
  """,
 
211
  row = await cursor.fetchone()
212
  if row is None:
213
  return None
214
+ data = dict(row)
215
+ total = int(data["total_files"] or 0)
216
+ processed = int(data["processed_files"] or 0)
217
+ failed = int(data["failed_files"] or 0)
218
+ denom = total if total > 0 else 1
219
+ progress = int(min(100, max(0, round((processed + failed) / denom * 100))))
220
+ errors = json.loads(data.get("errors_json") or "[]")
221
+ if not isinstance(errors, list):
222
+ errors = [str(errors)]
223
+ errors_str = [str(e) for e in errors]
224
+ return JobStatusResponse(
225
+ job_id=str(data["job_id"]),
226
+ status=str(data["status"]),
227
+ total_files=total,
228
+ processed_files=processed,
229
+ failed_files=failed,
230
+ progress_percent=progress,
231
+ started_at=_parse_dt(data.get("started_at")),
232
+ completed_at=_parse_dt(data.get("completed_at")),
233
+ errors=errors_str,
234
+ )
235
 
236
 
237
+ async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> tuple[list[JobListItem], int]:
238
  await init_jobs_db(db_path)
239
  async with aiosqlite.connect(db_path) as conn:
240
  conn.row_factory = aiosqlite.Row
241
+ cur_total = await conn.execute("SELECT COUNT(*) AS c FROM ingest_jobs")
242
+ total_row = await cur_total.fetchone()
243
+ total = int(total_row["c"]) if total_row else 0
244
  cursor = await conn.execute(
245
  """
246
+ SELECT job_id, status, total_files, completed_at
247
  FROM ingest_jobs
248
  ORDER BY datetime(updated_at) DESC, rowid DESC
249
  LIMIT ? OFFSET ?
 
251
  (limit, offset),
252
  )
253
  rows = await cursor.fetchall()
254
+ items = [
255
+ JobListItem(
256
+ job_id=str(r["job_id"]),
257
+ status=str(r["status"]),
258
+ total_files=int(r["total_files"] or 0),
259
+ completed_at=_parse_dt(r["completed_at"]),
260
+ )
261
+ for r in rows
262
+ ]
263
+ return items, total
264
+
265
+
266
+ def _parse_dt(value: object) -> datetime | None:
267
+ if value is None or value == "":
268
+ return None
269
+ s = str(value).strip()
270
+ if not s:
271
+ return None
272
+ if s.endswith("Z"):
273
+ s = s[:-1] + "+00:00"
274
+ try:
275
+ dt = datetime.fromisoformat(s)
276
+ if dt.tzinfo is None:
277
+ return dt.replace(tzinfo=timezone.utc)
278
+ return dt
279
+ except ValueError:
280
+ return None
streamlit_app.py CHANGED
@@ -71,17 +71,43 @@ def _fmt_api_error(exc: httpx.HTTPStatusError) -> str:
71
  return f"HTTP {exc.response.status_code}"
72
 
73
 
74
- def _post_query_ask(client: httpx.Client, *, question: str, collection_name: str) -> httpx.Response:
75
- """Milestone 8 uses POST /query/ask; older servers only expose POST /query."""
76
- body = {"question": question.strip(), "collection_name": collection_name}
 
 
 
 
 
 
 
 
 
 
 
 
77
  r = client.post("/query/ask", json=body)
78
  if r.status_code == 404:
79
  r = client.post("/query", json=body)
80
  return r
81
 
82
 
83
- def _get_audit_logs(client: httpx.Client, *, limit: int, offset: int) -> httpx.Response:
84
- params = {"limit": limit, "offset": offset}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  r = client.get("/audit/logs", params=params)
86
  if r.status_code == 404:
87
  r = client.get("/audit", params=params)
@@ -156,7 +182,7 @@ def main() -> None:
156
  st.warning("Choose a file first.")
157
  else:
158
  try:
159
- files = {"file": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")}
160
  data = {"collection_name": up_collection}
161
  with _client() as c:
162
  r = c.post("/ingest/upload", files=files, data=data)
@@ -185,7 +211,10 @@ def main() -> None:
185
  else:
186
  try:
187
  with _client() as c:
188
- r = c.post("/ingest/url", json={"url": ingest_url.strip(), "collection_name": url_collection})
 
 
 
189
  r.raise_for_status()
190
  out = r.json()
191
  st.success(out.get("message", "Queued"))
@@ -206,10 +235,10 @@ def main() -> None:
206
  r = c.get("/ingest/collections")
207
  r.raise_for_status()
208
  cols = r.json()
209
- names = [x["name"] for x in cols.get("collections", [])]
210
- st.write(cols.get("message", ""))
211
- if names:
212
- st.dataframe({"name": names}, hide_index=True, use_container_width=True)
213
  else:
214
  st.info("No collections yet.")
215
  except httpx.HTTPStatusError as e:
@@ -228,7 +257,10 @@ def main() -> None:
228
  with _client() as c:
229
  r = c.delete(f"/ingest/collection/{del_name.strip()}")
230
  r.raise_for_status()
231
- st.success(r.json().get("message", "Deleted"))
 
 
 
232
  except httpx.HTTPStatusError as e:
233
  st.error(_fmt_api_error(e))
234
  except httpx.ConnectError as e:
@@ -250,7 +282,7 @@ def main() -> None:
250
  r.raise_for_status()
251
  payload = r.json()
252
  jobs: list[dict[str, Any]] = payload.get("jobs", [])
253
- st.caption(payload.get("message", ""))
254
  if jobs:
255
  st.dataframe(jobs, hide_index=True, use_container_width=True)
256
  else:
@@ -293,9 +325,8 @@ def main() -> None:
293
  r = c.get(f"/jobs/{job_id.strip()}")
294
  r.raise_for_status()
295
  body = r.json()
296
- job = body.get("job") or {}
297
- st_ = job.get("status", "")
298
- status_ph.write(f"Poll {i + 1}: **{st_}** — {job.get('message', '')}")
299
  if st_ in ("completed", "failed"):
300
  st.json(body)
301
  break
@@ -331,8 +362,7 @@ def main() -> None:
331
  )
332
  r.raise_for_status()
333
  ans = r.json()
334
- msg = ans.get("message") or ""
335
- st.success(msg if msg else "Request completed.")
336
  if ans.get("answer"):
337
  st.markdown("### Answer")
338
  st.markdown(ans["answer"])
@@ -370,11 +400,11 @@ def main() -> None:
370
  r = c.post("/query/summarise", json=body)
371
  r.raise_for_status()
372
  ans = r.json()
373
- msg = ans.get("message") or ""
374
- st.success(msg if msg else "Request completed.")
375
- if ans.get("answer"):
376
  st.markdown("### Summary")
377
- st.markdown(ans["answer"])
378
  else:
379
  st.warning("No summary text in the response; see **Raw response** below.")
380
  src = ans.get("sources") or []
@@ -411,11 +441,15 @@ def main() -> None:
411
  )
412
  r.raise_for_status()
413
  payload = r.json()
414
- events = payload.get("events", [])
415
- st.caption(payload.get("message", ""))
416
  if events:
417
  st.dataframe(events, hide_index=True, use_container_width=True)
418
- ids = [e["event_id"] for e in events if isinstance(e, dict) and "event_id" in e]
 
 
 
 
419
  if ids:
420
  st.session_state["_audit_ids"] = ids
421
  else:
@@ -432,7 +466,7 @@ def main() -> None:
432
  pick = ""
433
  if ids_for_select:
434
  pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
435
- manual_id = st.text_input("Or enter event ID", key="audit_manual")
436
  ev_id = (manual_id.strip() or (pick or "").strip()).strip()
437
  if st.button("Load detail", key="btn_audit_detail") and ev_id:
438
  try:
 
71
  return f"HTTP {exc.response.status_code}"
72
 
73
 
74
+ def _post_query_ask(
75
+ client: httpx.Client,
76
+ *,
77
+ question: str,
78
+ collection_name: str,
79
+ top_k: int = 5,
80
+ user_id: str = "anonymous",
81
+ ) -> httpx.Response:
82
+ """POST /query/ask (falls back to POST /query on older servers)."""
83
+ body: dict[str, object] = {
84
+ "question": question.strip(),
85
+ "collection_name": collection_name,
86
+ "top_k": top_k,
87
+ "user_id": user_id,
88
+ }
89
  r = client.post("/query/ask", json=body)
90
  if r.status_code == 404:
91
  r = client.post("/query", json=body)
92
  return r
93
 
94
 
95
+ def _get_audit_logs(
96
+ client: httpx.Client,
97
+ *,
98
+ limit: int,
99
+ offset: int,
100
+ user_id: str | None = None,
101
+ from_date: str | None = None,
102
+ to_date: str | None = None,
103
+ ) -> httpx.Response:
104
+ params: dict[str, object] = {"limit": limit, "offset": offset}
105
+ if user_id:
106
+ params["user_id"] = user_id
107
+ if from_date:
108
+ params["from_date"] = from_date
109
+ if to_date:
110
+ params["to_date"] = to_date
111
  r = client.get("/audit/logs", params=params)
112
  if r.status_code == 404:
113
  r = client.get("/audit", params=params)
 
182
  st.warning("Choose a file first.")
183
  else:
184
  try:
185
+ files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")}
186
  data = {"collection_name": up_collection}
187
  with _client() as c:
188
  r = c.post("/ingest/upload", files=files, data=data)
 
211
  else:
212
  try:
213
  with _client() as c:
214
+ r = c.post(
215
+ "/ingest/url",
216
+ json={"urls": [ingest_url.strip()], "collection_name": url_collection},
217
+ )
218
  r.raise_for_status()
219
  out = r.json()
220
  st.success(out.get("message", "Queued"))
 
235
  r = c.get("/ingest/collections")
236
  r.raise_for_status()
237
  cols = r.json()
238
+ rows = cols.get("collections", [])
239
+ st.write(f"{cols.get('total', len(rows))} collection(s).")
240
+ if rows:
241
+ st.dataframe(rows, hide_index=True, use_container_width=True)
242
  else:
243
  st.info("No collections yet.")
244
  except httpx.HTTPStatusError as e:
 
257
  with _client() as c:
258
  r = c.delete(f"/ingest/collection/{del_name.strip()}")
259
  r.raise_for_status()
260
+ del_body = r.json()
261
+ st.success(del_body.get("message", "Deleted"))
262
+ if "documents_removed" in del_body:
263
+ st.caption(f"Documents removed: **{del_body['documents_removed']}**")
264
  except httpx.HTTPStatusError as e:
265
  st.error(_fmt_api_error(e))
266
  except httpx.ConnectError as e:
 
282
  r.raise_for_status()
283
  payload = r.json()
284
  jobs: list[dict[str, Any]] = payload.get("jobs", [])
285
+ st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**")
286
  if jobs:
287
  st.dataframe(jobs, hide_index=True, use_container_width=True)
288
  else:
 
325
  r = c.get(f"/jobs/{job_id.strip()}")
326
  r.raise_for_status()
327
  body = r.json()
328
+ st_ = body.get("status", "")
329
+ status_ph.write(f"Poll {i + 1}: **{st_}** {body.get('progress_percent', 0)}%")
 
330
  if st_ in ("completed", "failed"):
331
  st.json(body)
332
  break
 
362
  )
363
  r.raise_for_status()
364
  ans = r.json()
365
+ st.success(f"Query id: `{ans.get('query_id', '')}`")
 
366
  if ans.get("answer"):
367
  st.markdown("### Answer")
368
  st.markdown(ans["answer"])
 
400
  r = c.post("/query/summarise", json=body)
401
  r.raise_for_status()
402
  ans = r.json()
403
+ st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**")
404
+ summary_text = ans.get("summary") or ans.get("answer")
405
+ if summary_text:
406
  st.markdown("### Summary")
407
+ st.markdown(summary_text)
408
  else:
409
  st.warning("No summary text in the response; see **Raw response** below.")
410
  src = ans.get("sources") or []
 
441
  )
442
  r.raise_for_status()
443
  payload = r.json()
444
+ events = payload.get("logs", payload.get("events", []))
445
+ st.caption(f"Total matching: **{payload.get('total', len(events))}**")
446
  if events:
447
  st.dataframe(events, hide_index=True, use_container_width=True)
448
+ ids = [
449
+ e.get("query_id") or e.get("event_id")
450
+ for e in events
451
+ if isinstance(e, dict) and (e.get("query_id") or e.get("event_id"))
452
+ ]
453
  if ids:
454
  st.session_state["_audit_ids"] = ids
455
  else:
 
466
  pick = ""
467
  if ids_for_select:
468
  pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
469
+ manual_id = st.text_input("Or enter query / event ID", key="audit_manual")
470
  ev_id = (manual_id.strip() or (pick or "").strip()).strip()
471
  if st.button("Load detail", key="btn_audit_detail") and ev_id:
472
  try:
tests/test_audit.py CHANGED
@@ -1,12 +1,13 @@
1
  import asyncio
2
  from unittest.mock import AsyncMock
 
3
 
4
  import pytest
5
  from fastapi.testclient import TestClient
6
 
7
  from api.config import Settings
8
  from api.main import app
9
- from models.responses import QueryResponse
10
  from storage.audit_store import persist_query_audit
11
 
12
 
@@ -32,39 +33,89 @@ def client(settings, monkeypatch):
32
  yield test_client
33
 
34
 
35
- def _seed_audit(settings: Settings, question: str = "What are key risks?") -> str:
36
- return asyncio.run(
 
37
  persist_query_audit(
38
  settings.audit_db_path,
 
39
  action="query",
 
40
  question=question,
41
  collection_name="default",
42
- response=QueryResponse(
43
- status="success",
44
- message="ok",
45
- answer="Grounded answer",
46
- sources=[],
47
- results=[],
48
- ),
 
 
 
 
 
 
49
  )
50
  )
 
51
 
52
 
53
  def test_audit_logs_and_detail_success(client, settings):
54
- event_id = _seed_audit(settings)
55
 
56
  list_response = client.get("/audit/logs?limit=10&offset=0")
57
  assert list_response.status_code == 200
58
  body = list_response.json()
59
- assert body["status"] == "success"
60
- assert len(body["events"]) >= 1
61
- assert any(event["event_id"] == event_id for event in body["events"])
62
 
63
- detail_response = client.get(f"/audit/logs/{event_id}")
64
  assert detail_response.status_code == 200
65
  detail = detail_response.json()
66
- assert detail["status"] == "success"
67
- assert detail["event"]["question"] == "What are key risks?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  def test_audit_logs_validation_error_for_bad_limit(client):
 
1
  import asyncio
2
  from unittest.mock import AsyncMock
3
+ from uuid import uuid4
4
 
5
  import pytest
6
  from fastapi.testclient import TestClient
7
 
8
  from api.config import Settings
9
  from api.main import app
10
+ from models.responses import SourceCitation
11
  from storage.audit_store import persist_query_audit
12
 
13
 
 
33
  yield test_client
34
 
35
 
36
+ def _seed_audit(settings: Settings, question: str = "What are key risks?", user_id: str = "analyst_001") -> str:
37
+ query_id = str(uuid4())
38
+ asyncio.run(
39
  persist_query_audit(
40
  settings.audit_db_path,
41
+ query_id=query_id,
42
  action="query",
43
+ user_id=user_id,
44
  question=question,
45
  collection_name="default",
46
+ answer="Grounded answer text for audit trail.",
47
+ sources=[
48
+ SourceCitation(
49
+ document_name="report.pdf",
50
+ page_number=3,
51
+ chunk_text="Risk disclosure excerpt.",
52
+ relevance_score=0.9,
53
+ )
54
+ ],
55
+ model_used="ollama:llama3.1:8b",
56
+ tokens_used=120,
57
+ response_time_ms=50,
58
+ kind="ask",
59
  )
60
  )
61
+ return query_id
62
 
63
 
64
  def test_audit_logs_and_detail_success(client, settings):
65
+ query_id = _seed_audit(settings)
66
 
67
  list_response = client.get("/audit/logs?limit=10&offset=0")
68
  assert list_response.status_code == 200
69
  body = list_response.json()
70
+ assert "logs" in body
71
+ assert body["total"] >= 1
72
+ assert any(entry["query_id"] == query_id for entry in body["logs"])
73
 
74
+ detail_response = client.get(f"/audit/logs/{query_id}")
75
  assert detail_response.status_code == 200
76
  detail = detail_response.json()
77
+ assert detail["query_id"] == query_id
78
+ assert detail["question"] == "What are key risks?"
79
+ assert detail["full_answer"] == "Grounded answer text for audit trail."
80
+ assert len(detail["sources"]) == 1
81
+ assert detail["sources"][0]["document_name"] == "report.pdf"
82
+
83
+
84
+ def test_audit_logs_filter_by_user_id(client, settings):
85
+ q1 = _seed_audit(settings, question="Q one", user_id="user_a")
86
+ _seed_audit(settings, question="Q two", user_id="user_b")
87
+
88
+ r = client.get("/audit/logs", params={"user_id": "user_a", "limit": 50, "offset": 0})
89
+ assert r.status_code == 200
90
+ body = r.json()
91
+ ids = {e["query_id"] for e in body["logs"]}
92
+ assert q1 in ids
93
+ assert all(e["user_id"] == "user_a" for e in body["logs"])
94
+
95
+
96
+ def test_audit_logs_filter_by_from_date(client, settings):
97
+ query_id = str(uuid4())
98
+ future = "2099-01-01T00:00:00Z"
99
+ asyncio.run(
100
+ persist_query_audit(
101
+ settings.audit_db_path,
102
+ query_id=query_id,
103
+ action="query",
104
+ user_id="u",
105
+ question="Future dated row",
106
+ collection_name="default",
107
+ answer="A",
108
+ sources=[],
109
+ model_used="m",
110
+ tokens_used=0,
111
+ response_time_ms=1,
112
+ kind="ask",
113
+ )
114
+ )
115
+ r = client.get("/audit/logs", params={"from_date": future, "limit": 50, "offset": 0})
116
+ assert r.status_code == 200
117
+ body = r.json()
118
+ assert query_id not in {e["query_id"] for e in body["logs"]}
119
 
120
 
121
  def test_audit_logs_validation_error_for_bad_limit(client):
tests/test_ingest.py CHANGED
@@ -34,21 +34,23 @@ def test_upload_queues_job_success(client, monkeypatch):
34
  response = client.post(
35
  "/ingest/upload",
36
  data={"collection_name": "default"},
37
- files={"file": ("sample.txt", b"hello world", "text/plain")},
38
  )
39
 
40
  assert response.status_code == 200
41
  body = response.json()
42
  assert body["status"] == "queued"
43
  assert body["job_id"] == "job-123"
44
- assert "Poll GET /jobs/job-123" in body["message"]
 
 
45
 
46
 
47
  def test_upload_rejects_unsupported_extension(client):
48
  response = client.post(
49
  "/ingest/upload",
50
  data={"collection_name": "default"},
51
- files={"file": ("sample.csv", b"a,b\n1,2", "text/csv")},
52
  )
53
 
54
  assert response.status_code == 400
@@ -65,7 +67,7 @@ def test_upload_returns_500_on_job_creation_error(client, monkeypatch):
65
  response = client.post(
66
  "/ingest/upload",
67
  data={"collection_name": "default"},
68
- files={"file": ("sample.txt", b"hello", "text/plain")},
69
  )
70
 
71
  assert response.status_code == 500
@@ -75,12 +77,14 @@ def test_upload_returns_500_on_job_creation_error(client, monkeypatch):
75
  def test_ingest_url_rejects_non_http_scheme(client, monkeypatch):
76
  monkeypatch.setattr(
77
  "api.routes.ingest._download_url_to_temp",
78
- AsyncMock(side_effect=ingest_route.HTTPException(status_code=400, detail="Only http and https URLs are supported.")),
 
 
79
  )
80
 
81
  response = client.post(
82
  "/ingest/url",
83
- json={"url": "https://example.com/file.txt", "collection_name": "default"},
84
  )
85
 
86
  assert response.status_code == 400
 
34
  response = client.post(
35
  "/ingest/upload",
36
  data={"collection_name": "default"},
37
+ files=[("files", ("sample.txt", b"hello world", "text/plain"))],
38
  )
39
 
40
  assert response.status_code == 200
41
  body = response.json()
42
  assert body["status"] == "queued"
43
  assert body["job_id"] == "job-123"
44
+ assert body["total_files"] == 1
45
+ assert body["filenames"] == ["sample.txt"]
46
+ assert "Poll /jobs/job-123" in body["message"]
47
 
48
 
49
  def test_upload_rejects_unsupported_extension(client):
50
  response = client.post(
51
  "/ingest/upload",
52
  data={"collection_name": "default"},
53
+ files=[("files", ("sample.csv", b"a,b\n1,2", "text/csv"))],
54
  )
55
 
56
  assert response.status_code == 400
 
67
  response = client.post(
68
  "/ingest/upload",
69
  data={"collection_name": "default"},
70
+ files=[("files", ("sample.txt", b"hello", "text/plain"))],
71
  )
72
 
73
  assert response.status_code == 500
 
77
  def test_ingest_url_rejects_non_http_scheme(client, monkeypatch):
78
  monkeypatch.setattr(
79
  "api.routes.ingest._download_url_to_temp",
80
+ AsyncMock(
81
+ side_effect=ingest_route.HTTPException(status_code=400, detail="Only http and https URLs are supported.")
82
+ ),
83
  )
84
 
85
  response = client.post(
86
  "/ingest/url",
87
+ json={"urls": ["https://example.com/file.txt"], "collection_name": "default"},
88
  )
89
 
90
  assert response.status_code == 400
tests/test_query.py CHANGED
@@ -40,20 +40,51 @@ def test_ask_returns_grounded_answer_with_sources(client, monkeypatch):
40
  monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
41
  monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
42
  monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
43
- monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: "Audi is expanding EV investment.")
44
  monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
45
 
46
  response = client.post(
47
  "/query/ask",
48
- json={"question": "What is Audi doing in EV?", "collection_name": "default"},
 
 
 
 
 
49
  )
50
 
51
  assert response.status_code == 200
52
  body = response.json()
53
- assert body["status"] == "success"
54
  assert body["answer"] == "Audi is expanding EV investment."
 
 
55
  assert len(body["sources"]) == 1
56
- assert body["sources"][0]["source"] == "strategy.md"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def test_ask_returns_422_for_invalid_payload(client):
@@ -61,6 +92,14 @@ def test_ask_returns_422_for_invalid_payload(client):
61
  assert response.status_code == 422
62
 
63
 
 
 
 
 
 
 
 
 
64
  def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
65
  monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
66
  monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
@@ -68,7 +107,7 @@ def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
68
 
69
  response = client.post(
70
  "/query/ask",
71
- json={"question": "What happened?", "collection_name": "default"},
72
  )
73
 
74
  assert response.status_code == 500
@@ -88,7 +127,8 @@ def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
88
  monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
89
  monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
90
  monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
91
- monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: "Summary output")
 
92
  monkeypatch.setattr(
93
  "api.routes.query.persist_query_audit",
94
  AsyncMock(side_effect=RuntimeError("audit write failed")),
@@ -96,7 +136,7 @@ def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
96
 
97
  response = client.post(
98
  "/query/summarise",
99
- json={"collection_name": "default", "focus": "summarise risks"},
100
  )
101
 
102
  assert response.status_code == 500
 
40
  monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
41
  monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
42
  monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
43
+ monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42))
44
  monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
45
 
46
  response = client.post(
47
  "/query/ask",
48
+ json={
49
+ "question": "What is Audi doing in EV markets worldwide?",
50
+ "collection_name": "default",
51
+ "top_k": 3,
52
+ "user_id": "tester",
53
+ },
54
  )
55
 
56
  assert response.status_code == 200
57
  body = response.json()
 
58
  assert body["answer"] == "Audi is expanding EV investment."
59
+ assert "query_id" in body
60
+ assert body["question"].startswith("What is Audi")
61
  assert len(body["sources"]) == 1
62
+ assert body["sources"][0]["document_name"] == "strategy.md"
63
+ assert body["sources"][0]["page_number"] == 1
64
+ assert body["tokens_used"] == 42
65
+ assert "response_time_ms" in body
66
+ assert "model_used" in body
67
+
68
+
69
+ def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch):
70
+ captured: dict[str, object] = {}
71
+
72
+ def capture_retrieve(vs, question, k):
73
+ captured["k"] = k
74
+ return []
75
+
76
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
77
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
78
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve)
79
+ monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0))
80
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
81
+
82
+ response = client.post(
83
+ "/query/ask",
84
+ json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7},
85
+ )
86
+ assert response.status_code == 200
87
+ assert captured.get("k") == 7
88
 
89
 
90
  def test_ask_returns_422_for_invalid_payload(client):
 
92
  assert response.status_code == 422
93
 
94
 
95
+ def test_ask_returns_422_for_short_question(client):
96
+ response = client.post(
97
+ "/query/ask",
98
+ json={"question": "hi", "collection_name": "default"},
99
+ )
100
+ assert response.status_code == 422
101
+
102
+
103
  def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
104
  monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
105
  monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
 
107
 
108
  response = client.post(
109
  "/query/ask",
110
+ json={"question": "What happened in the documents?", "collection_name": "default"},
111
  )
112
 
113
  assert response.status_code == 500
 
127
  monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
128
  monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
129
  monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
130
+ monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10))
131
+ monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5)
132
  monkeypatch.setattr(
133
  "api.routes.query.persist_query_audit",
134
  AsyncMock(side_effect=RuntimeError("audit write failed")),
 
136
 
137
  response = client.post(
138
  "/query/summarise",
139
+ json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"},
140
  )
141
 
142
  assert response.status_code == 500
workers/ingest_worker.py CHANGED
@@ -5,10 +5,15 @@ from rag.chunker import chunk_documents
5
  from rag.embedder import create_embedding_function
6
  from rag.loader import load_documents
7
  from rag.vector_store import add_documents, get_vector_store
8
- from storage.job_store import update_ingest_job
 
 
 
 
 
9
 
10
 
11
- def _ingest_sync(temp_path: str, collection_name: str, chroma_persist_directory: str) -> tuple[list[str], int]:
12
  documents = load_documents(temp_path)
13
  chunks = chunk_documents(documents)
14
  if not chunks:
@@ -25,37 +30,72 @@ def _ingest_sync(temp_path: str, collection_name: str, chroma_persist_directory:
25
 
26
  async def run_ingest_job(
27
  job_id: str,
28
- temp_path: str,
29
  collection_name: str,
30
  jobs_db_path: str,
31
  chroma_persist_directory: str,
32
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
- await update_ingest_job(
35
- jobs_db_path,
36
- job_id,
37
- status="processing",
38
- message="Ingestion in progress.",
39
- )
40
- document_ids, num_chunks = await asyncio.to_thread(
41
- _ingest_sync,
42
- temp_path,
43
- collection_name,
44
- chroma_persist_directory,
45
- )
46
- await update_ingest_job(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  jobs_db_path,
48
  job_id,
49
- status="completed",
50
- message=f"Ingested {num_chunks} chunks.",
51
- document_ids=document_ids,
52
  )
53
  except Exception as exc:
54
- await update_ingest_job(
55
- jobs_db_path,
56
- job_id,
57
- status="failed",
58
- message=str(exc),
59
- )
60
- finally:
61
- Path(temp_path).unlink(missing_ok=True)
 
5
  from rag.embedder import create_embedding_function
6
  from rag.loader import load_documents
7
  from rag.vector_store import add_documents, get_vector_store
8
+ from storage.job_store import (
9
+ complete_ingest_job,
10
+ fail_ingest_job,
11
+ mark_job_processing,
12
+ update_job_progress,
13
+ )
14
 
15
 
16
+ def _ingest_one_file_sync(temp_path: str, collection_name: str, chroma_persist_directory: str) -> tuple[list[str], int]:
17
  documents = load_documents(temp_path)
18
  chunks = chunk_documents(documents)
19
  if not chunks:
 
30
 
31
  async def run_ingest_job(
32
  job_id: str,
33
+ files: list[tuple[str, str]],
34
  collection_name: str,
35
  jobs_db_path: str,
36
  chroma_persist_directory: str,
37
  ) -> None:
38
+ """
39
+ Process one or more temp files for a single job. ``files`` is (temp_path, display_name).
40
+ """
41
+ all_doc_ids: list[str] = []
42
+ errors: list[str] = []
43
+ processed = 0
44
+ failed = 0
45
+ total = len(files)
46
+ if total == 0:
47
+ await fail_ingest_job(jobs_db_path, job_id, message="No files to ingest.")
48
+ return
49
+
50
  try:
51
+ await mark_job_processing(jobs_db_path, job_id)
52
+ for temp_path, display_name in files:
53
+ try:
54
+ doc_ids, num_chunks = await asyncio.to_thread(
55
+ _ingest_one_file_sync,
56
+ temp_path,
57
+ collection_name,
58
+ chroma_persist_directory,
59
+ )
60
+ all_doc_ids.extend(doc_ids)
61
+ processed += 1
62
+ await update_job_progress(
63
+ jobs_db_path,
64
+ job_id,
65
+ processed_files=processed,
66
+ failed_files=failed,
67
+ errors=errors,
68
+ message=f"Ingested {display_name} ({num_chunks} chunks).",
69
+ )
70
+ except Exception as exc:
71
+ failed += 1
72
+ errors.append(f"{display_name}: {exc}")
73
+ await update_job_progress(
74
+ jobs_db_path,
75
+ job_id,
76
+ processed_files=processed,
77
+ failed_files=failed,
78
+ errors=errors,
79
+ message=f"Failed on {display_name}: {exc}",
80
+ )
81
+ finally:
82
+ Path(temp_path).unlink(missing_ok=True)
83
+
84
+ if processed == 0:
85
+ await fail_ingest_job(
86
+ jobs_db_path,
87
+ job_id,
88
+ message="All files failed ingestion.",
89
+ errors=errors,
90
+ )
91
+ return
92
+
93
+ chunk_note = f"{len(all_doc_ids)} chunk vector(s) across {processed} file(s)."
94
+ await complete_ingest_job(
95
  jobs_db_path,
96
  job_id,
97
+ document_ids=all_doc_ids,
98
+ message=f"Ingestion completed. {chunk_note}",
 
99
  )
100
  except Exception as exc:
101
+ await fail_ingest_job(jobs_db_path, job_id, message=str(exc), errors=errors + [str(exc)])