Mayank Chugh commited on
Commit
d44b33d
·
0 Parent(s):

Deploy DocuAudit AI to Hugging Face Space (no binaries)

Browse files
.dockerignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Markdown: omit from build context except repo root README.md (Dockerfile COPY README.md).
2
+ *.md
3
+ **/*.md
4
+ !README.md
5
+
6
+ .git
7
+ .gitignore
8
+ .env
9
+ .venv
10
+ venv
11
+ __pycache__
12
+ *.py[cod]
13
+ *$py.class
14
+ .pytest_cache
15
+ .mypy_cache
16
+ .ruff_cache
17
+ *.egg-info
18
+ dist
19
+ build
20
+ .coverage
21
+ htmlcov
22
+ .DS_Store
23
+ docs
24
+ tests
25
+ .cursor
26
+ terminals
27
+ *.log
28
+ data/chroma
29
+ chroma
.env.example ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DocuAudit AI — environment template (see docs/DOCUAUDIT_AI_REQUIREMENTS.md)
2
+
3
+ # LLM Provider: ollama | anthropic | openai | huggingface
4
+ LLM_PROVIDER=ollama
5
+
6
+ # OpenAI (optional)
7
+ OPENAI_API_KEY=
8
+ OPENAI_MODEL=gpt-4o
9
+ OPENAI_EMBEDDING_MODEL=text-embedding-3-small
10
+
11
+ # Anthropic (optional)
12
+ ANTHROPIC_API_KEY=
13
+ ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
14
+
15
+ # Hugging Face Inference API (when LLM_PROVIDER=huggingface — typical on Hugging Face Spaces)
16
+ # Use a fine-grained token with "Make calls to Inference Providers" / Inference API where required.
17
+ HUGGINGFACE_API_KEY=
18
+ # Use a model your Hub gates allow (e.g. Llama 3.8B under “Meta Llama 3”, or Mistral instruct). Llama 3.1 needs its own gate. Chat: hf-inference then router auto.
19
+ #HUGGINGFACE_MODEL=mistralai/Mistral-7B-Instruct-v0.3
20
+ #HUGGINGFACE_MODEL=meta-llama/Meta-Llama-3.1-8B-Instruct
21
+ HUGGINGFACE_MODEL=meta-llama/Meta-Llama-3-8B-Instruct
22
+ HUGGINGFACE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
23
+ # Optional: huggingface_hub InferenceClient provider. Leave unset: primary hf-inference, then router auto for chat (Mistral instruct ids also try Novita).
24
+ # Use `auto` for router-only primary client (may pick Novita and break some models).
25
+ HUGGINGFACE_INFERENCE_PROVIDER=
26
+ # On Hugging Face Spaces you can omit HUGGINGFACE_API_KEY if the Space provides HF_TOKEN (mapped
27
+ # automatically when LLM_PROVIDER=huggingface). For local .env you can set HF_TOKEN instead.
28
+
29
+ # Ollama (recommended local default)
30
+ OLLAMA_BASE_URL=http://localhost:11434
31
+ OLLAMA_CHAT_MODEL=llama3.1:8b
32
+ OLLAMA_EMBEDDING_MODEL=nomic-embed-text
33
+
34
+ # App
35
+ APP_NAME=DocuAudit AI
36
+ APP_VERSION=1.0.0
37
+ DEBUG=false
38
+ MAX_FILE_SIZE_MB=50
39
+ # Spec name alias (optional; mapped to MAX_FILE_SIZE_MB in settings)
40
+ MAX_UPLOAD_SIZE_MB=
41
+
42
+ # ChromaDB
43
+ CHROMA_PERSIST_DIRECTORY=./data/chroma
44
+ CHROMA_PERSIST_DIR=
45
+ CHROMA_COLLECTION_NAME=docuaudit_docs
46
+
47
+ # Chunking
48
+ CHUNK_SIZE=1000
49
+ CHUNK_OVERLAP=200
50
+
51
+ # Retrieval default (overridable per request on /query/ask via top_k)
52
+ TOP_K_RESULTS=5
53
+
54
+ # Audit + jobs SQLite
55
+ AUDIT_DB_PATH=./audit.db
56
+ JOBS_DB_PATH=./data/jobs.db
57
+
58
+ # Limits
59
+ MAX_DOCUMENTS_PER_BATCH=100
60
+
61
+ # URL ingest (POST /ingest/url). SEC.gov blocks undeclared bots — use "Company Name you@email.com".
62
+ # INGEST_USER_AGENT=DocuAudit AI you@example.com
63
+
64
+ # Streamlit → API (Streamlit process reads these when set in the shell / OS env)
65
+ STREAMLIT_BACKEND_URL=http://localhost:8000
66
+ DOC_AUDI_API_BASE=http://127.0.0.1:8000
67
+ # Read timeout (seconds) for Ask/Summarise HTTP calls; default in code is 3600 if unset
68
+ DOC_AUDI_HTTP_READ_TIMEOUT=3600
69
+
70
+ # --- Docker Compose (Milestone 12) ---
71
+ # Copy this file to `.env` before `docker compose up` (Compose loads `.env` for substitution and `env_file`).
72
+ #
73
+ # Persistent paths below are overridden in docker-compose.yml to a single volume mount at /data:
74
+ # CHROMA_PERSIST_DIRECTORY=/data/chroma, AUDIT_DB_PATH=/data/audit.db, JOBS_DB_PATH=/data/jobs.db
75
+ # You do not need to duplicate those in .env for compose unless you use a custom override file.
76
+ #
77
+ # Ollama from the API container cannot reach localhost on your machine; default in compose is:
78
+ # OLLAMA_BASE_URL=http://host.docker.internal:11434
79
+ # (extra_hosts host-gateway is set for Linux.) Run `ollama serve` on the host, or start the bundled
80
+ # Ollama service: docker compose --profile ollama up -d
81
+ # When using the compose `ollama` profile, set in .env:
82
+ # OLLAMA_BASE_URL=http://ollama:11434
83
+ #
84
+ # Compose sets DOC_AUDI_API_BASE / STREAMLIT_BACKEND_URL to http://api:8000 for the Streamlit service
85
+ # so server-side HTTP calls reach the API on the Docker network (do not override for UI in compose).
86
+ #
87
+ # Optional port overrides: API_PORT=8000, STREAMLIT_PORT=8501, OLLAMA_HOST_PORT=11434
88
+
89
+ # --- Hugging Face Spaces ---
90
+ # Recommended for CPU Spaces (no Ollama): set in Space Settings → Repository secrets → Variables
91
+ # LLM_PROVIDER=huggingface
92
+ # HUGGINGFACE_API_KEY=<token> OR rely on built-in HF_TOKEN (same value as a Hub token secret)
93
+ # HUGGINGFACE_MODEL / HUGGINGFACE_EMBEDDING_MODEL as needed
94
+ # If the API runs in a second Space or external URL, set for the Streamlit Space:
95
+ # DOC_AUDI_API_BASE=https://your-api....hf.space (or your FastAPI public URL)
96
+ # Streamlit on Spaces must listen on port 8501 (default). Entry file: app.py (see docs/HUGGING_FACE_SPACES.md).
97
+ # On Streamlit SDK Spaces, only Streamlit starts by default; app.py auto-starts uvicorn on 127.0.0.1:8000 when
98
+ # SPACE_ID is set (built-in Hub env). Set DOC_AUDI_EMBED_API=0 to disable if you use a separate API URL above.
99
+ # Repository secrets (HF_TOKEN / HUGGINGFACE_API_KEY) are copied from st.secrets into the API subprocess env.
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Single image for API (uvicorn) and UI (Streamlit); compose overrides the command per service.
2
+ FROM python:3.11-slim-bookworm
3
+
4
+ ENV PYTHONDONTWRITEBYTECODE=1 \
5
+ PYTHONUNBUFFERED=1 \
6
+ PYTHONPATH=/app \
7
+ PIP_NO_CACHE_DIR=1 \
8
+ ANONYMIZED_TELEMETRY=FALSE
9
+
10
+ WORKDIR /app
11
+
12
+ # PyMuPDF / scientific wheels are manylinux; minimal OS deps for SSL and fonts used by PDF tooling.
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ ca-certificates \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ COPY requirements.txt .
18
+ RUN pip install --upgrade pip && pip install -r requirements.txt
19
+
20
+ COPY api/ api/
21
+ COPY models/ models/
22
+ COPY rag/ rag/
23
+ COPY storage/ storage/
24
+ COPY workers/ workers/
25
+ COPY app.py streamlit_app.py main.py pyproject.toml README.md ./
26
+
27
+ EXPOSE 8000 8501
28
+
29
+ CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Document-Audit RAG
3
+ emoji: 📑
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: "1.39.0"
8
+ app_file: app.py
9
+ ---
10
+
11
+ # DocuAudit AI
12
+
13
+ **DocuAudit AI** is a production-oriented FastAPI backend plus optional Streamlit UI for **multi-document RAG**: upload documents, build a Chroma vector index, ask grounded questions with citations, and retain a **SQLite audit trail** of every query.
14
+
15
+ ## Architecture
16
+
17
+ ```mermaid
18
+ flowchart LR
19
+ subgraph ingest [Ingestion]
20
+ A[PDF / TXT / MD] --> B[Loader]
21
+ B --> C[Chunker]
22
+ C --> D[Embedder]
23
+ D --> E[(ChromaDB)]
24
+ end
25
+ subgraph query [Query path]
26
+ Q[User question] --> R[Semantic search]
27
+ R --> E
28
+ R --> T[Top-K chunks]
29
+ T --> L[LLM]
30
+ L --> U[Answer + citations]
31
+ end
32
+ U --> V[(SQLite audit)]
33
+ ```
34
+
35
+ ASCII equivalent:
36
+
37
+ ```
38
+ PDF Upload → Parser → Chunker → Embedder → ChromaDB
39
+
40
+ User Query → Semantic Search → Top-K Chunks → LLM → Answer + Citations
41
+
42
+ Audit Log (SQLite)
43
+ ```
44
+
45
+ ## Use cases
46
+
47
+ - **Litigation document analysis** — trace claims to exact pages and filenames.
48
+ - **Corporate finance review** — compare disclosures and filings under a consistent audit log.
49
+ - **Investigation support** — bulk ingest, async jobs, and reproducible query history.
50
+
51
+ ## Deploying on Hugging Face Spaces
52
+
53
+ - Set **`LLM_PROVIDER=huggingface`**; use **`HUGGINGFACE_API_KEY`** and/or the Space secret **`HF_TOKEN`** (see [`.env.example`](.env.example)).
54
+ - Use root **`app.py`** as the Streamlit entry for the default Hub command.
55
+ - Hub UI, secrets, hardware, and Streamlit SDK details: [Streamlit Spaces](https://huggingface.co/docs/hub/spaces-sdks-streamlit), [Spaces overview](https://huggingface.co/docs/hub/spaces-overview).
56
+ - **Test locally before deploy:** `uv run python scripts/verify_huggingface_inference.py` (requires `LLM_PROVIDER=huggingface` in `.env`).
57
+
58
+ ## Quick start with Docker
59
+
60
+ Requires [Docker Engine](https://docs.docker.com/engine/) and Compose v2. The snippet below matches the shipped **`docker-compose.yml`**: API on **8000**, Streamlit on **8501**, with Chroma and SQLite under **`/data`** inside the API container. After **`docker compose up -d`**, expect **`curl http://localhost:8000/health`** to return JSON including **`"status":"ok"`**.
61
+
62
+ ```bash
63
+ git clone <repository-url> doc-Audi-ai
64
+ cd doc-Audi-ai
65
+ cp .env.example .env
66
+ # edit .env as needed; for compose Ollama: OLLAMA_BASE_URL=http://ollama:11434
67
+ # (with host Ollama: run `ollama serve`; compose defaults to host.docker.internal:11434)
68
+
69
+ docker compose build
70
+ docker compose up -d
71
+ curl -s http://localhost:8000/health
72
+ # http://localhost:8501 — Streamlit
73
+ docker compose down
74
+ ```
75
+
76
+ Optional all-in-one Ollama in Compose: `docker compose --profile ollama up -d` (then set `OLLAMA_BASE_URL=http://ollama:11434` in `.env` and recreate containers).
77
+
78
+ ## How it works (user workflow)
79
+
80
+ Collections, ingestion vs querying, jobs vs audit, Streamlit tabs, and **per-button UI flows**: **[docs/USER_WORKFLOW.md](docs/USER_WORKFLOW.md)**.
81
+
82
+ ## Run and test (step-by-step)
83
+
84
+ For ingestion formats, URL rules, job polling, sample `sample.txt` walkthrough, curl/PowerShell examples, and troubleshooting, see **[docs/RUN_AND_TEST_GUIDE.md](docs/RUN_AND_TEST_GUIDE.md)**.
85
+
86
+ For SQLite vs Memcached, offline DB inspection, and the Cursor **SQLite Viewer** extension (`qwtel.sqlite-viewer`), see **[docs/SQLITE_AND_DB_INSPECTION.md](docs/SQLITE_AND_DB_INSPECTION.md)**.
87
+
88
+ ## Quick start (local, without Docker)
89
+
90
+ Run the API with **uv** (or your preferred tool):
91
+
92
+ ```bash
93
+ git clone <repository-url> doc-Audi-ai
94
+ cd doc-Audi-ai
95
+ cp .env.example .env
96
+ uv sync
97
+ ollama pull llama3.1:8b
98
+ ollama pull nomic-embed-text
99
+ uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
100
+
101
+ uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload --reload-dir api --reload-dir storage
102
+ ```
103
+
104
+ Optional UI:
105
+
106
+ ```bash
107
+ uv run streamlit run streamlit_app.py --server.port 8501 --server.address 0.0.0.0
108
+ ```
109
+
110
+ ## API overview
111
+
112
+ | Method | Path | Description |
113
+ |--------|------|-------------|
114
+ | GET | `/health` | Liveness; returns configured app name and version |
115
+ | POST | `/ingest/upload` | Multipart **`files`** (one or more); queues background ingest job |
116
+ | POST | `/ingest/url` | JSON **`urls`** array (1–100); download and queue ingest |
117
+ | GET | `/ingest/collections` | Lists collections with **`document_count`** and optional **`created_at`** |
118
+ | DELETE | `/ingest/collection/{collection_name}` | Drops a collection; returns **`documents_removed`** |
119
+ | GET | `/jobs` | Lists jobs with **`total`** count |
120
+ | GET | `/jobs/{job_id}` | Job status with **`progress_percent`**, file counters, timestamps, **`errors`** |
121
+ | POST | `/query/ask` | Grounded answer; request includes **`top_k`**, **`user_id`** |
122
+ | POST | `/query/summarise` | Collection summary; distinct response shape (`summary`, `document_count`, …) |
123
+ | POST | `/query` | Legacy alias of **`/query/ask`** |
124
+ | GET | `/audit/logs` | Filterable audit index (`user_id`, `from_date`, `to_date`, pagination) |
125
+ | GET | `/audit/logs/{query_id}` | Full stored answer and citations for one query |
126
+
127
+ Interactive docs: `http://localhost:8000/docs`.
128
+
129
+ ## Sample request and response (`POST /query/ask`)
130
+
131
+ Request:
132
+
133
+ ```json
134
+ {
135
+ "question": "What were the key risk factors identified in the Q3 2023 financial report?",
136
+ "collection_name": "default",
137
+ "top_k": 5,
138
+ "user_id": "analyst_001"
139
+ }
140
+ ```
141
+
142
+ Response (shape; values depend on your documents and model):
143
+
144
+ ```json
145
+ {
146
+ "query_id": "uuid-string",
147
+ "question": "What were the key risk factors identified in the Q3 2023 financial report?",
148
+ "answer": "… grounded text with citations …",
149
+ "sources": [
150
+ {
151
+ "document_name": "q3_financial_report.pdf",
152
+ "page_number": 12,
153
+ "chunk_text": "Key risk factors include …",
154
+ "relevance_score": 0.91
155
+ }
156
+ ],
157
+ "model_used": "llama3.1:8b",
158
+ "tokens_used": 0,
159
+ "response_time_ms": 1820,
160
+ "timestamp": "2026-05-03T12:00:00Z"
161
+ }
162
+ ```
163
+
164
+ ## Design decisions
165
+
166
+ - **Source citations** — High-stakes review requires every substantive claim to be tied to **document name** and **page** (where available), not a free-floating model monologue.
167
+ - **Auditability** — Each ask/summarise persists **query id**, **user id**, timing, model id, token usage (when the provider exposes it), and serialized sources so regulators or counsel can reconstruct what the system returned.
168
+
169
+ ## Scale note
170
+
171
+ Architecture is designed for **high-volume document ingestion** via **async background jobs** (FastAPI `BackgroundTasks`), persistent Chroma collections, and a stateless API tier that can be replicated once you add a shared vector store and job queue.
172
+
173
+ ## Tests
174
+
175
+ Automated API tests use **pytest** with isolated temp databases; they do **not** require a running server or Ollama.
176
+
177
+ ```bash
178
+ uv sync
179
+ uv run pytest tests/ -q
180
+ ```
181
+
182
+ Full guide (commands, coverage by file, mocks vs manual smoke tests, troubleshooting): **[docs/TESTING.md](docs/TESTING.md)**.
183
+
184
+ ## Configuration
185
+
186
+ See **`.env.example`**. Common variables include `LLM_PROVIDER`, Ollama/OpenAI/Anthropic keys and models, `CHROMA_PERSIST_DIRECTORY`, `AUDIT_DB_PATH`, `JOBS_DB_PATH`, and upload limits (`MAX_FILE_SIZE_MB`; **`MAX_UPLOAD_SIZE_MB`** is accepted as an alias via settings normalization).
api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """HTTP API package: FastAPI app, settings, and route modules."""
api/config.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application configuration loaded from environment variables and ``.env``.
2
+
3
+ ``Settings`` is the single source of truth for LLM provider choice, Chroma paths,
4
+ chunking limits, upload caps, and SQLite locations. Use :func:`get_settings` (cached)
5
+ from route handlers and RAG modules instead of reading ``os.environ`` directly.
6
+ """
7
+
8
+ import os
9
+ from functools import lru_cache
10
+ from typing import Any, Self
11
+
12
+ from pydantic import Field, model_validator
13
+
14
+ from pydantic_settings import BaseSettings, SettingsConfigDict
15
+
16
+
17
+ class Settings(BaseSettings):
18
+ """Pydantic-settings model for DocuAudit AI; fields map to env vars (case-insensitive)."""
19
+
20
+ model_config = SettingsConfigDict(
21
+ env_file=".env",
22
+ env_file_encoding="utf-8",
23
+ extra="ignore",
24
+ case_sensitive=False,
25
+ populate_by_name=True,
26
+ )
27
+
28
+ @model_validator(mode="before")
29
+ @classmethod
30
+ def _map_max_upload_env_alias(cls, data: Any) -> Any:
31
+ if not isinstance(data, dict):
32
+ return data
33
+ out = dict(data)
34
+ if out.get("max_file_size_mb") in (None, "") and out.get("max_upload_size_mb") not in (None, ""):
35
+ out["max_file_size_mb"] = out.pop("max_upload_size_mb")
36
+ elif "max_upload_size_mb" in out and "max_file_size_mb" not in out:
37
+ out["max_file_size_mb"] = out.pop("max_upload_size_mb")
38
+ return out
39
+
40
+ app_name: str = Field(default="DocuAudit AI", description="FastAPI title and product name")
41
+ app_version: str = Field(default="1.0.0", description="Application version")
42
+ app_description: str = Field(
43
+ default=(
44
+ "Multi-document RAG API for high-stakes consulting environments. "
45
+ "Every answer is grounded in source documents with full audit trails."
46
+ ),
47
+ description="OpenAPI /docs description",
48
+ )
49
+ llm_provider: str = Field(default="ollama", description="Embedding provider")
50
+
51
+ openai_api_key: str | None = Field(default=None, description="OpenAI API key")
52
+ openai_model: str = "gpt-4o"
53
+ openai_embedding_model: str = "text-embedding-3-small"
54
+
55
+ anthropic_api_key: str = ""
56
+ anthropic_model: str = "claude-3-5-sonnet-20241022"
57
+
58
+ huggingface_api_key: str = ""
59
+ huggingface_model: str = Field(
60
+ default="meta-llama/Meta-Llama-3-8B-Instruct",
61
+ description=(
62
+ "HF chat model id (use a repo your Hub account already has access to; Llama 3.1 needs the "
63
+ "separate Llama 3.1 gate). Chat tries hf-inference then router auto when unset."
64
+ ),
65
+ )
66
+ huggingface_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
67
+ huggingface_inference_provider: str | None = Field(
68
+ default=None,
69
+ description=(
70
+ "Optional huggingface_hub InferenceClient provider (e.g. hf-inference, together). "
71
+ "Unset uses hf-inference in chat code; set to `auto` for router auto-routing."
72
+ ),
73
+ )
74
+
75
+ ollama_base_url: str = Field(default="http://localhost:11434", description="Ollama base URL")
76
+ ollama_chat_model: str = "llama3.1:8b"
77
+ ollama_embedding_model: str = "nomic-embed-text"
78
+
79
+ chroma_persist_directory: str = Field(default="./data/chroma", description="Chroma persistence path")
80
+
81
+ chroma_persist_dir: str = Field(default="./chroma", description="Chroma persistence path")
82
+ chroma_collection_name: str = "docuaudit_docs"
83
+
84
+ chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
85
+ chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
86
+ top_k_results: int = Field(default=5, ge=1, le=20, description="Default number of chunks to retrieve")
87
+
88
+ audit_db_path: str = "./audit.db"
89
+ jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
90
+
91
+ max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size (MB)")
92
+ max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
93
+ ingest_user_agent: str = Field(
94
+ default="DocuAudit AI docuaudit-ingest@example.com",
95
+ description=(
96
+ "HTTP User-Agent for POST /ingest/url downloads. SEC.gov requires "
97
+ "'Company Name contact@email.com' with a reachable address (see sec.gov/os/accessing-edgar-data)."
98
+ ),
99
+ )
100
+
101
+ @model_validator(mode="after")
102
+ def _space_default_llm_provider(self) -> Self:
103
+ """Hugging Face Spaces do not run Ollama locally; use Hub inference unless the user set LLM_PROVIDER."""
104
+ if not (os.environ.get("SPACE_ID") or "").strip():
105
+ return self
106
+ if "LLM_PROVIDER" in os.environ:
107
+ return self
108
+ if self.llm_provider.lower() != "ollama":
109
+ return self
110
+ self.llm_provider = "huggingface"
111
+ return self
112
+
113
+ @model_validator(mode="after")
114
+ def _huggingface_token_from_hub_env(self) -> Self:
115
+ """When using the Hugging Face inference stack, accept the Hub token from standard env names.
116
+
117
+ Spaces often expose `HF_TOKEN` (read/write per Space secrets). Map it into `huggingface_api_key`
118
+ when `HUGGINGFACE_API_KEY` is unset so embedder/chat clients receive a token.
119
+ """
120
+ if self.llm_provider.lower() != "huggingface":
121
+ return self
122
+ if (self.huggingface_api_key or "").strip():
123
+ return self
124
+ for key in ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"):
125
+ token = (os.environ.get(key) or "").strip()
126
+ if token:
127
+ self.huggingface_api_key = token
128
+ break
129
+ return self
130
+
131
+
132
+ @lru_cache
133
+ def get_settings() -> Settings:
134
+ """Return the process-wide settings singleton (cleared in tests via ``cache_clear()``)."""
135
+ return Settings()
api/main.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application entry point for DocuAudit AI.
2
+
3
+ Creates the ASGI app, registers CORS, mounts route modules (ingest, query, jobs, audit),
4
+ and initializes SQLite audit and job stores on startup.
5
+
6
+ Run locally::
7
+
8
+ uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
9
+
10
+ Health check: ``GET /health``.
11
+ """
12
+
13
+ import os
14
+
15
+ # Before any route imports that touch Chroma: disable product telemetry (avoids posthog capture() errors in logs).
16
+ os.environ.setdefault("ANONYMIZED_TELEMETRY", "FALSE")
17
+
18
+ from fastapi import FastAPI
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+
21
+ from api.config import get_settings
22
+ from storage.audit_store import init_audit_db
23
+ from storage.job_store import init_jobs_db
24
+ from .routes import audit, ingest, jobs, query
25
+
26
+ _settings = get_settings()
27
+ app = FastAPI(
28
+ title=_settings.app_name,
29
+ version=_settings.app_version,
30
+ description=_settings.app_description,
31
+ )
32
+
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ app.include_router(audit.router)
42
+ app.include_router(ingest.router)
43
+ app.include_router(jobs.router)
44
+ app.include_router(query.router)
45
+ app.include_router(query.legacy_query_router)
46
+
47
+
48
+ @app.on_event("startup")
49
+ async def startup() -> None:
50
+ """Ensure audit and ingest-job SQLite schemas exist before serving traffic."""
51
+ settings = get_settings()
52
+ await init_audit_db(settings.audit_db_path)
53
+ await init_jobs_db(settings.jobs_db_path)
54
+
55
+
56
+ @app.get("/health", tags=["Health"])
57
+ def health() -> dict[str, str]:
58
+ """Liveness probe returning app name, version, and ``status: ok``."""
59
+ settings = get_settings()
60
+ return {
61
+ "status": "ok",
62
+ "app": settings.app_name,
63
+ "version": settings.app_version,
64
+ }
api/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """FastAPI routers grouped by domain: ingest, query, jobs, and audit."""
api/routes/audit.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query audit log HTTP routes.
2
+
3
+ Every successful ask/summarise call writes to SQLite via :mod:`storage.audit_store`.
4
+ These endpoints expose paginated list and per-query detail for compliance review.
5
+ """
6
+
7
+ from typing import Annotated
8
+
9
+ from fastapi import APIRouter, Depends, HTTPException, Query, status
10
+
11
+ from api.config import get_settings
12
+ from models.requests import AuditListParams
13
+ from models.responses import AuditLogDetailResponse, AuditLogsResponse
14
+ from storage.audit_store import get_audit_event, list_audit_events
15
+
16
+
17
+ def _audit_list_params(
18
+ limit: Annotated[int, Query(ge=1, le=100)] = 50,
19
+ offset: Annotated[int, Query(ge=0)] = 0,
20
+ user_id: Annotated[str | None, Query(max_length=256)] = None,
21
+ from_date: Annotated[str | None, Query(description="ISO 8601 lower bound")] = None,
22
+ to_date: Annotated[str | None, Query(description="ISO 8601 upper bound")] = None,
23
+ ) -> AuditListParams:
24
+ return AuditListParams(
25
+ limit=limit,
26
+ offset=offset,
27
+ user_id=user_id,
28
+ from_date=from_date,
29
+ to_date=to_date,
30
+ )
31
+
32
+
33
+ router = APIRouter(prefix="/audit", tags=["audit"])
34
+
35
+
36
+ @router.get("/logs", response_model=AuditLogsResponse)
37
+ async def audit_logs(
38
+ params: Annotated[AuditListParams, Depends(_audit_list_params)],
39
+ ) -> AuditLogsResponse:
40
+ """Paginated audit trail with optional user and date filters."""
41
+ settings = get_settings()
42
+ logs, total = await list_audit_events(
43
+ settings.audit_db_path,
44
+ limit=params.limit,
45
+ offset=params.offset,
46
+ user_id=params.user_id,
47
+ from_date=params.from_date,
48
+ to_date=params.to_date,
49
+ )
50
+ return AuditLogsResponse(
51
+ logs=logs,
52
+ total=total,
53
+ limit=params.limit,
54
+ offset=params.offset,
55
+ )
56
+
57
+
58
+ @router.get("/logs/{query_id}", response_model=AuditLogDetailResponse)
59
+ async def audit_log_detail(query_id: str) -> AuditLogDetailResponse:
60
+ """Full answer and citations for one audited query."""
61
+ settings = get_settings()
62
+ event = await get_audit_event(settings.audit_db_path, query_id)
63
+ if event is None:
64
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
65
+ return event
api/routes/ingest.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document ingestion HTTP routes.
2
+
3
+ Endpoints under ``/ingest`` queue background jobs that load PDF/TXT/MD files (upload or URL),
4
+ chunk and embed them, and write vectors into a named Chroma collection. Poll ``/jobs/{id}``
5
+ for progress. Collection listing and deletion are synchronous.
6
+ """
7
+
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from tempfile import NamedTemporaryFile
11
+ from typing import Annotated
12
+ from urllib.parse import unquote, urlparse
13
+
14
+ import httpx
15
+ from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
16
+
17
+ from api.config import get_settings
18
+ from models.requests import URLIngestRequest
19
+ from models.responses import (
20
+ CollectionItem,
21
+ IngestCollectionsResponse,
22
+ IngestDeleteCollectionResponse,
23
+ IngestUploadResponse,
24
+ UrlIngestResponse,
25
+ )
26
+ from rag.vector_store import (
27
+ collection_created_at,
28
+ collection_document_count,
29
+ delete_collection,
30
+ ensure_collection_created_at,
31
+ list_collection_names,
32
+ )
33
+ from storage.job_store import create_ingest_job, earliest_job_created_at_for_collection
34
+ from workers.ingest_worker import run_ingest_job
35
+
36
+ router = APIRouter(prefix="/ingest", tags=["ingest"])
37
+
38
+ _SUPPORTED_EXTENSIONS = frozenset({".pdf", ".txt", ".md"})
39
+
40
+ _CONTENT_TYPE_SUFFIX: dict[str, str] = {
41
+ "application/pdf": ".pdf",
42
+ "text/plain": ".txt",
43
+ "text/markdown": ".md",
44
+ "text/x-markdown": ".md",
45
+ }
46
+
47
+
48
+ def _validate_file(file: UploadFile, max_bytes: int) -> str:
49
+ """Check extension and size; return normalized suffix (e.g. ``.pdf``)."""
50
+ filename = (file.filename or "").strip()
51
+ if not filename:
52
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Filename is required.")
53
+
54
+ suffix = Path(filename).suffix.lower()
55
+ if suffix not in _SUPPORTED_EXTENSIONS:
56
+ raise HTTPException(
57
+ status_code=status.HTTP_400_BAD_REQUEST,
58
+ detail="Unsupported file type. Only PDF, TXT, and MD are accepted.",
59
+ )
60
+
61
+ file.file.seek(0, 2)
62
+ size = file.file.tell()
63
+ file.file.seek(0)
64
+ if size > max_bytes:
65
+ raise HTTPException(
66
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
67
+ detail=f"File too large. Max allowed is {max_bytes // (1024 * 1024)}MB.",
68
+ )
69
+
70
+ return suffix
71
+
72
+
73
+ def _suffix_from_url_path(url: str) -> str | None:
74
+ path = urlparse(url).path
75
+ suffix = Path(unquote(path)).suffix.lower()
76
+ return suffix if suffix in _SUPPORTED_EXTENSIONS else None
77
+
78
+
79
+ def _suffix_from_content_type(content_type: str | None) -> str | None:
80
+ if not content_type:
81
+ return None
82
+ base = content_type.split(";")[0].strip().lower()
83
+ return _CONTENT_TYPE_SUFFIX.get(base)
84
+
85
+
86
+ def _download_request_headers(user_agent: str) -> dict[str, str]:
87
+ """Headers for remote URL fetches (SEC.gov requires declared User-Agent + Accept-Encoding)."""
88
+ return {
89
+ "User-Agent": user_agent.strip() or "DocuAudit AI docuaudit-ingest@example.com",
90
+ "Accept-Encoding": "gzip, deflate",
91
+ "Accept": "application/pdf,text/plain,text/markdown,*/*;q=0.8",
92
+ }
93
+
94
+
95
+ def _display_name_from_url(url: str, suffix: str) -> str:
96
+ name = Path(unquote(urlparse(url).path)).name.strip()
97
+ if not name or name in {"/", "."}:
98
+ return f"download{suffix}"
99
+ if Path(name).suffix.lower() not in _SUPPORTED_EXTENSIONS:
100
+ return f"{name}{suffix}" if not name.endswith(suffix) else name
101
+ return name
102
+
103
+
104
+ async def _download_url_to_temp(url: str, max_bytes: int, user_agent: str | None = None) -> tuple[str, str]:
105
+ """Stream-download a URL to a temp file; return ``(path, display_name)``."""
106
+ parsed = urlparse(url)
107
+ if parsed.scheme not in ("http", "https"):
108
+ raise HTTPException(
109
+ status_code=status.HTTP_400_BAD_REQUEST,
110
+ detail="Only http and https URLs are supported.",
111
+ )
112
+
113
+ ua = user_agent or get_settings().ingest_user_agent
114
+ timeout = httpx.Timeout(60.0, connect=10.0)
115
+ limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
116
+ headers = _download_request_headers(ua)
117
+
118
+ try:
119
+ async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
120
+ async with client.stream("GET", url, headers=headers) as response:
121
+ response.raise_for_status()
122
+ content_type = response.headers.get("content-type")
123
+ suffix = _suffix_from_url_path(url) or _suffix_from_content_type(content_type)
124
+ if not suffix:
125
+ raise HTTPException(
126
+ status_code=status.HTTP_400_BAD_REQUEST,
127
+ detail=(
128
+ "Could not determine file type from the URL path or Content-Type. "
129
+ "Provide a .pdf, .txt, or .md resource with matching content-type."
130
+ ),
131
+ )
132
+
133
+ display_name = _display_name_from_url(url, suffix)
134
+ total = 0
135
+ with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
136
+ temp_path = tmp.name
137
+ async for chunk in response.aiter_bytes(chunk_size=65536):
138
+ total += len(chunk)
139
+ if total > max_bytes:
140
+ raise HTTPException(
141
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
142
+ detail=f"Download too large. Max allowed is {max_bytes // (1024 * 1024)}MB.",
143
+ )
144
+ tmp.write(chunk)
145
+ except HTTPException:
146
+ raise
147
+ except httpx.HTTPStatusError as exc:
148
+ code = exc.response.status_code if exc.response else "unknown"
149
+ detail = f"Remote server returned HTTP {code}."
150
+ if code == 403 and "sec.gov" in parsed.netloc.lower():
151
+ detail += (
152
+ " SEC.gov requires a declared User-Agent ('Company Name you@email.com'). "
153
+ "Set INGEST_USER_AGENT in .env (see sec.gov/os/accessing-edgar-data)."
154
+ )
155
+ raise HTTPException(
156
+ status_code=status.HTTP_502_BAD_GATEWAY,
157
+ detail=detail,
158
+ ) from exc
159
+ except httpx.RequestError as exc:
160
+ raise HTTPException(
161
+ status_code=status.HTTP_502_BAD_GATEWAY,
162
+ detail=f"Failed to download URL: {exc}",
163
+ ) from exc
164
+
165
+ return temp_path, display_name
166
+
167
+
168
+ def _parse_created_at(raw: str | None) -> datetime | None:
169
+ if not raw:
170
+ return None
171
+ s = raw.strip()
172
+ if s.endswith("Z"):
173
+ s = s[:-1] + "+00:00"
174
+ try:
175
+ dt = datetime.fromisoformat(s)
176
+ if dt.tzinfo is None:
177
+ return dt.replace(tzinfo=timezone.utc)
178
+ return dt
179
+ except ValueError:
180
+ return None
181
+
182
+
183
+ @router.post("/upload", response_model=IngestUploadResponse)
184
+ async def upload_endpoint(
185
+ background_tasks: BackgroundTasks,
186
+ files: list[UploadFile] = File(..., description="One or more PDF, TXT, or MD files"),
187
+ collection_name: Annotated[str, Form(min_length=1, max_length=256)] = "default",
188
+ ) -> IngestUploadResponse:
189
+ """Accept multipart file uploads, validate, and queue a background ingest job."""
190
+ settings = get_settings()
191
+ if not files:
192
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required.")
193
+ if len(files) > settings.max_documents_per_batch:
194
+ raise HTTPException(
195
+ status_code=status.HTTP_400_BAD_REQUEST,
196
+ detail=f"Too many files in one request (max {settings.max_documents_per_batch}).",
197
+ )
198
+
199
+ max_bytes = settings.max_file_size_mb * 1024 * 1024
200
+ temp_paths: list[tuple[str, str]] = []
201
+ filenames: list[str] = []
202
+ try:
203
+ for file in files:
204
+ suffix = _validate_file(file, max_bytes)
205
+ display_name = (file.filename or "upload").strip()
206
+ file_bytes = await file.read()
207
+ with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
208
+ tmp.write(file_bytes)
209
+ temp_paths.append((tmp.name, display_name))
210
+ filenames.append(display_name)
211
+ await file.close()
212
+
213
+ job_id = await create_ingest_job(
214
+ settings.jobs_db_path,
215
+ collection_name=collection_name.strip(),
216
+ filenames=filenames,
217
+ )
218
+
219
+ background_tasks.add_task(
220
+ run_ingest_job,
221
+ job_id,
222
+ temp_paths,
223
+ collection_name.strip(),
224
+ settings.jobs_db_path,
225
+ settings.chroma_persist_directory,
226
+ )
227
+
228
+ return IngestUploadResponse(
229
+ job_id=job_id,
230
+ status="queued",
231
+ total_files=len(filenames),
232
+ filenames=filenames,
233
+ message=f"Documents queued for processing. Poll /jobs/{job_id} for status.",
234
+ )
235
+ except HTTPException:
236
+ for path, _ in temp_paths:
237
+ Path(path).unlink(missing_ok=True)
238
+ raise
239
+ except Exception as exc:
240
+ for path, _ in temp_paths:
241
+ Path(path).unlink(missing_ok=True)
242
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
243
+
244
+
245
+ @router.post("/url", response_model=UrlIngestResponse)
246
+ async def ingest_url_endpoint(
247
+ background_tasks: BackgroundTasks,
248
+ payload: URLIngestRequest,
249
+ ) -> UrlIngestResponse:
250
+ """Download one or more HTTP(S) documents and queue them for ingestion."""
251
+ settings = get_settings()
252
+ max_bytes = settings.max_file_size_mb * 1024 * 1024
253
+ url_strings = [str(u).strip() for u in payload.urls]
254
+ if len(url_strings) > settings.max_documents_per_batch:
255
+ raise HTTPException(
256
+ status_code=status.HTTP_400_BAD_REQUEST,
257
+ detail=f"Too many URLs in one request (max {settings.max_documents_per_batch}).",
258
+ )
259
+
260
+ downloaded: list[tuple[str, str]] = []
261
+ try:
262
+ for url_str in url_strings:
263
+ temp_path, display_name = await _download_url_to_temp(
264
+ url_str, max_bytes, user_agent=settings.ingest_user_agent
265
+ )
266
+ downloaded.append((temp_path, display_name))
267
+
268
+ coll = (payload.collection_name or "default").strip()
269
+ job_id = await create_ingest_job(
270
+ settings.jobs_db_path,
271
+ collection_name=coll,
272
+ filenames=[name for _, name in downloaded],
273
+ )
274
+
275
+ background_tasks.add_task(
276
+ run_ingest_job,
277
+ job_id,
278
+ downloaded,
279
+ coll,
280
+ settings.jobs_db_path,
281
+ settings.chroma_persist_directory,
282
+ )
283
+
284
+ return UrlIngestResponse(
285
+ job_id=job_id,
286
+ status="queued",
287
+ total_urls=len(downloaded),
288
+ message="URLs queued for download and processing.",
289
+ )
290
+ except HTTPException:
291
+ for path, _ in downloaded:
292
+ Path(path).unlink(missing_ok=True)
293
+ raise
294
+ except Exception as exc:
295
+ for path, _ in downloaded:
296
+ Path(path).unlink(missing_ok=True)
297
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
298
+
299
+
300
+ @router.get("/collections", response_model=IngestCollectionsResponse)
301
+ async def list_collections_endpoint() -> IngestCollectionsResponse:
302
+ """List Chroma collections with document counts and creation timestamps."""
303
+ settings = get_settings()
304
+ try:
305
+ names = list_collection_names(settings.chroma_persist_directory)
306
+ except Exception as exc:
307
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
308
+ items: list[CollectionItem] = []
309
+ for n in names:
310
+ cnt = collection_document_count(settings.chroma_persist_directory, n)
311
+ raw_created = collection_created_at(settings.chroma_persist_directory, n)
312
+ if not raw_created:
313
+ job_fallback = await earliest_job_created_at_for_collection(settings.jobs_db_path, n)
314
+ raw_created = ensure_collection_created_at(
315
+ settings.chroma_persist_directory,
316
+ n,
317
+ fallback=job_fallback,
318
+ )
319
+ items.append(
320
+ CollectionItem(
321
+ name=n,
322
+ document_count=cnt,
323
+ created_at=_parse_created_at(raw_created),
324
+ )
325
+ )
326
+ return IngestCollectionsResponse(collections=items, total=len(items))
327
+
328
+
329
+ @router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
330
+ async def delete_collection_endpoint(collection_name: str) -> IngestDeleteCollectionResponse:
331
+ """Remove a Chroma collection and all embedded chunks."""
332
+ if not collection_name.strip():
333
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="collection_name is required.")
334
+ settings = get_settings()
335
+ name = collection_name.strip()
336
+ try:
337
+ existing = list_collection_names(settings.chroma_persist_directory)
338
+ if name not in existing:
339
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
340
+ removed = delete_collection(settings.chroma_persist_directory, name)
341
+ except HTTPException:
342
+ raise
343
+ except Exception as exc:
344
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
345
+ return IngestDeleteCollectionResponse(
346
+ message=f"Collection '{name}' deleted successfully.",
347
+ documents_removed=removed,
348
+ )
api/routes/jobs.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ingest job status and listing.
2
+
3
+ Jobs are created by upload/URL ingest routes and updated by :mod:`workers.ingest_worker`.
4
+ """
5
+
6
+ from typing import Annotated
7
+
8
+ from fastapi import APIRouter, Depends, HTTPException, Query, status
9
+
10
+ from api.config import get_settings
11
+ from models.requests import JobsListParams
12
+ from models.responses import JobListResponse, JobStatusResponse
13
+ from storage.job_store import get_job_status, list_ingest_jobs
14
+
15
+
16
+ def _jobs_list_params(
17
+ limit: Annotated[int, Query(ge=1, le=100)] = 10,
18
+ offset: Annotated[int, Query(ge=0)] = 0,
19
+ ) -> JobsListParams:
20
+ return JobsListParams(limit=limit, offset=offset)
21
+
22
+
23
+ router = APIRouter(tags=["jobs"])
24
+
25
+
26
+ @router.get("/jobs", response_model=JobListResponse)
27
+ async def list_jobs(
28
+ params: Annotated[JobsListParams, Depends(_jobs_list_params)],
29
+ ) -> JobListResponse:
30
+ """Paginated list of ingest jobs (newest first)."""
31
+ settings = get_settings()
32
+ jobs, total = await list_ingest_jobs(
33
+ settings.jobs_db_path,
34
+ limit=params.limit,
35
+ offset=params.offset,
36
+ )
37
+ return JobListResponse(jobs=jobs, total=total)
38
+
39
+
40
+ @router.get("/jobs/{job_id}", response_model=JobStatusResponse)
41
+ async def get_job(job_id: str) -> JobStatusResponse:
42
+ """Poll a single job by id (404 if unknown)."""
43
+ settings = get_settings()
44
+ job = await get_job_status(settings.jobs_db_path, job_id)
45
+ if job is None:
46
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found.")
47
+ return job
api/routes/query.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grounded Q&A and summarisation routes.
2
+
3
+ ``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with
4
+ citations enforced in the prompt, persists an audit row, and returns answer + sources.
5
+ ``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt.
6
+ ``POST /query`` is a legacy alias for ``/query/ask``.
7
+ """
8
+
9
+ import time
10
+ from datetime import datetime, timezone
11
+ from uuid import uuid4
12
+
13
+ from fastapi import APIRouter, HTTPException, status
14
+
15
+ from api.config import Settings, get_settings
16
+ from models.requests import QueryRequest, SummariseRequest
17
+ from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
18
+ from rag.embedder import create_embedding_function
19
+ from rag.retriever import (
20
+ SUMMARY_RETRIEVAL_QUERY,
21
+ RetrievedChunk,
22
+ answer_with_grounding,
23
+ retrieve_chunks,
24
+ summarise_with_grounding,
25
+ )
26
+ from rag.vector_store import collection_document_count, get_vector_store
27
+ from storage.audit_store import persist_query_audit
28
+
29
+ router = APIRouter(prefix="/query", tags=["query"])
30
+
31
+
32
+ def _model_used_label(settings: Settings) -> str:
33
+ provider = settings.llm_provider.lower()
34
+ if provider == "openai":
35
+ return settings.openai_model
36
+ if provider == "ollama":
37
+ return settings.ollama_chat_model
38
+ if provider == "anthropic":
39
+ return settings.anthropic_model
40
+ if provider == "huggingface":
41
+ return settings.huggingface_model
42
+ return f"{provider}:unknown"
43
+
44
+
45
+ def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
46
+ citations: list[SourceCitation] = []
47
+ for chunk in chunks:
48
+ page = chunk.page if chunk.page is not None else 0
49
+ score = float(chunk.score) if chunk.score is not None else 0.0
50
+ citations.append(
51
+ SourceCitation(
52
+ document_name=chunk.source or "unknown",
53
+ page_number=page,
54
+ chunk_text=chunk.text,
55
+ relevance_score=score,
56
+ )
57
+ )
58
+ return citations
59
+
60
+
61
+ async def _run_ask(
62
+ settings: Settings,
63
+ payload: QueryRequest,
64
+ ) -> AskQueryResponse:
65
+ """Retrieve, generate grounded answer, audit, and build the API response."""
66
+ top_k = payload.top_k
67
+ t0 = time.perf_counter()
68
+ embedding_function = create_embedding_function()
69
+ vector_store = get_vector_store(
70
+ persist_directory=settings.chroma_persist_directory,
71
+ collection_name=payload.collection_name or "default",
72
+ embedding_function=embedding_function,
73
+ )
74
+ chunks = retrieve_chunks(vector_store, payload.question, top_k)
75
+ answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
76
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
77
+ citations = _chunks_to_citations(chunks)
78
+ query_id = str(uuid4())
79
+ ts = datetime.now(timezone.utc)
80
+ response = AskQueryResponse(
81
+ query_id=query_id,
82
+ question=payload.question,
83
+ answer=answer,
84
+ sources=citations,
85
+ model_used=_model_used_label(settings),
86
+ tokens_used=tokens_used,
87
+ response_time_ms=elapsed_ms,
88
+ timestamp=ts,
89
+ )
90
+ await persist_query_audit(
91
+ settings.audit_db_path,
92
+ query_id=query_id,
93
+ action="query",
94
+ user_id=payload.user_id,
95
+ question=payload.question,
96
+ collection_name=payload.collection_name or "default",
97
+ answer=answer,
98
+ sources=citations,
99
+ model_used=response.model_used,
100
+ tokens_used=tokens_used,
101
+ response_time_ms=elapsed_ms,
102
+ kind="ask",
103
+ )
104
+ return response
105
+
106
+
107
+ async def _run_summarise(
108
+ settings: Settings,
109
+ payload: SummariseRequest,
110
+ ) -> SummariseQueryResponse:
111
+ """Retrieve with focus or default overview query, summarise, and audit."""
112
+ top_k = settings.top_k_results
113
+ retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
114
+ audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
115
+ t0 = time.perf_counter()
116
+ embedding_function = create_embedding_function()
117
+ vector_store = get_vector_store(
118
+ persist_directory=settings.chroma_persist_directory,
119
+ collection_name=payload.collection_name,
120
+ embedding_function=embedding_function,
121
+ )
122
+ chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
123
+ summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
124
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
125
+ citations = _chunks_to_citations(chunks)
126
+ doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
127
+ query_id = str(uuid4())
128
+ ts = datetime.now(timezone.utc)
129
+ response = SummariseQueryResponse(
130
+ query_id=query_id,
131
+ summary=summary,
132
+ document_count=doc_count,
133
+ sources=citations,
134
+ timestamp=ts,
135
+ )
136
+ await persist_query_audit(
137
+ settings.audit_db_path,
138
+ query_id=query_id,
139
+ action="summarise",
140
+ user_id=payload.user_id,
141
+ question=audit_question,
142
+ collection_name=payload.collection_name,
143
+ answer=summary,
144
+ sources=citations,
145
+ model_used=_model_used_label(settings),
146
+ tokens_used=tokens_used,
147
+ response_time_ms=elapsed_ms,
148
+ kind="summarise",
149
+ )
150
+ return response
151
+
152
+
153
+ @router.post("/ask", response_model=AskQueryResponse)
154
+ async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
155
+ """Grounded question answering against a Chroma collection."""
156
+ settings = get_settings()
157
+ try:
158
+ return await _run_ask(settings, payload)
159
+ except Exception as exc:
160
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
161
+
162
+
163
+ @router.post("/summarise", response_model=SummariseQueryResponse)
164
+ async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
165
+ """Collection-wide summary with optional focus for retrieval."""
166
+ settings = get_settings()
167
+ try:
168
+ return await _run_summarise(settings, payload)
169
+ except Exception as exc:
170
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
171
+
172
+
173
+ legacy_query_router = APIRouter(tags=["query"])
174
+
175
+
176
+ @legacy_query_router.post("/query", response_model=AskQueryResponse)
177
+ async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
178
+ """Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
179
+ return await ask_endpoint(payload)
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Spaces default entry (Streamlit SDK expects `app.py`).
2
+
3
+ Local development can still use `streamlit run streamlit_app.py`; Docker Compose uses `app.py`
4
+ so the same entry path works on the Hub and in containers.
5
+
6
+ On Hugging Face Streamlit Spaces only `streamlit run app.py` is started — no separate uvicorn
7
+ process — so we spawn the FastAPI app on 127.0.0.1:8000 when `SPACE_ID` is present (see Hub
8
+ built-in env vars). Set `DOC_AUDI_EMBED_API=0` to disable. Use `DOC_AUDI_EMBED_API=1` to force
9
+ embedding elsewhere (e.g. demos).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import atexit
15
+ import os
16
+ import socket
17
+ import subprocess
18
+ import sys
19
+ import time
20
+
21
+ _uvicorn_proc: subprocess.Popen[bytes] | None = None
22
+ _cleanup_registered = False
23
+
24
+
25
+ def _port_accepting_connections(host: str, port: int) -> bool:
26
+ try:
27
+ with socket.create_connection((host, port), timeout=0.3):
28
+ return True
29
+ except OSError:
30
+ return False
31
+
32
+
33
+ def _want_embedded_api() -> bool:
34
+ if os.environ.get("DOC_AUDI_EMBED_API", "").lower() in ("0", "false", "no"):
35
+ return False
36
+ if os.environ.get("DOC_AUDI_EMBED_API", "").lower() in ("1", "true", "yes"):
37
+ return True
38
+ return bool(os.environ.get("SPACE_ID"))
39
+
40
+
41
+ def _propagate_streamlit_secrets_to_environ() -> None:
42
+ """Copy Hub tokens from Streamlit secrets into os.environ for the embedded uvicorn child.
43
+
44
+ On Hugging Face Streamlit Spaces, repository secrets are often available as ``st.secrets``
45
+ but are not always present in ``os.environ``. ``subprocess.Popen`` only forwards the
46
+ process environment, so the API would miss ``HF_TOKEN`` / ``HUGGINGFACE_API_KEY`` otherwise.
47
+ """
48
+ try:
49
+ import streamlit as st
50
+ except ImportError:
51
+ return
52
+ secrets = getattr(st, "secrets", None)
53
+ if secrets is None:
54
+ return
55
+ for key in ("HF_TOKEN", "HUGGINGFACE_API_KEY", "HUGGING_FACE_HUB_TOKEN"):
56
+ if (os.environ.get(key) or "").strip():
57
+ continue
58
+ try:
59
+ raw = secrets[key]
60
+ except Exception:
61
+ continue
62
+ if raw is not None and str(raw).strip():
63
+ os.environ[key] = str(raw).strip()
64
+
65
+
66
+ def _maybe_start_embedded_uvicorn() -> None:
67
+ """Start uvicorn in-process when running on HF Spaces (or when DOC_AUDI_EMBED_API=1)."""
68
+ global _uvicorn_proc, _cleanup_registered
69
+ if not _want_embedded_api():
70
+ return
71
+ _propagate_streamlit_secrets_to_environ()
72
+ if _port_accepting_connections("127.0.0.1", 8000):
73
+ return
74
+ if _uvicorn_proc is not None and _uvicorn_proc.poll() is None:
75
+ for _ in range(120):
76
+ if _port_accepting_connections("127.0.0.1", 8000):
77
+ return
78
+ time.sleep(0.05)
79
+ return
80
+
81
+ cmd = [
82
+ sys.executable,
83
+ "-m",
84
+ "uvicorn",
85
+ "api.main:app",
86
+ "--host",
87
+ "127.0.0.1",
88
+ "--port",
89
+ "8000",
90
+ ]
91
+ _uvicorn_proc = subprocess.Popen(cmd)
92
+ proc = _uvicorn_proc
93
+
94
+ if not _cleanup_registered:
95
+
96
+ def _cleanup(p: subprocess.Popen[bytes] = proc) -> None:
97
+ if p.poll() is None:
98
+ p.terminate()
99
+ try:
100
+ p.wait(timeout=10)
101
+ except subprocess.TimeoutExpired:
102
+ p.kill()
103
+
104
+ atexit.register(_cleanup)
105
+ _cleanup_registered = True
106
+
107
+ for _ in range(120):
108
+ if _port_accepting_connections("127.0.0.1", 8000):
109
+ return
110
+ time.sleep(0.05)
111
+
112
+
113
+ _maybe_start_embedded_uvicorn()
114
+
115
+ from streamlit_app import main # noqa: E402 — start API before importing Streamlit stack
116
+
117
+ main()
docker-compose.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requires a project `.env` (copy from `.env.example`) for `env_file` and variable substitution.
2
+ name: docuaudit-ai
3
+
4
+ x-app: &app
5
+ build: .
6
+ image: docuaudit-ai:${IMAGE_TAG:-local}
7
+
8
+ services:
9
+ api:
10
+ <<: *app
11
+ command: uvicorn api.main:app --host 0.0.0.0 --port 8000
12
+ ports:
13
+ - "${API_PORT:-8000}:8000"
14
+ env_file:
15
+ - .env
16
+ environment:
17
+ CHROMA_PERSIST_DIRECTORY: /data/chroma
18
+ AUDIT_DB_PATH: /data/audit.db
19
+ JOBS_DB_PATH: /data/jobs.db
20
+ OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
21
+ volumes:
22
+ - docuaudit_data:/data
23
+ extra_hosts:
24
+ - "host.docker.internal:host-gateway"
25
+ healthcheck:
26
+ test:
27
+ [
28
+ "CMD",
29
+ "python",
30
+ "-c",
31
+ "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5)",
32
+ ]
33
+ interval: 15s
34
+ timeout: 5s
35
+ retries: 5
36
+ start_period: 40s
37
+
38
+ streamlit:
39
+ <<: *app
40
+ command: >
41
+ streamlit run app.py
42
+ --server.port=8501
43
+ --server.address=0.0.0.0
44
+ --server.headless=true
45
+ --browser.gatherUsageStats=false
46
+ ports:
47
+ - "${STREAMLIT_PORT:-8501}:8501"
48
+ env_file:
49
+ - .env
50
+ environment:
51
+ DOC_AUDI_API_BASE: http://api:8000
52
+ STREAMLIT_BACKEND_URL: http://api:8000
53
+ depends_on:
54
+ api:
55
+ condition: service_healthy
56
+
57
+ ollama:
58
+ image: ollama/ollama:latest
59
+ profiles: ["ollama"]
60
+ ports:
61
+ - "${OLLAMA_HOST_PORT:-11434}:11434"
62
+ volumes:
63
+ - ollama_data:/root/.ollama
64
+
65
+ volumes:
66
+ docuaudit_data:
67
+ ollama_data:
main.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal CLI placeholder (not used by Docker or Hugging Face entrypoints).
2
+
3
+ Production entrypoints: ``api.main:app`` (FastAPI) and ``app.py`` / ``streamlit_app.py`` (UI).
4
+ """
5
+
6
+
7
+ def main() -> None:
8
+ """Print a hello message when run as ``python main.py``."""
9
+ print("Hello from doc-audi-ai!")
10
+
11
+
12
+ if __name__ == "__main__":
13
+ main()
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API contract models: request payloads and response DTOs."""
models/requests.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic request bodies and query-parameter models for the HTTP API.
2
+
3
+ Used by FastAPI route handlers for validation and OpenAPI schema generation.
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, HttpUrl
9
+
10
+
11
+ class QueryRequest(BaseModel):
12
+ model_config = ConfigDict(extra="forbid")
13
+
14
+ question: str = Field(min_length=5, max_length=2000, description="Natural language question")
15
+ collection_name: Optional[str] = Field(
16
+ default="default",
17
+ min_length=1,
18
+ max_length=256,
19
+ description="Chroma collection to search",
20
+ )
21
+ top_k: int = Field(default=5, ge=1, le=20, description="Number of chunks to retrieve")
22
+ user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
23
+
24
+
25
+ class SummariseRequest(BaseModel):
26
+ model_config = ConfigDict(extra="forbid")
27
+
28
+ collection_name: str = Field(
29
+ default="default",
30
+ min_length=1,
31
+ max_length=256,
32
+ description="Chroma collection to summarise",
33
+ )
34
+ focus: str | None = Field(
35
+ default=None,
36
+ max_length=8000,
37
+ description="Optional angle or scope for retrieval and the summary",
38
+ )
39
+ user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
40
+
41
+
42
+ class URLIngestRequest(BaseModel):
43
+ model_config = ConfigDict(extra="forbid")
44
+
45
+ urls: list[HttpUrl] = Field(
46
+ min_length=1,
47
+ max_length=100,
48
+ description="One or more HTTP(S) URLs to PDF, TXT, or Markdown documents",
49
+ )
50
+ collection_name: Optional[str] = Field(
51
+ default="default",
52
+ min_length=1,
53
+ max_length=256,
54
+ description="Target Chroma collection name",
55
+ )
56
+
57
+
58
+ class JobsListParams(BaseModel):
59
+ model_config = ConfigDict(extra="forbid")
60
+
61
+ limit: int = Field(default=10, ge=1, le=100, description="Max jobs to return")
62
+ offset: int = Field(default=0, ge=0, description="Offset for pagination")
63
+
64
+
65
+ class AuditListParams(BaseModel):
66
+ model_config = ConfigDict(extra="forbid")
67
+
68
+ limit: int = Field(default=50, ge=1, le=100, description="Max log entries to return")
69
+ offset: int = Field(default=0, ge=0, description="Offset for pagination")
70
+ user_id: str | None = Field(default=None, max_length=256, description="Filter by user id")
71
+ from_date: str | None = Field(
72
+ default=None,
73
+ description="ISO 8601 datetime lower bound (inclusive) on timestamp",
74
+ )
75
+ to_date: str | None = Field(
76
+ default=None,
77
+ description="ISO 8601 datetime upper bound (inclusive) on timestamp",
78
+ )
models/responses.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic response models returned by FastAPI routes.
2
+
3
+ Shared shape: :class:`SourceCitation` appears on ask, summarise, and audit detail responses.
4
+ """
5
+
6
+ from datetime import datetime
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ # --- Shared citations (spec-shaped) ---
12
+
13
+
14
+ class SourceCitation(BaseModel):
15
+ document_name: str
16
+ page_number: int
17
+ chunk_text: str
18
+ relevance_score: float
19
+
20
+
21
+ # --- Query: ask ---
22
+
23
+
24
+ class AskQueryResponse(BaseModel):
25
+ query_id: str
26
+ question: str
27
+ answer: str
28
+ sources: list[SourceCitation] = Field(default_factory=list)
29
+ model_used: str
30
+ tokens_used: int
31
+ response_time_ms: int
32
+ timestamp: datetime
33
+
34
+
35
+ # --- Query: summarise ---
36
+
37
+
38
+ class SummariseQueryResponse(BaseModel):
39
+ query_id: str
40
+ summary: str
41
+ document_count: int
42
+ sources: list[SourceCitation] = Field(default_factory=list)
43
+ timestamp: datetime
44
+
45
+
46
+ # --- Ingest ---
47
+
48
+
49
+ class IngestUploadResponse(BaseModel):
50
+ job_id: str
51
+ status: str
52
+ total_files: int
53
+ filenames: list[str]
54
+ message: str
55
+
56
+
57
+ class UrlIngestResponse(BaseModel):
58
+ job_id: str
59
+ status: str
60
+ total_urls: int
61
+ message: str
62
+
63
+
64
+ class CollectionItem(BaseModel):
65
+ name: str
66
+ document_count: int
67
+ created_at: datetime | None = None
68
+
69
+
70
+ class IngestCollectionsResponse(BaseModel):
71
+ collections: list[CollectionItem] = Field(default_factory=list)
72
+ total: int
73
+
74
+
75
+ class IngestDeleteCollectionResponse(BaseModel):
76
+ message: str
77
+ documents_removed: int
78
+
79
+
80
+ # --- Jobs ---
81
+
82
+
83
+ class JobStatusResponse(BaseModel):
84
+ job_id: str
85
+ status: str
86
+ total_files: int
87
+ processed_files: int
88
+ failed_files: int
89
+ progress_percent: int
90
+ started_at: datetime | None
91
+ completed_at: datetime | None
92
+ errors: list[str] = Field(default_factory=list)
93
+
94
+
95
+ class JobListItem(BaseModel):
96
+ job_id: str
97
+ status: str
98
+ total_files: int
99
+ completed_at: datetime | None = None
100
+
101
+
102
+ class JobListResponse(BaseModel):
103
+ jobs: list[JobListItem] = Field(default_factory=list)
104
+ total: int
105
+
106
+
107
+ # --- Audit ---
108
+
109
+
110
+ class AuditLogEntry(BaseModel):
111
+ query_id: str
112
+ user_id: str
113
+ question: str
114
+ answer_summary: str
115
+ sources_count: int
116
+ model_used: str | None
117
+ timestamp: datetime
118
+
119
+
120
+ class AuditLogsResponse(BaseModel):
121
+ logs: list[AuditLogEntry] = Field(default_factory=list)
122
+ total: int
123
+ limit: int
124
+ offset: int
125
+
126
+
127
+ class AuditLogDetailResponse(BaseModel):
128
+ query_id: str
129
+ user_id: str
130
+ question: str
131
+ full_answer: str
132
+ sources: list[SourceCitation] = Field(default_factory=list)
133
+ model_used: str | None
134
+ tokens_used: int | None
135
+ timestamp: datetime
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "doc-audi-ai"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "fastapi==0.111.0",
9
+ "langchain==0.2.0",
10
+ "langchain-openai==0.1.7",
11
+ "langchain-community==0.2.0",
12
+ "langchain-chroma==0.1.4",
13
+ "langchain-text-splitters==0.2.0",
14
+ "langchain-anthropic==0.1.15",
15
+ "langchain-ollama==0.1.3",
16
+ "chromadb==0.5.0",
17
+ # Chroma 0.5 calls posthog.capture(distinct_id, event, props); posthog 6+ removed that API (breaks telemetry + spams stderr).
18
+ "posthog>=3.7.0,<4",
19
+ "openai==1.30.1",
20
+ "anthropic==0.28.1",
21
+ "pydantic-settings==2.3.4",
22
+ "pymupdf==1.25.5",
23
+ "python-multipart==0.0.9",
24
+ "aiosqlite>=0.21.0",
25
+ "httpx>=0.27.0",
26
+ "uvicorn[standard]==0.29.0",
27
+ "huggingface-hub>=1.13.0",
28
+ "langchain-huggingface>=0.0.3",
29
+ "streamlit>=1.39.0",
30
+ "pytest>=8.4.2",
31
+ "pytest-asyncio>=1.2.0",
32
+ "onnxruntime==1.23.2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
33
+ "torch==2.2.2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
34
+ ]
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ python_files = test_*.py
rag/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """RAG pipeline: load → chunk → embed → store → retrieve → generate.
2
+
3
+ Submodules: :mod:`loader`, :mod:`chunker`, :mod:`embedder`, :mod:`vector_store`,
4
+ :mod:`retriever`, and :mod:`hf_hub_inference` for Hugging Face Hub compatibility.
5
+ """
6
+
rag/chunker.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Split loaded documents into overlapping chunks for embedding.
2
+
3
+ Chunk size and overlap come from :func:`api.config.get_settings`. Each chunk receives
4
+ ``chunk_index``, ``source``, and ``page`` metadata.
5
+ """
6
+
7
+ from langchain_core.documents import Document
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+
10
+ from api.config import get_settings
11
+
12
+ def chunk_documents(
13
+ documents: list[Document],
14
+ ) -> list[Document]:
15
+ """Recursive character split of all input documents."""
16
+ settings = get_settings()
17
+ splitter = RecursiveCharacterTextSplitter(
18
+ chunk_size=settings.chunk_size,
19
+ chunk_overlap=settings.chunk_overlap,
20
+ separators=["\n\n", "\n", ". ", " ", ""],
21
+ )
22
+ chunks = splitter.split_documents(documents)
23
+ for idx, chunk in enumerate(chunks):
24
+ chunk.metadata["chunk_index"] = idx
25
+ chunk.metadata.setdefault("source", "unknown")
26
+ chunk.metadata.setdefault("page", 0)
27
+ return chunks
28
+
rag/embedder.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory for LangChain embedding backends (OpenAI, Ollama, Hugging Face).
2
+
3
+ The active provider is ``Settings.llm_provider``. Used by ingest and query paths when
4
+ opening or querying Chroma collections.
5
+ """
6
+
7
+ from langchain_core.embeddings import Embeddings
8
+ from langchain_ollama import OllamaEmbeddings
9
+ from langchain_openai import OpenAIEmbeddings
10
+ from pydantic import SecretStr
11
+
12
+ from api.config import get_settings
13
+ from rag.hf_hub_inference import HubInferenceEmbeddings
14
+
15
+
16
+ def create_embedding_function() -> Embeddings:
17
+ """Return an ``Embeddings`` implementation matching the configured LLM provider."""
18
+ settings = get_settings()
19
+ provider = settings.llm_provider.lower()
20
+
21
+ if provider == "openai":
22
+ if not settings.openai_api_key:
23
+ raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai")
24
+ return OpenAIEmbeddings(
25
+ model=settings.openai_embedding_model,
26
+ api_key=SecretStr(settings.openai_api_key),
27
+ )
28
+ if provider == "huggingface":
29
+ if not settings.huggingface_api_key:
30
+ raise ValueError(
31
+ "A Hugging Face token is required when LLM_PROVIDER=huggingface "
32
+ "(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)."
33
+ )
34
+ return HubInferenceEmbeddings(
35
+ model=settings.huggingface_embedding_model,
36
+ api_token=settings.huggingface_api_key,
37
+ )
38
+ if provider == "ollama":
39
+ return OllamaEmbeddings(
40
+ model=settings.ollama_embedding_model,
41
+ base_url=settings.ollama_base_url,
42
+ )
43
+ raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}")
44
+
rag/hf_hub_inference.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Inference API via ``huggingface_hub.InferenceClient``.
2
+
3
+ ``langchain_huggingface`` 0.0.x uses ``InferenceClient.post()``, which was removed in
4
+ ``huggingface_hub`` 1.x. Chat tries ``InferenceClient.chat_completion`` on the primary
5
+ provider, then (for repo ids containing ``mistral`` when primary is not Novita) Novita,
6
+ which often maps those weights to conversational chat only. On router errors or local
7
+ ``ValueError`` (Hub sometimes omits ``pipeline_tag``), we fall back to ``text_generation``
8
+ providers, then the classic **api-inference** ``POST /models/{id}`` JSON API.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any, List, Optional
14
+
15
+ import httpx
16
+ import numpy as np
17
+ from langchain_core.embeddings import Embeddings
18
+ from langchain_core.language_models.chat_models import BaseChatModel
19
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
20
+ from langchain_core.outputs import ChatGeneration, ChatResult
21
+ from langchain_core.pydantic_v1 import Field, root_validator
22
+
23
+ from huggingface_hub import InferenceClient, constants
24
+ from huggingface_hub.errors import BadRequestError, HfHubHTTPError
25
+
26
+
27
+ def _lc_messages_to_hf_chat(messages: List[BaseMessage]) -> list[dict[str, str]]:
28
+ """Map LangChain messages to Hugging Face ``chat_completion`` message dicts."""
29
+ out: list[dict[str, str]] = []
30
+ for m in messages:
31
+ content = m.content if isinstance(m.content, str) else str(m.content)
32
+ if isinstance(m, SystemMessage):
33
+ out.append({"role": "system", "content": content})
34
+ elif isinstance(m, HumanMessage):
35
+ out.append({"role": "user", "content": content})
36
+ elif isinstance(m, AIMessage):
37
+ out.append({"role": "assistant", "content": content})
38
+ elif isinstance(m, ToolMessage):
39
+ out.append({"role": "user", "content": f"[tool result]\n{content}"})
40
+ else:
41
+ out.append({"role": "user", "content": content})
42
+ return out
43
+
44
+
45
+ def _messages_to_text_generation_prompt(repo_id: str, messages: List[BaseMessage]) -> str:
46
+ """Build a single prompt for causal / text-generation APIs (instruct templates)."""
47
+ blocks: list[str] = []
48
+ for m in messages:
49
+ content = m.content if isinstance(m.content, str) else str(m.content)
50
+ if isinstance(m, SystemMessage):
51
+ blocks.append(content)
52
+ elif isinstance(m, HumanMessage):
53
+ blocks.append(content)
54
+ elif isinstance(m, AIMessage):
55
+ blocks.append(content)
56
+ elif isinstance(m, ToolMessage):
57
+ blocks.append(f"[tool]\n{content}")
58
+ else:
59
+ blocks.append(content)
60
+ body = "\n\n".join(blocks)
61
+ rid = repo_id.lower()
62
+ if "mistral" in rid:
63
+ return f"<s>[INST] {body} [/INST]"
64
+ return f"{body}\n\nAssistant:\n"
65
+
66
+
67
+ def _chat_completion_text_and_usage(out: Any) -> tuple[str, dict[str, int] | None]:
68
+ """Extract assistant text and optional token usage from ``ChatCompletionOutput``."""
69
+ choices = getattr(out, "choices", None) or []
70
+ if not choices:
71
+ return (str(out).strip(), None)
72
+ msg = getattr(choices[0], "message", None)
73
+ text = (getattr(msg, "content", None) or "").strip() if msg is not None else ""
74
+
75
+ usage_meta: dict[str, int] | None = None
76
+ u = getattr(out, "usage", None)
77
+ if u is not None:
78
+ usage_meta = {}
79
+ tt = getattr(u, "total_tokens", None)
80
+ pt = getattr(u, "prompt_tokens", None)
81
+ ct = getattr(u, "completion_tokens", None)
82
+ if tt is not None:
83
+ usage_meta["total_tokens"] = int(tt)
84
+ if pt is not None:
85
+ usage_meta["input_tokens"] = int(pt)
86
+ if ct is not None:
87
+ usage_meta["output_tokens"] = int(ct)
88
+ if not usage_meta:
89
+ usage_meta = None
90
+
91
+ return text, usage_meta
92
+
93
+
94
+ def _legacy_api_text_generation(
95
+ model_id: str,
96
+ api_token: str,
97
+ prompt: str,
98
+ *,
99
+ max_new_tokens: int,
100
+ temperature: float,
101
+ stop: list[str] | None,
102
+ ) -> str:
103
+ """Classic HF Inference API (bypasses strict ``InferenceClient`` task checks)."""
104
+ url = f"{constants.INFERENCE_ENDPOINT.rstrip('/')}/models/{model_id}"
105
+ parameters: dict[str, Any] = {
106
+ "max_new_tokens": max_new_tokens,
107
+ "temperature": temperature,
108
+ "return_full_text": False,
109
+ }
110
+ if stop:
111
+ parameters["stop"] = stop
112
+ body = {"inputs": prompt, "parameters": parameters}
113
+ headers = {"Authorization": f"Bearer {api_token}"}
114
+ timeout = httpx.Timeout(60.0, read=300.0)
115
+ with httpx.Client(timeout=timeout) as client:
116
+ resp = client.post(url, json=body, headers=headers)
117
+ try:
118
+ resp.raise_for_status()
119
+ except httpx.HTTPStatusError as exc:
120
+ _raise_legacy_inference_http_error(model_id, exc)
121
+ data = resp.json()
122
+ if isinstance(data, dict) and data.get("error"):
123
+ raise RuntimeError(str(data["error"]))
124
+ if isinstance(data, list) and data:
125
+ first = data[0]
126
+ if isinstance(first, dict) and "generated_text" in first:
127
+ return str(first["generated_text"]).strip()
128
+ if isinstance(data, dict) and "generated_text" in data:
129
+ return str(data["generated_text"]).strip()
130
+ raise RuntimeError(f"Unexpected legacy inference response: {data!r}")
131
+
132
+
133
+ class LegacyInferenceNotFoundError(RuntimeError):
134
+ """Classic ``api-inference`` returned 404 for this model id (weights not on that route)."""
135
+
136
+
137
+ def _raise_legacy_inference_http_error(model_id: str, exc: httpx.HTTPStatusError) -> None:
138
+ if exc.response.status_code == 404:
139
+ raise LegacyInferenceNotFoundError(
140
+ f"Hugging Face legacy inference returned 404 for model {model_id!r}. "
141
+ "The classic api-inference route often no longer serves this checkpoint, and router chat "
142
+ "can 404 as well depending on provider health. Try "
143
+ "HUGGINGFACE_MODEL=meta-llama/Meta-Llama-3-8B-Instruct (or another id your token can call), "
144
+ "another model id your token can reach, or Ollama/local inference."
145
+ ) from exc
146
+ raise exc
147
+
148
+
149
+ class HubInferenceEmbeddings(Embeddings):
150
+ """Embeddings through ``InferenceClient.feature_extraction``."""
151
+
152
+ def __init__(self, *, model: str, api_token: str) -> None:
153
+ self._model = model
154
+ self._client = InferenceClient(model=model, token=api_token or None)
155
+
156
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
157
+ out: list[list[float]] = []
158
+ for text in texts:
159
+ t = text.replace("\n", " ")
160
+ raw = self._client.feature_extraction(t, model=self._model)
161
+ vec = np.asarray(raw, dtype=np.float32)
162
+ if vec.ndim > 1:
163
+ vec = vec.mean(axis=0)
164
+ out.append(vec.flatten().tolist())
165
+ return out
166
+
167
+ def embed_query(self, text: str) -> List[float]:
168
+ return self.embed_documents([text])[0]
169
+
170
+
171
+ class HubInferenceChatModel(BaseChatModel):
172
+ """HF Inference: ``chat_completion`` when supported, else ``text_generation`` fallback."""
173
+
174
+ repo_id: str = Field(..., description="Hugging Face model id for inference")
175
+ huggingfacehub_api_token: str = Field(..., repr=False)
176
+ temperature: float = Field(default=0.2)
177
+ max_new_tokens: int = Field(default=2048)
178
+ inference_provider: Optional[str] = Field(
179
+ default=None,
180
+ description=(
181
+ "huggingface_hub provider id. Default is hf-inference (avoids Novita-only mappings). "
182
+ "Set to `auto` for router auto-routing (provider=None)."
183
+ ),
184
+ )
185
+
186
+ class Config:
187
+ """Pydantic v1 config."""
188
+
189
+ arbitrary_types_allowed = True
190
+
191
+ client: Any = Field(default=None, exclude=True)
192
+
193
+ @root_validator(skip_on_failure=True)
194
+ def _build_client(cls, values: dict) -> dict:
195
+ if values.get("client") is not None:
196
+ return values
197
+ raw = values.get("inference_provider")
198
+ if isinstance(raw, str):
199
+ raw = raw.strip() or None
200
+ # Auto-routing often picks Novita for Mistral instruct; Novita maps that model to
201
+ # "conversational" only, so text_generation fails. Default to HF's inference proxy.
202
+ if raw is None:
203
+ client_provider: str | None = "hf-inference"
204
+ stored = "hf-inference"
205
+ elif raw.lower() == "auto":
206
+ client_provider = None
207
+ stored = "auto"
208
+ else:
209
+ client_provider = raw
210
+ stored = raw
211
+ values["inference_provider"] = stored
212
+ values["client"] = InferenceClient(
213
+ model=values["repo_id"],
214
+ token=values.get("huggingfacehub_api_token") or None,
215
+ provider=client_provider,
216
+ )
217
+ return values
218
+
219
+ def _chat_inference_clients(self) -> list[InferenceClient]:
220
+ """Ordered ``InferenceClient`` instances for ``chat_completion``.
221
+
222
+ - Primary client (usually ``hf-inference`` when unset).
223
+ - For Mistral instruct ids, Novita often exposes **conversational** chat while HF task checks
224
+ or ``hf-inference`` reject the same repo.
225
+ - When primary is ``hf-inference``, append **router auto** (``provider=None``): many models
226
+ (e.g. Llama 3.1 Instruct) return *Model not supported by provider hf-inference* on the
227
+ serverless HF proxy but work via the inference router to another provider.
228
+ """
229
+ token = self.huggingfacehub_api_token or None
230
+ rid = self.repo_id
231
+ clients: list[InferenceClient] = [self.client]
232
+ ip = (self.inference_provider or "").strip().lower()
233
+ if "mistral" in rid.lower() and ip != "novita":
234
+ clients.append(InferenceClient(model=rid, token=token, provider="novita"))
235
+ if ip == "hf-inference":
236
+ clients.append(InferenceClient(model=rid, token=token, provider=None))
237
+ return clients
238
+
239
+ @property
240
+ def _llm_type(self) -> str:
241
+ return "hf-hub-inference"
242
+
243
+ @property
244
+ def _identifying_params(self) -> dict[str, Any]:
245
+ return {
246
+ "repo_id": self.repo_id,
247
+ "temperature": self.temperature,
248
+ "max_new_tokens": self.max_new_tokens,
249
+ "inference_provider": self.inference_provider,
250
+ }
251
+
252
+ def _text_generation_fallback(self, messages: List[BaseMessage], stop: Optional[List[str]]) -> str:
253
+ prompt = _messages_to_text_generation_prompt(self.repo_id, messages)
254
+ token = self.huggingfacehub_api_token
255
+ rid = self.repo_id
256
+ chain_raw: list[str | None] = []
257
+ p = (self.inference_provider or "").strip()
258
+ if p.lower() == "auto":
259
+ chain_raw.append(None)
260
+ elif p and p.lower() != "hf-inference":
261
+ chain_raw.append(p)
262
+ chain_raw.append("hf-inference")
263
+ chain_raw.append(None)
264
+ chain: list[str | None] = []
265
+ seen: set[str] = set()
266
+ for prov in chain_raw:
267
+ key = prov if prov is not None else "__auto__"
268
+ if key in seen:
269
+ continue
270
+ seen.add(key)
271
+ chain.append(prov)
272
+
273
+ last: Exception | None = None
274
+ for prov in chain:
275
+ try:
276
+ cli = InferenceClient(model=rid, token=token, provider=prov)
277
+ raw = cli.text_generation(
278
+ prompt,
279
+ model=rid,
280
+ max_new_tokens=self.max_new_tokens,
281
+ temperature=self.temperature,
282
+ stop=stop,
283
+ return_full_text=False,
284
+ )
285
+ return (raw if isinstance(raw, str) else str(raw)).strip()
286
+ except Exception as exc:
287
+ last = exc
288
+ continue
289
+ try:
290
+ return _legacy_api_text_generation(
291
+ rid,
292
+ token,
293
+ prompt,
294
+ max_new_tokens=self.max_new_tokens,
295
+ temperature=self.temperature,
296
+ stop=stop,
297
+ )
298
+ except Exception as legacy_exc:
299
+ if last is not None:
300
+ # Prefer the legacy endpoint error (e.g. explicit 404 guidance) over the last
301
+ # provider text_generation failure (often a task-mapping ValueError).
302
+ raise legacy_exc from last
303
+ raise legacy_exc
304
+
305
+ def _generate(
306
+ self,
307
+ messages: List[BaseMessage],
308
+ stop: Optional[List[str]] = None,
309
+ run_manager: Optional[Any] = None,
310
+ **kwargs: Any,
311
+ ) -> ChatResult:
312
+ chat_payload = _lc_messages_to_hf_chat(messages)
313
+ last_chat_err: BaseException | None = None
314
+
315
+ for cli in self._chat_inference_clients():
316
+ try:
317
+ out = cli.chat_completion(
318
+ chat_payload,
319
+ model=self.repo_id,
320
+ max_tokens=self.max_new_tokens,
321
+ temperature=self.temperature,
322
+ stop=stop,
323
+ )
324
+ text, usage_meta = _chat_completion_text_and_usage(out)
325
+ message = AIMessage(content=text, usage_metadata=usage_meta)
326
+ return ChatResult(generations=[ChatGeneration(message=message)])
327
+ except BadRequestError as exc:
328
+ last_chat_err = exc
329
+ err = str(exc).lower()
330
+ if (
331
+ "not a chat model" in err
332
+ or "model_not_supported" in err
333
+ or "not supported by provider" in err
334
+ # Defer to post-loop handling so we can explain gated / unknown ids without masking
335
+ # earlier recoverable errors from another client.
336
+ or "model_not_found" in err
337
+ or ("does not exist" in err and "model" in err)
338
+ ):
339
+ continue
340
+ raise
341
+ except HfHubHTTPError as exc:
342
+ last_chat_err = exc
343
+ code = getattr(exc.response, "status_code", None)
344
+ # Novita/router may 404 a model or route; try remaining clients then completion fallbacks.
345
+ if code in (404, 410):
346
+ continue
347
+ raise
348
+ except ValueError as exc:
349
+ # e.g. hf-inference _check_supported_task when Hub model card has no pipeline_tag
350
+ last_chat_err = exc
351
+ continue
352
+
353
+ if last_chat_err is not None and isinstance(last_chat_err, BadRequestError):
354
+ le = str(last_chat_err).lower()
355
+ if "model_not_found" in le or (
356
+ "does not exist" in le and ("model" in le or "requested model" in le)
357
+ ):
358
+ raise RuntimeError(
359
+ f"Inference router could not use chat model {self.repo_id!r} "
360
+ "(common for gated models: open the model page on the Hugging Face Hub, accept the "
361
+ "license, ensure your API token has read access to that model, then retry)."
362
+ ) from last_chat_err
363
+
364
+ try:
365
+ text = self._text_generation_fallback(messages, stop)
366
+ except LegacyInferenceNotFoundError:
367
+ raise
368
+ except Exception as exc:
369
+ hint = (
370
+ f"Hugging Face chat_completion failed for {self.repo_id!r} on all tried providers; "
371
+ "text_generation / legacy fallbacks also failed. "
372
+ "Accept the model license on the Hub, check your token, or set "
373
+ "HUGGINGFACE_INFERENCE_PROVIDER=auto to use only router routing."
374
+ )
375
+ if last_chat_err is not None:
376
+ raise RuntimeError(f"{hint} Last chat error: {last_chat_err!r}") from exc
377
+ raise RuntimeError(hint) from exc
378
+
379
+ message = AIMessage(content=text, usage_metadata=None)
380
+ return ChatResult(generations=[ChatGeneration(message=message)])
rag/loader.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load raw documents from disk into LangChain ``Document`` objects.
2
+
3
+ Supports PDF (PyMuPDF), plain text, and Markdown. Each document gets ``source`` and
4
+ ``page`` metadata for downstream chunking and citations.
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ from langchain_core.documents import Document
10
+ from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
11
+
12
+
13
+ def load_documents(paths: str | list[str]) -> list[Document]:
14
+ """Load one or more files; raise ``ValueError`` for unsupported extensions."""
15
+ normalized_paths = [paths] if isinstance(paths, str) else paths
16
+ all_docs: list[Document] = []
17
+ for path_str in normalized_paths:
18
+ path = Path(path_str)
19
+ suffix = path.suffix.lower()
20
+
21
+ if suffix == ".pdf":
22
+ loader = PyMuPDFLoader(str(path_str))
23
+ elif suffix in {".txt", ".md"}:
24
+ loader = TextLoader(str(path_str), encoding="utf-8")
25
+ else:
26
+ raise ValueError(f"Unsupported file type: {suffix or 'unknown'}")
27
+
28
+ documents = loader.load()
29
+ for doc in documents:
30
+ doc.metadata.setdefault("source", path.name)
31
+ doc.metadata.setdefault("page", 0)
32
+ all_docs.extend(documents)
33
+
34
+ return all_docs
35
+
rag/retriever.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic retrieval and grounded LLM generation for ask and summarise flows.
2
+
3
+ Pipeline: similarity search on Chroma → relevance filter → provider-specific chat model
4
+ → answer with citations. Prompt templates enforce document-only answers for consulting use.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+
9
+ from langchain_chroma import Chroma
10
+ from langchain_core.language_models import BaseChatModel
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langchain_ollama import ChatOllama
13
+ from langchain_openai import ChatOpenAI
14
+ from pydantic import SecretStr
15
+
16
+ try:
17
+ from langchain_anthropic import ChatAnthropic
18
+ except ImportError:
19
+ ChatAnthropic = None # type: ignore[assignment]
20
+
21
+ from api.config import Settings
22
+ from rag.hf_hub_inference import HubInferenceChatModel
23
+
24
+ NO_MATCH_ANSWER = "I cannot find this information in the uploaded documents."
25
+ MIN_RELEVANCE_SCORE = 0.15
26
+
27
+ # Verbatim from DOCUAUDIT_AI_REQUIREMENTS.md (placeholders filled at runtime).
28
+ DOCUAUDIT_ASK_TEMPLATE = """You are DocuAudit AI, an expert document analyst for consulting environments.
29
+
30
+ RULES:
31
+ 1. Answer ONLY based on the provided document excerpts below.
32
+ 2. If the answer is not in the documents, say: "I cannot find this information in the uploaded documents."
33
+ 3. ALWAYS cite your sources: mention the document name and page number for every claim.
34
+ 4. Be precise and professional. This is a high-stakes consulting environment.
35
+ 5. Do not speculate or add information not present in the documents.
36
+
37
+ DOCUMENT EXCERPTS:
38
+ {context}
39
+
40
+ QUESTION: {question}
41
+
42
+ ANSWER (with source citations):
43
+ """
44
+
45
+
46
+ @dataclass
47
+ class RetrievedChunk:
48
+ """One search hit with metadata needed for prompts and API citations."""
49
+
50
+ text: str
51
+ score: float | None
52
+ source: str
53
+ page: int | None
54
+ chunk_index: int | None
55
+
56
+
57
+ def retrieve_chunks(vector_store: Chroma, question: str, k: int) -> list[RetrievedChunk]:
58
+ """Top-K similarity search with relevance scores from Chroma/LangChain."""
59
+ results = vector_store.similarity_search_with_relevance_scores(question, k=k)
60
+ chunks: list[RetrievedChunk] = []
61
+ for doc, score in results:
62
+ metadata = doc.metadata or {}
63
+ chunks.append(
64
+ RetrievedChunk(
65
+ text=doc.page_content,
66
+ score=score,
67
+ source=str(metadata.get("source", "unknown")),
68
+ page=_to_int_or_none(metadata.get("page")),
69
+ chunk_index=_to_int_or_none(metadata.get("chunk_index")),
70
+ )
71
+ )
72
+ return chunks
73
+
74
+
75
+ SUMMARY_RETRIEVAL_QUERY = (
76
+ "Overview of the document: main topics, key definitions, obligations, risks, and conclusions."
77
+ )
78
+
79
+
80
+ def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> tuple[str, int]:
81
+ """Generate a cited answer from chunks; return ``(answer_text, token_count)``."""
82
+ ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
83
+ if not ranked_chunks:
84
+ return NO_MATCH_ANSWER, 0
85
+
86
+ llm = _create_chat_model(settings)
87
+ prompt_context = _format_context(ranked_chunks)
88
+ user_content = DOCUAUDIT_ASK_TEMPLATE.format(context=prompt_context, question=question)
89
+ messages = [HumanMessage(content=user_content)]
90
+ response = llm.invoke(messages)
91
+ answer = _extract_message_text(response).strip()
92
+ tokens = _extract_usage_tokens(response)
93
+ return (answer or NO_MATCH_ANSWER), tokens
94
+
95
+
96
+ def summarise_with_grounding(
97
+ settings: Settings,
98
+ *,
99
+ focus: str | None,
100
+ chunks: list[RetrievedChunk],
101
+ ) -> tuple[str, int]:
102
+ """Produce a structured summary grounded in retrieved excerpts."""
103
+ ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
104
+ if not ranked_chunks:
105
+ return NO_MATCH_ANSWER, 0
106
+
107
+ llm = _create_chat_model(settings)
108
+ prompt_context = _format_context(ranked_chunks)
109
+ user_instruction = (
110
+ focus.strip()
111
+ if focus and focus.strip()
112
+ else "Summarise the main themes, structure, and important details. Use bullet points where helpful."
113
+ )
114
+ messages = [
115
+ SystemMessage(
116
+ content=(
117
+ "You write accurate summaries using only the provided document excerpts. "
118
+ "Do not invent facts. If the excerpts are insufficient, say what is missing."
119
+ )
120
+ ),
121
+ HumanMessage(
122
+ content=(
123
+ f"Summary request: {user_instruction}\n\n"
124
+ f"Document excerpts:\n{prompt_context}\n\n"
125
+ "Return a structured, concise summary grounded in the excerpts above."
126
+ )
127
+ ),
128
+ ]
129
+ response = llm.invoke(messages)
130
+ answer = _extract_message_text(response).strip()
131
+ tokens = _extract_usage_tokens(response)
132
+ return (answer or NO_MATCH_ANSWER), tokens
133
+
134
+
135
+ def _create_chat_model(settings: Settings) -> BaseChatModel:
136
+ provider = settings.llm_provider.lower()
137
+
138
+ if provider == "openai":
139
+ if not settings.openai_api_key:
140
+ raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai")
141
+ return ChatOpenAI(model=settings.openai_model, api_key=SecretStr(settings.openai_api_key))
142
+ if provider == "ollama":
143
+ return ChatOllama(model=settings.ollama_chat_model, base_url=settings.ollama_base_url)
144
+ if provider == "anthropic":
145
+ if ChatAnthropic is None:
146
+ raise ValueError("langchain-anthropic is not installed for LLM_PROVIDER=anthropic")
147
+ if not settings.anthropic_api_key:
148
+ raise ValueError("ANTHROPIC_API_KEY is required when LLM_PROVIDER=anthropic")
149
+ return ChatAnthropic(model=settings.anthropic_model, api_key=SecretStr(settings.anthropic_api_key))
150
+ if provider == "huggingface":
151
+ if not settings.huggingface_api_key:
152
+ raise ValueError(
153
+ "A Hugging Face token is required when LLM_PROVIDER=huggingface "
154
+ "(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)."
155
+ )
156
+ return HubInferenceChatModel(
157
+ repo_id=settings.huggingface_model,
158
+ huggingfacehub_api_token=settings.huggingface_api_key,
159
+ temperature=0.2,
160
+ max_new_tokens=2048,
161
+ inference_provider=settings.huggingface_inference_provider,
162
+ )
163
+
164
+ raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}")
165
+
166
+
167
+ def _format_context(chunks: list[RetrievedChunk]) -> str:
168
+ lines: list[str] = []
169
+ for idx, chunk in enumerate(chunks, start=1):
170
+ lines.append(
171
+ f"[{idx}] source={chunk.source}, page={chunk.page}, chunk={chunk.chunk_index}, score={chunk.score}\n"
172
+ f"{chunk.text}"
173
+ )
174
+ return "\n\n".join(lines)
175
+
176
+
177
+ def _to_int_or_none(value: object) -> int | None:
178
+ try:
179
+ if value is None:
180
+ return None
181
+ return int(value)
182
+ except (TypeError, ValueError):
183
+ return None
184
+
185
+
186
+ def _extract_usage_tokens(response: object) -> int:
187
+ um = getattr(response, "usage_metadata", None)
188
+ if isinstance(um, dict):
189
+ total = um.get("total_tokens")
190
+ if total is not None:
191
+ return int(total)
192
+ inp = int(um.get("input_tokens", 0) or 0)
193
+ out = int(um.get("output_tokens", 0) or 0)
194
+ return inp + out
195
+ rm = getattr(response, "response_metadata", None) or {}
196
+ if isinstance(rm, dict):
197
+ tu = rm.get("token_usage")
198
+ if isinstance(tu, dict):
199
+ if tu.get("total_tokens") is not None:
200
+ return int(tu["total_tokens"])
201
+ return int(tu.get("prompt_tokens", 0) or 0) + int(tu.get("completion_tokens", 0) or 0)
202
+ return 0
203
+
204
+
205
+ def _extract_message_text(response: object) -> str:
206
+ content = getattr(response, "content", "")
207
+ if isinstance(content, str):
208
+ return content
209
+ if isinstance(content, list):
210
+ text_parts: list[str] = []
211
+ for item in content:
212
+ if isinstance(item, str):
213
+ text_parts.append(item)
214
+ elif isinstance(item, dict) and "text" in item:
215
+ text_parts.append(str(item["text"]))
216
+ return "\n".join(part for part in text_parts if part)
217
+ return str(content)
218
+
rag/vector_store.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChromaDB persistence and LangChain ``Chroma`` vector store helpers.
2
+
3
+ Collections are named per ingest target; documents are stored with UUID chunk ids.
4
+ Telemetry is disabled at the client level for quieter logs in production.
5
+ """
6
+
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from uuid import uuid4
10
+
11
+ import chromadb
12
+ from chromadb.config import Settings
13
+ from langchain_chroma import Chroma
14
+ from langchain_core.documents import Document
15
+ from langchain_core.embeddings import Embeddings
16
+
17
+ _CHROMA_CLIENT_SETTINGS = Settings(anonymized_telemetry=False)
18
+
19
+
20
+ def _utc_now_iso() -> str:
21
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
22
+
23
+
24
+ def _chroma_client(persist_directory: str) -> chromadb.PersistentClient:
25
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
26
+ return chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
27
+
28
+
29
+ def get_vector_store(
30
+ persist_directory: str,
31
+ collection_name: str,
32
+ embedding_function: Embeddings,
33
+ ) -> Chroma:
34
+ """Open or create a persisted Chroma collection wired to the given embedder."""
35
+ client = _chroma_client(persist_directory)
36
+ try:
37
+ client.get_collection(name=collection_name)
38
+ except Exception:
39
+ client.get_or_create_collection(
40
+ name=collection_name,
41
+ metadata={"created_at": _utc_now_iso()},
42
+ )
43
+ return Chroma(
44
+ collection_name=collection_name,
45
+ embedding_function=embedding_function,
46
+ persist_directory=persist_directory,
47
+ client_settings=_CHROMA_CLIENT_SETTINGS,
48
+ )
49
+
50
+
51
+ def add_documents(vector_store: Chroma, chunks: list[Document]) -> list[str]:
52
+ """Embed and insert chunks; return the generated vector ids."""
53
+ document_ids = [str(uuid4()) for _ in chunks]
54
+ vector_store.add_documents(documents=chunks, ids=document_ids)
55
+ return document_ids
56
+
57
+
58
+ def list_collection_names(persist_directory: str) -> list[str]:
59
+ """Sorted list of collection names in the persist directory."""
60
+ client = _chroma_client(persist_directory)
61
+ return sorted(c.name for c in client.list_collections())
62
+
63
+
64
+ def delete_collection(persist_directory: str, collection_name: str) -> int:
65
+ """Delete a collection and return the number of documents that were removed (best effort)."""
66
+ client = _chroma_client(persist_directory)
67
+ removed = 0
68
+ try:
69
+ col = client.get_collection(name=collection_name)
70
+ removed = int(col.count())
71
+ except Exception:
72
+ removed = 0
73
+ client.delete_collection(name=collection_name)
74
+ return removed
75
+
76
+
77
+ def collection_document_count(persist_directory: str, collection_name: str) -> int:
78
+ """Number of vectors in a collection, or 0 if the collection does not exist."""
79
+ client = _chroma_client(persist_directory)
80
+ try:
81
+ col = client.get_collection(name=collection_name)
82
+ return int(col.count())
83
+ except Exception:
84
+ return 0
85
+
86
+
87
+ def collection_created_at(persist_directory: str, collection_name: str) -> str | None:
88
+ """Return collection metadata ``created_at`` if present (Chroma-specific)."""
89
+ client = _chroma_client(persist_directory)
90
+ try:
91
+ col = client.get_collection(name=collection_name)
92
+ meta = getattr(col, "metadata", None) or {}
93
+ if isinstance(meta, dict):
94
+ raw = meta.get("created_at") or meta.get("created")
95
+ if raw is not None:
96
+ return str(raw)
97
+ except Exception:
98
+ pass
99
+ return None
100
+
101
+
102
+ def ensure_collection_created_at(
103
+ persist_directory: str,
104
+ collection_name: str,
105
+ *,
106
+ fallback: str | None = None,
107
+ ) -> str | None:
108
+ """Persist ``created_at`` on the Chroma collection when missing; never overwrites an existing value."""
109
+ client = _chroma_client(persist_directory)
110
+ try:
111
+ col = client.get_collection(name=collection_name)
112
+ except Exception:
113
+ return None
114
+ meta = getattr(col, "metadata", None) or {}
115
+ if not isinstance(meta, dict):
116
+ meta = {}
117
+ raw = meta.get("created_at") or meta.get("created")
118
+ if raw is not None:
119
+ return str(raw)
120
+ value = fallback or _utc_now_iso()
121
+ updated = dict(meta)
122
+ updated["created_at"] = value
123
+ col.modify(metadata=updated)
124
+ return value
125
+
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.29.0
3
+ pydantic-settings==2.3.4
4
+ langchain==0.2.0
5
+ langchain-openai==0.1.7
6
+ langchain-community==0.2.0
7
+ langchain-chroma==0.1.4
8
+ langchain-text-splitters==0.2.0
9
+ langchain-anthropic==0.1.15
10
+ langchain-ollama==0.1.3
11
+ chromadb==0.5.0
12
+ posthog>=3.7.0,<4
13
+ openai==1.30.1
14
+ anthropic==0.28.1
15
+ pymupdf==1.25.5
16
+ python-multipart==0.0.9
17
+ aiosqlite
18
+ httpx>=0.27.0
19
+ huggingface-hub
20
+ langchain-huggingface
21
+ streamlit>=1.39.0
22
+ pytest>=8.4.2
23
+ pytest-asyncio>=1.2.0
sample.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Doc-Audi-AI RAG Smoke Test Document
2
+
3
+ Project: Doc-Audi-AI
4
+ Environment: Lightning AI deployment with Ollama embeddings.
5
+
6
+ This sample document is used to test ingestion and retrieval.
7
+ The system should split this file into chunks, generate embeddings, and store vectors in Chroma.
8
+
9
+ Key facts:
10
+ - The project supports file ingestion for PDF, TXT, and MD formats.
11
+ - The default collection name for tests is "default".
12
+ - A typical retrieval question is: "What is this document about?"
13
+ - Another test question is: "Which file formats are supported?"
14
+
15
+ Expected behavior:
16
+ If ingestion succeeds, querying should return text snippets from this document with relevance scores.
storage/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Persistence layer: SQLite audit log and ingest job tracking."""
storage/audit_store.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLite persistence for query and summarise audit events.
2
+
3
+ Schema is created/migrated on first use. Stores full answers, citation JSON, token usage,
4
+ and optional filters (user_id, date range) for list endpoints.
5
+ """
6
+
7
+ import json
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any
11
+ from uuid import uuid4
12
+
13
+ import aiosqlite
14
+
15
+ from models.responses import AuditLogDetailResponse, AuditLogEntry, SourceCitation
16
+
17
+
18
+ def _utc_now_iso() -> str:
19
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
20
+
21
+
22
+ def _parse_ts(value: object) -> datetime:
23
+ if value is None or value == "":
24
+ return datetime.now(timezone.utc)
25
+ s = str(value).strip()
26
+ if s.endswith("Z"):
27
+ s = s[:-1] + "+00:00"
28
+ try:
29
+ dt = datetime.fromisoformat(s)
30
+ if dt.tzinfo is None:
31
+ return dt.replace(tzinfo=timezone.utc)
32
+ return dt
33
+ except ValueError:
34
+ return datetime.now(timezone.utc)
35
+
36
+
37
+ async def _migrate_audit_columns(conn: aiosqlite.Connection) -> None:
38
+ cursor = await conn.execute("PRAGMA table_info(audit_events)")
39
+ rows = await cursor.fetchall()
40
+ col_names = {str(r[1]) for r in rows}
41
+ alters: list[str] = []
42
+ if "user_id" not in col_names:
43
+ alters.append("ALTER TABLE audit_events ADD COLUMN user_id TEXT NOT NULL DEFAULT 'anonymous'")
44
+ if "model_used" not in col_names:
45
+ alters.append("ALTER TABLE audit_events ADD COLUMN model_used TEXT")
46
+ if "tokens_used" not in col_names:
47
+ alters.append("ALTER TABLE audit_events ADD COLUMN tokens_used INTEGER")
48
+ if "response_time_ms" not in col_names:
49
+ alters.append("ALTER TABLE audit_events ADD COLUMN response_time_ms INTEGER")
50
+ if "answer_summary" not in col_names:
51
+ alters.append("ALTER TABLE audit_events ADD COLUMN answer_summary TEXT")
52
+ if "kind" not in col_names:
53
+ alters.append("ALTER TABLE audit_events ADD COLUMN kind TEXT NOT NULL DEFAULT 'ask'")
54
+ for stmt in alters:
55
+ await conn.execute(stmt)
56
+ if alters:
57
+ await conn.commit()
58
+
59
+
60
+ async def init_audit_db(db_path: str) -> None:
61
+ """Create ``audit_events`` table and apply additive column migrations."""
62
+ db_file = Path(db_path)
63
+ db_file.parent.mkdir(parents=True, exist_ok=True)
64
+ async with aiosqlite.connect(db_file.as_posix()) as conn:
65
+ await conn.execute(
66
+ """
67
+ CREATE TABLE IF NOT EXISTS audit_events (
68
+ event_id TEXT PRIMARY KEY,
69
+ action TEXT NOT NULL,
70
+ question TEXT NOT NULL,
71
+ collection_name TEXT NOT NULL,
72
+ answer TEXT,
73
+ status TEXT NOT NULL,
74
+ message TEXT NOT NULL,
75
+ sources_json TEXT NOT NULL,
76
+ results_json TEXT NOT NULL,
77
+ created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
78
+ user_id TEXT NOT NULL DEFAULT 'anonymous',
79
+ model_used TEXT,
80
+ tokens_used INTEGER,
81
+ response_time_ms INTEGER,
82
+ answer_summary TEXT,
83
+ kind TEXT NOT NULL DEFAULT 'ask'
84
+ )
85
+ """
86
+ )
87
+ await conn.commit()
88
+ await _migrate_audit_columns(conn)
89
+
90
+
91
+ def _summary_from_answer(answer: str, max_len: int = 280) -> str:
92
+ text = (answer or "").strip()
93
+ if len(text) <= max_len:
94
+ return text
95
+ return text[: max_len - 1].rstrip() + "…"
96
+
97
+
98
+ def _sources_to_citations(raw: list[dict[str, Any]]) -> list[SourceCitation]:
99
+ out: list[SourceCitation] = []
100
+ for item in raw:
101
+ if not isinstance(item, dict):
102
+ continue
103
+ if "document_name" in item:
104
+ doc = str(item.get("document_name", ""))
105
+ page = int(item.get("page_number", 0) or 0)
106
+ chunk = str(item.get("chunk_text", ""))
107
+ score = float(item.get("relevance_score", 0.0) or 0.0)
108
+ else:
109
+ doc = str(item.get("source", item.get("document_name", "")))
110
+ p = item.get("page_number", item.get("page"))
111
+ try:
112
+ page = int(p) if p is not None else 0
113
+ except (TypeError, ValueError):
114
+ page = 0
115
+ chunk = str(item.get("chunk_text", item.get("excerpt", item.get("text", ""))))
116
+ s = item.get("relevance_score", item.get("score"))
117
+ try:
118
+ score = float(s) if s is not None else 0.0
119
+ except (TypeError, ValueError):
120
+ score = 0.0
121
+ out.append(
122
+ SourceCitation(
123
+ document_name=doc or "unknown",
124
+ page_number=page,
125
+ chunk_text=chunk,
126
+ relevance_score=score,
127
+ )
128
+ )
129
+ return out
130
+
131
+
132
+ async def persist_query_audit(
133
+ db_path: str,
134
+ *,
135
+ query_id: str,
136
+ action: str,
137
+ user_id: str,
138
+ question: str,
139
+ collection_name: str,
140
+ answer: str,
141
+ sources: list[SourceCitation],
142
+ model_used: str,
143
+ tokens_used: int,
144
+ response_time_ms: int,
145
+ status: str = "success",
146
+ message: str = "ok",
147
+ kind: str = "ask",
148
+ ) -> str:
149
+ """Insert one audit row after a successful ask or summarise; returns ``query_id``."""
150
+ await init_audit_db(db_path)
151
+ sources_payload = [s.model_dump(mode="json") for s in sources]
152
+ summary = _summary_from_answer(answer)
153
+ created = _utc_now_iso()
154
+ async with aiosqlite.connect(db_path) as conn:
155
+ await conn.execute(
156
+ """
157
+ INSERT INTO audit_events (
158
+ event_id, action, question, collection_name, answer, status, message,
159
+ sources_json, results_json, created_at, user_id, model_used, tokens_used,
160
+ response_time_ms, answer_summary, kind
161
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, '[]', ?, ?, ?, ?, ?, ?, ?)
162
+ """,
163
+ (
164
+ query_id,
165
+ action,
166
+ question,
167
+ collection_name,
168
+ answer,
169
+ status,
170
+ message,
171
+ json.dumps(sources_payload),
172
+ created,
173
+ user_id,
174
+ model_used,
175
+ tokens_used,
176
+ response_time_ms,
177
+ summary,
178
+ kind,
179
+ ),
180
+ )
181
+ await conn.commit()
182
+ return query_id
183
+
184
+
185
+ async def count_audit_events(
186
+ db_path: str,
187
+ *,
188
+ user_id: str | None = None,
189
+ from_date: str | None = None,
190
+ to_date: str | None = None,
191
+ ) -> int:
192
+ await init_audit_db(db_path)
193
+ where, params = _audit_filters(user_id, from_date, to_date)
194
+ async with aiosqlite.connect(db_path) as conn:
195
+ cur = await conn.execute(f"SELECT COUNT(*) AS c FROM audit_events {where}", params)
196
+ row = await cur.fetchone()
197
+ return int(row[0]) if row else 0
198
+
199
+
200
+ def _audit_filters(user_id: str | None, from_date: str | None, to_date: str | None) -> tuple[str, list[Any]]:
201
+ clauses: list[str] = []
202
+ params: list[Any] = []
203
+ if user_id:
204
+ clauses.append("user_id = ?")
205
+ params.append(user_id)
206
+ if from_date:
207
+ clauses.append("datetime(created_at) >= datetime(?)")
208
+ params.append(from_date)
209
+ if to_date:
210
+ clauses.append("datetime(created_at) <= datetime(?)")
211
+ params.append(to_date)
212
+ if not clauses:
213
+ return "", []
214
+ return "WHERE " + " AND ".join(clauses), params
215
+
216
+
217
+ async def list_audit_events(
218
+ db_path: str,
219
+ *,
220
+ limit: int,
221
+ offset: int,
222
+ user_id: str | None = None,
223
+ from_date: str | None = None,
224
+ to_date: str | None = None,
225
+ ) -> tuple[list[AuditLogEntry], int]:
226
+ """Paginated audit list with optional user and ISO datetime filters."""
227
+ await init_audit_db(db_path)
228
+ where, fparams = _audit_filters(user_id, from_date, to_date)
229
+ total = await count_audit_events(db_path, user_id=user_id, from_date=from_date, to_date=to_date)
230
+ async with aiosqlite.connect(db_path) as conn:
231
+ conn.row_factory = aiosqlite.Row
232
+ cursor = await conn.execute(
233
+ f"""
234
+ SELECT event_id, user_id, question, answer, answer_summary, sources_json, model_used, created_at
235
+ FROM audit_events
236
+ {where}
237
+ ORDER BY datetime(created_at) DESC, rowid DESC
238
+ LIMIT ? OFFSET ?
239
+ """,
240
+ [*fparams, limit, offset],
241
+ )
242
+ rows = await cursor.fetchall()
243
+ logs: list[AuditLogEntry] = []
244
+ for row in rows:
245
+ src_raw = json.loads(row["sources_json"] or "[]")
246
+ if not isinstance(src_raw, list):
247
+ src_raw = []
248
+ summary_cell = row["answer_summary"]
249
+ summary_text = str(summary_cell).strip() if summary_cell else ""
250
+ if not summary_text:
251
+ summary_text = _summary_from_answer(str(row["answer"] or ""))
252
+ logs.append(
253
+ AuditLogEntry(
254
+ query_id=str(row["event_id"]),
255
+ user_id=str(row["user_id"] or "anonymous"),
256
+ question=str(row["question"]),
257
+ answer_summary=summary_text,
258
+ sources_count=len(src_raw),
259
+ model_used=row["model_used"],
260
+ timestamp=_parse_ts(row["created_at"]),
261
+ )
262
+ )
263
+ return logs, total
264
+
265
+
266
+ async def get_audit_event(db_path: str, query_id: str) -> AuditLogDetailResponse | None:
267
+ """Full audit record for one ``query_id``, or ``None`` if missing."""
268
+ await init_audit_db(db_path)
269
+ async with aiosqlite.connect(db_path) as conn:
270
+ conn.row_factory = aiosqlite.Row
271
+ cursor = await conn.execute(
272
+ """
273
+ SELECT event_id, user_id, question, answer, sources_json, model_used, tokens_used, created_at
274
+ FROM audit_events
275
+ WHERE event_id = ?
276
+ """,
277
+ (query_id,),
278
+ )
279
+ row = await cursor.fetchone()
280
+ if row is None:
281
+ return None
282
+ src_raw = json.loads(row["sources_json"] or "[]")
283
+ if not isinstance(src_raw, list):
284
+ src_raw = []
285
+ citations = _sources_to_citations(src_raw)
286
+ return AuditLogDetailResponse(
287
+ query_id=str(row["event_id"]),
288
+ user_id=str(row["user_id"] or "anonymous"),
289
+ question=str(row["question"]),
290
+ full_answer=str(row["answer"] or ""),
291
+ sources=citations,
292
+ model_used=row["model_used"],
293
+ tokens_used=row["tokens_used"],
294
+ timestamp=_parse_ts(row["created_at"]),
295
+ )
storage/job_store.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLite tracking for asynchronous document ingest jobs.
2
+
3
+ Jobs move through ``queued`` → ``processing`` → ``completed`` or ``failed``. Progress
4
+ fields support multi-file batches and per-file error messages.
5
+ """
6
+
7
+ import json
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any
11
+ from uuid import uuid4
12
+
13
+ import aiosqlite
14
+
15
+ from models.responses import JobListItem, JobStatusResponse
16
+
17
+
18
+ def _utc_now_iso() -> str:
19
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
20
+
21
+
22
+ async def _migrate_jobs_columns(conn: aiosqlite.Connection) -> None:
23
+ cursor = await conn.execute("PRAGMA table_info(ingest_jobs)")
24
+ rows = await cursor.fetchall()
25
+ col_names = {str(r[1]) for r in rows}
26
+ alters: list[str] = []
27
+ if "total_files" not in col_names:
28
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN total_files INTEGER NOT NULL DEFAULT 1")
29
+ if "processed_files" not in col_names:
30
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN processed_files INTEGER NOT NULL DEFAULT 0")
31
+ if "failed_files" not in col_names:
32
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN failed_files INTEGER NOT NULL DEFAULT 0")
33
+ if "filenames_json" not in col_names:
34
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN filenames_json TEXT NOT NULL DEFAULT '[]'")
35
+ if "errors_json" not in col_names:
36
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN errors_json TEXT NOT NULL DEFAULT '[]'")
37
+ if "started_at" not in col_names:
38
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN started_at TEXT")
39
+ if "completed_at" not in col_names:
40
+ alters.append("ALTER TABLE ingest_jobs ADD COLUMN completed_at TEXT")
41
+ for stmt in alters:
42
+ await conn.execute(stmt)
43
+ if alters:
44
+ await conn.commit()
45
+ await _backfill_job_filenames(conn)
46
+
47
+
48
+ async def _backfill_job_filenames(conn: aiosqlite.Connection) -> None:
49
+ conn.row_factory = aiosqlite.Row
50
+ cursor = await conn.execute("SELECT job_id, filename, filenames_json, total_files FROM ingest_jobs")
51
+ rows = await cursor.fetchall()
52
+ for row in rows:
53
+ raw = row["filenames_json"] or "[]"
54
+ try:
55
+ parsed: Any = json.loads(raw)
56
+ except json.JSONDecodeError:
57
+ parsed = []
58
+ if not parsed and row["filename"]:
59
+ await conn.execute(
60
+ """
61
+ UPDATE ingest_jobs
62
+ SET filenames_json = ?, total_files = CASE WHEN total_files IS NULL OR total_files < 1 THEN 1 ELSE total_files END
63
+ WHERE job_id = ?
64
+ """,
65
+ (json.dumps([row["filename"]]), row["job_id"]),
66
+ )
67
+ await conn.commit()
68
+
69
+
70
+ async def init_jobs_db(db_path: str) -> None:
71
+ """Create ``ingest_jobs`` table and apply additive column migrations."""
72
+ db_file = Path(db_path)
73
+ db_file.parent.mkdir(parents=True, exist_ok=True)
74
+ async with aiosqlite.connect(db_file.as_posix()) as conn:
75
+ await conn.execute(
76
+ """
77
+ CREATE TABLE IF NOT EXISTS ingest_jobs (
78
+ job_id TEXT PRIMARY KEY,
79
+ status TEXT NOT NULL,
80
+ collection_name TEXT NOT NULL,
81
+ filename TEXT NOT NULL,
82
+ message TEXT NOT NULL DEFAULT '',
83
+ document_ids_json TEXT NOT NULL DEFAULT '[]',
84
+ created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
85
+ updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
86
+ total_files INTEGER NOT NULL DEFAULT 1,
87
+ processed_files INTEGER NOT NULL DEFAULT 0,
88
+ failed_files INTEGER NOT NULL DEFAULT 0,
89
+ filenames_json TEXT NOT NULL DEFAULT '[]',
90
+ errors_json TEXT NOT NULL DEFAULT '[]',
91
+ started_at TEXT,
92
+ completed_at TEXT
93
+ )
94
+ """
95
+ )
96
+ await conn.commit()
97
+ await _migrate_jobs_columns(conn)
98
+
99
+
100
+ async def create_ingest_job(
101
+ db_path: str,
102
+ *,
103
+ collection_name: str,
104
+ filenames: list[str],
105
+ ) -> str:
106
+ """Insert a new queued job; return the generated ``job_id``."""
107
+ if not filenames:
108
+ raise ValueError("filenames must not be empty")
109
+ job_id = str(uuid4())
110
+ primary = filenames[0]
111
+ names_json = json.dumps(filenames)
112
+ total = len(filenames)
113
+ await init_jobs_db(db_path)
114
+ async with aiosqlite.connect(db_path) as conn:
115
+ await conn.execute(
116
+ """
117
+ INSERT INTO ingest_jobs (
118
+ job_id, status, collection_name, filename, message, document_ids_json,
119
+ total_files, processed_files, failed_files, filenames_json, errors_json
120
+ ) VALUES (?, 'queued', ?, ?, '', '[]', ?, 0, 0, ?, '[]')
121
+ """,
122
+ (job_id, collection_name, primary, total, names_json),
123
+ )
124
+ await conn.commit()
125
+ return job_id
126
+
127
+
128
+ async def mark_job_processing(db_path: str, job_id: str) -> None:
129
+ await init_jobs_db(db_path)
130
+ started = _utc_now_iso()
131
+ async with aiosqlite.connect(db_path) as conn:
132
+ await conn.execute(
133
+ """
134
+ UPDATE ingest_jobs
135
+ SET status = 'processing', message = 'Ingestion in progress.', started_at = COALESCE(started_at, ?),
136
+ updated_at = CURRENT_TIMESTAMP
137
+ WHERE job_id = ?
138
+ """,
139
+ (started, job_id),
140
+ )
141
+ await conn.commit()
142
+
143
+
144
+ async def update_job_progress(
145
+ db_path: str,
146
+ job_id: str,
147
+ *,
148
+ processed_files: int,
149
+ failed_files: int,
150
+ errors: list[str],
151
+ message: str | None = None,
152
+ ) -> None:
153
+ await init_jobs_db(db_path)
154
+ async with aiosqlite.connect(db_path) as conn:
155
+ await conn.execute(
156
+ """
157
+ UPDATE ingest_jobs
158
+ SET processed_files = ?, failed_files = ?, errors_json = ?,
159
+ message = COALESCE(?, message), updated_at = CURRENT_TIMESTAMP
160
+ WHERE job_id = ?
161
+ """,
162
+ (processed_files, failed_files, json.dumps(errors), message, job_id),
163
+ )
164
+ await conn.commit()
165
+
166
+
167
+ async def complete_ingest_job(
168
+ db_path: str,
169
+ job_id: str,
170
+ *,
171
+ document_ids: list[str],
172
+ message: str,
173
+ ) -> None:
174
+ await init_jobs_db(db_path)
175
+ completed = _utc_now_iso()
176
+ async with aiosqlite.connect(db_path) as conn:
177
+ await conn.execute(
178
+ """
179
+ UPDATE ingest_jobs
180
+ SET status = 'completed', message = ?, document_ids_json = ?,
181
+ completed_at = ?, updated_at = CURRENT_TIMESTAMP
182
+ WHERE job_id = ?
183
+ """,
184
+ (message, json.dumps(document_ids), completed, job_id),
185
+ )
186
+ await conn.commit()
187
+
188
+
189
+ async def fail_ingest_job(db_path: str, job_id: str, *, message: str, errors: list[str] | None = None) -> None:
190
+ await init_jobs_db(db_path)
191
+ completed = _utc_now_iso()
192
+ err_json = json.dumps(errors or [message])
193
+ async with aiosqlite.connect(db_path) as conn:
194
+ await conn.execute(
195
+ """
196
+ UPDATE ingest_jobs
197
+ SET status = 'failed', message = ?, errors_json = ?, completed_at = ?,
198
+ updated_at = CURRENT_TIMESTAMP
199
+ WHERE job_id = ?
200
+ """,
201
+ (message, err_json, completed, job_id),
202
+ )
203
+ await conn.commit()
204
+
205
+
206
+ async def get_job_status(db_path: str, job_id: str) -> JobStatusResponse | None:
207
+ """Job status DTO for API, including computed ``progress_percent``."""
208
+ await init_jobs_db(db_path)
209
+ async with aiosqlite.connect(db_path) as conn:
210
+ conn.row_factory = aiosqlite.Row
211
+ cursor = await conn.execute(
212
+ """
213
+ SELECT job_id, status, total_files, processed_files, failed_files, errors_json,
214
+ started_at, completed_at, message
215
+ FROM ingest_jobs
216
+ WHERE job_id = ?
217
+ """,
218
+ (job_id,),
219
+ )
220
+ row = await cursor.fetchone()
221
+ if row is None:
222
+ return None
223
+ data = dict(row)
224
+ total = int(data["total_files"] or 0)
225
+ processed = int(data["processed_files"] or 0)
226
+ failed = int(data["failed_files"] or 0)
227
+ denom = total if total > 0 else 1
228
+ progress = int(min(100, max(0, round((processed + failed) / denom * 100))))
229
+ errors = json.loads(data.get("errors_json") or "[]")
230
+ if not isinstance(errors, list):
231
+ errors = [str(errors)]
232
+ errors_str = [str(e) for e in errors]
233
+ return JobStatusResponse(
234
+ job_id=str(data["job_id"]),
235
+ status=str(data["status"]),
236
+ total_files=total,
237
+ processed_files=processed,
238
+ failed_files=failed,
239
+ progress_percent=progress,
240
+ started_at=_parse_dt(data.get("started_at")),
241
+ completed_at=_parse_dt(data.get("completed_at")),
242
+ errors=errors_str,
243
+ )
244
+
245
+
246
+ async def earliest_job_created_at_for_collection(db_path: str, collection_name: str) -> str | None:
247
+ """Earliest ingest job timestamp for a collection (SQLite ``created_at`` string)."""
248
+ await init_jobs_db(db_path)
249
+ async with aiosqlite.connect(db_path) as conn:
250
+ conn.row_factory = aiosqlite.Row
251
+ cursor = await conn.execute(
252
+ """
253
+ SELECT MIN(created_at) AS earliest
254
+ FROM ingest_jobs
255
+ WHERE collection_name = ?
256
+ """,
257
+ (collection_name,),
258
+ )
259
+ row = await cursor.fetchone()
260
+ if row is None or row["earliest"] is None:
261
+ return None
262
+ return str(row["earliest"])
263
+
264
+
265
+ async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> tuple[list[JobListItem], int]:
266
+ """Recent jobs summary list and total count for pagination."""
267
+ await init_jobs_db(db_path)
268
+ async with aiosqlite.connect(db_path) as conn:
269
+ conn.row_factory = aiosqlite.Row
270
+ cur_total = await conn.execute("SELECT COUNT(*) AS c FROM ingest_jobs")
271
+ total_row = await cur_total.fetchone()
272
+ total = int(total_row["c"]) if total_row else 0
273
+ cursor = await conn.execute(
274
+ """
275
+ SELECT job_id, status, total_files, completed_at
276
+ FROM ingest_jobs
277
+ ORDER BY datetime(updated_at) DESC, rowid DESC
278
+ LIMIT ? OFFSET ?
279
+ """,
280
+ (limit, offset),
281
+ )
282
+ rows = await cursor.fetchall()
283
+ items = [
284
+ JobListItem(
285
+ job_id=str(r["job_id"]),
286
+ status=str(r["status"]),
287
+ total_files=int(r["total_files"] or 0),
288
+ completed_at=_parse_dt(r["completed_at"]),
289
+ )
290
+ for r in rows
291
+ ]
292
+ return items, total
293
+
294
+
295
+ def _parse_dt(value: object) -> datetime | None:
296
+ if value is None or value == "":
297
+ return None
298
+ s = str(value).strip()
299
+ if not s:
300
+ return None
301
+ if s.endswith("Z"):
302
+ s = s[:-1] + "+00:00"
303
+ try:
304
+ dt = datetime.fromisoformat(s)
305
+ if dt.tzinfo is None:
306
+ return dt.replace(tzinfo=timezone.utc)
307
+ return dt
308
+ except ValueError:
309
+ return None
streamlit_app.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit UI for doc-audi-ai — talks to the FastAPI backend only."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import time
7
+ from typing import Any
8
+
9
+ import httpx
10
+ import streamlit as st
11
+
12
+ DEFAULT_API_BASE = os.environ.get("DOC_AUDI_API_BASE", "http://127.0.0.1:8000")
13
+
14
+ # httpx read timeout for Ask/Summarise: embeddings + LLM on CPU or cold Ollama often exceeds 10 minutes.
15
+ _HTTP_READ_TIMEOUT_DEFAULT_S = 3600.0
16
+ _HTTP_READ_TIMEOUT_MIN_S = 60.0
17
+ _HTTP_READ_TIMEOUT_MAX_S = 7200.0
18
+
19
+
20
+ def _http_read_timeout_seconds() -> float:
21
+ raw = os.environ.get(
22
+ "DOC_AUDI_HTTP_READ_TIMEOUT",
23
+ str(int(_HTTP_READ_TIMEOUT_DEFAULT_S)),
24
+ )
25
+ try:
26
+ read_s = float(raw)
27
+ except ValueError:
28
+ read_s = _HTTP_READ_TIMEOUT_DEFAULT_S
29
+ return max(_HTTP_READ_TIMEOUT_MIN_S, min(read_s, _HTTP_READ_TIMEOUT_MAX_S))
30
+
31
+
32
+ def _http_timeout() -> httpx.Timeout:
33
+ """LLM + embeddings can exceed a few minutes on CPU or cold Ollama; Streamlit uses this, not Uvicorn."""
34
+ read_s = _http_read_timeout_seconds()
35
+ return httpx.Timeout(connect=20.0, read=read_s, write=120.0, pool=30.0)
36
+
37
+
38
+ def _fmt_timeout_hint() -> str:
39
+ cap = int(_http_read_timeout_seconds())
40
+ lo, hi = int(_HTTP_READ_TIMEOUT_MIN_S), int(_HTTP_READ_TIMEOUT_MAX_S)
41
+ return (
42
+ f"The UI stops waiting after **{cap}s** per request (set **DOC_AUDI_HTTP_READ_TIMEOUT**, "
43
+ f"allowed **{lo}–{hi}** s). "
44
+ "Ensure `ollama serve` is running; cold models or CPU inference can exceed a few minutes."
45
+ )
46
+
47
+
48
+ def _api_base() -> str:
49
+ """Resolve API base URL. Whitespace-only sidebar input must not win over default (breaks httpx)."""
50
+ raw = st.session_state.get("api_base")
51
+ if raw is None:
52
+ return DEFAULT_API_BASE.rstrip("/")
53
+ s = str(raw).strip()
54
+ if not s:
55
+ return DEFAULT_API_BASE.rstrip("/")
56
+ return s.rstrip("/")
57
+
58
+
59
+ def _client() -> httpx.Client:
60
+ return httpx.Client(base_url=_api_base(), timeout=_http_timeout())
61
+
62
+
63
+ def _fmt_api_error(exc: httpx.HTTPStatusError) -> str:
64
+ try:
65
+ body = exc.response.json()
66
+ except Exception:
67
+ return f"HTTP {exc.response.status_code}: {exc.response.text[:500]}"
68
+ detail = body.get("detail")
69
+ if isinstance(detail, list):
70
+ parts = []
71
+ for item in detail:
72
+ if isinstance(item, dict):
73
+ loc = item.get("loc", ())
74
+ msg = item.get("msg", "")
75
+ parts.append(f"{'/'.join(str(x) for x in loc)}: {msg}")
76
+ else:
77
+ parts.append(str(item))
78
+ return f"HTTP {exc.response.status_code}: " + "; ".join(parts)
79
+ if detail is not None:
80
+ return f"HTTP {exc.response.status_code}: {detail}"
81
+ return f"HTTP {exc.response.status_code}"
82
+
83
+
84
+ def _fmt_request_error(exc: httpx.RequestError) -> str:
85
+ """Human-readable transport errors (connection, timeouts, TLS, etc.)."""
86
+ base = _api_base()
87
+ if isinstance(exc, httpx.ReadTimeout):
88
+ return (
89
+ f"**Read timeout** — `{base}` did not send a full response in time (embeddings/LLM can be slow). "
90
+ f"{_fmt_timeout_hint()}"
91
+ )
92
+ if isinstance(exc, httpx.ConnectTimeout):
93
+ return (
94
+ f"**Connect timeout** — could not open TCP to `{base}` in time. "
95
+ "Confirm the FastAPI process is listening (`uv run uvicorn api.main:app --host 0.0.0.0 --port 8000`)."
96
+ )
97
+ if isinstance(exc, httpx.ConnectError):
98
+ return (
99
+ f"**Connection failed** — nothing is accepting HTTP at `{base}`: {exc}. "
100
+ "Start the API, or fix **API base URL** / **`DOC_AUDI_API_BASE`** (use `http://127.0.0.1:8000` from the same machine, not `0.0.0.0`)."
101
+ )
102
+ if isinstance(exc, httpx.TimeoutException):
103
+ return f"**Timeout** ({type(exc).__name__}): {exc}. {_fmt_timeout_hint()}"
104
+ return f"**Request error** ({type(exc).__name__}): {exc}. Backend: `{base}`."
105
+
106
+
107
+ def _post_query_ask(
108
+ client: httpx.Client,
109
+ *,
110
+ question: str,
111
+ collection_name: str,
112
+ top_k: int = 5,
113
+ user_id: str = "anonymous",
114
+ ) -> httpx.Response:
115
+ """POST /query/ask (falls back to POST /query on older servers)."""
116
+ body: dict[str, object] = {
117
+ "question": question.strip(),
118
+ "collection_name": collection_name,
119
+ "top_k": top_k,
120
+ "user_id": user_id,
121
+ }
122
+ r = client.post("/query/ask", json=body)
123
+ if r.status_code == 404:
124
+ r = client.post("/query", json=body)
125
+ return r
126
+
127
+
128
+ def _get_audit_logs(
129
+ client: httpx.Client,
130
+ *,
131
+ limit: int,
132
+ offset: int,
133
+ user_id: str | None = None,
134
+ from_date: str | None = None,
135
+ to_date: str | None = None,
136
+ ) -> httpx.Response:
137
+ params: dict[str, object] = {"limit": limit, "offset": offset}
138
+ if user_id:
139
+ params["user_id"] = user_id
140
+ if from_date:
141
+ params["from_date"] = from_date
142
+ if to_date:
143
+ params["to_date"] = to_date
144
+ r = client.get("/audit/logs", params=params)
145
+ if r.status_code == 404:
146
+ r = client.get("/audit", params=params)
147
+ return r
148
+
149
+
150
+ def _get_audit_event_detail(client: httpx.Client, event_id: str) -> httpx.Response:
151
+ r = client.get(f"/audit/logs/{event_id}")
152
+ if r.status_code == 404:
153
+ r = client.get(f"/audit/{event_id}")
154
+ return r
155
+
156
+
157
+ def _health_check() -> tuple[bool, str]:
158
+ try:
159
+ with _client() as c:
160
+ r = c.get("/health")
161
+ r.raise_for_status()
162
+ data = r.json()
163
+ return True, str(data)
164
+ except httpx.HTTPStatusError as e:
165
+ return False, _fmt_api_error(e)
166
+ except httpx.RequestError as e:
167
+ return False, _fmt_request_error(e)
168
+ except Exception as e:
169
+ return False, str(e)
170
+
171
+
172
+ def main() -> None:
173
+ st.set_page_config(page_title="doc-audi-ai", layout="wide")
174
+ if "api_base" not in st.session_state:
175
+ st.session_state.api_base = DEFAULT_API_BASE
176
+
177
+ st.title("doc-audi-ai")
178
+ st.caption("Ingest, query, and audit via the FastAPI backend.")
179
+ st.caption(f"Requests go to: `{_api_base()}`")
180
+
181
+ with st.sidebar:
182
+ st.subheader("Backend")
183
+ st.text_input(
184
+ "API base URL",
185
+ key="api_base",
186
+ placeholder=DEFAULT_API_BASE,
187
+ help=f"Default: {DEFAULT_API_BASE}. Clear the field to use the default.",
188
+ )
189
+ st.caption(
190
+ f"Ask/Summarise wait up to **{int(_http_read_timeout_seconds())}s** per request "
191
+ f"(env `DOC_AUDI_HTTP_READ_TIMEOUT`, range {int(_HTTP_READ_TIMEOUT_MIN_S)}–{int(_HTTP_READ_TIMEOUT_MAX_S)})."
192
+ )
193
+ if st.button("Test connection"):
194
+ ok, msg = _health_check()
195
+ if ok:
196
+ st.success(msg)
197
+ else:
198
+ st.error(msg)
199
+
200
+ tab_upload, tab_jobs, tab_ask, tab_sum, tab_audit = st.tabs(
201
+ ["Upload", "Jobs", "Ask", "Summarise", "Audit"]
202
+ )
203
+
204
+ with tab_upload:
205
+ st.subheader("Upload document")
206
+ col_u1, col_u2 = st.columns(2)
207
+ with col_u1:
208
+ up_collection = st.text_input("Collection", value="default", key="up_col")
209
+ uploaded = st.file_uploader("PDF, TXT, or Markdown", type=["pdf", "txt", "md"], key="up_file")
210
+ with col_u2:
211
+ if st.button("Submit upload", key="btn_upload", disabled=uploaded is None):
212
+ if uploaded is None:
213
+ st.warning("Choose a file first.")
214
+ else:
215
+ try:
216
+ files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")}
217
+ data = {"collection_name": up_collection}
218
+ with _client() as c:
219
+ r = c.post("/ingest/upload", files=files, data=data)
220
+ r.raise_for_status()
221
+ out = r.json()
222
+ st.success(out.get("message", "Queued"))
223
+ st.json(out)
224
+ if out.get("job_id"):
225
+ st.session_state["last_job_id"] = out["job_id"]
226
+ except httpx.HTTPStatusError as e:
227
+ st.error(_fmt_api_error(e))
228
+ except httpx.RequestError as e:
229
+ st.error(_fmt_request_error(e))
230
+ except Exception as e:
231
+ st.exception(e)
232
+
233
+ st.subheader("Ingest from URL")
234
+ url_col = st.columns([3, 1])
235
+ with url_col[0]:
236
+ ingest_url = st.text_input("Document URL (http/https)", key="ingest_url")
237
+ with url_col[1]:
238
+ url_collection = st.text_input("Collection", value="default", key="url_col")
239
+ if st.button("Queue URL ingest", key="btn_url"):
240
+ if not ingest_url.strip():
241
+ st.warning("Enter a URL.")
242
+ else:
243
+ try:
244
+ with _client() as c:
245
+ r = c.post(
246
+ "/ingest/url",
247
+ json={"urls": [ingest_url.strip()], "collection_name": url_collection},
248
+ )
249
+ r.raise_for_status()
250
+ out = r.json()
251
+ st.success(out.get("message", "Queued"))
252
+ st.json(out)
253
+ if out.get("job_id"):
254
+ st.session_state["last_job_id"] = out["job_id"]
255
+ except httpx.HTTPStatusError as e:
256
+ st.error(_fmt_api_error(e))
257
+ except httpx.RequestError as e:
258
+ st.error(_fmt_request_error(e))
259
+ except Exception as e:
260
+ st.exception(e)
261
+
262
+ st.subheader("Collections")
263
+ if st.button("Refresh collections", key="btn_collections"):
264
+ try:
265
+ with _client() as c:
266
+ r = c.get("/ingest/collections")
267
+ r.raise_for_status()
268
+ cols = r.json()
269
+ rows = cols.get("collections", [])
270
+ st.write(f"{cols.get('total', len(rows))} collection(s).")
271
+ if rows:
272
+ st.dataframe(rows, hide_index=True, use_container_width=True)
273
+ else:
274
+ st.info("No collections yet.")
275
+ except httpx.HTTPStatusError as e:
276
+ st.error(_fmt_api_error(e))
277
+ except httpx.RequestError as e:
278
+ st.error(_fmt_request_error(e))
279
+ except Exception as e:
280
+ st.exception(e)
281
+
282
+ del_name = st.text_input("Delete collection name (optional)", key="del_col")
283
+ if st.button("Delete collection", key="btn_del_col"):
284
+ if not del_name.strip():
285
+ st.warning("Enter a collection name.")
286
+ else:
287
+ try:
288
+ with _client() as c:
289
+ r = c.delete(f"/ingest/collection/{del_name.strip()}")
290
+ r.raise_for_status()
291
+ del_body = r.json()
292
+ st.success(del_body.get("message", "Deleted"))
293
+ if "documents_removed" in del_body:
294
+ st.caption(f"Documents removed: **{del_body['documents_removed']}**")
295
+ except httpx.HTTPStatusError as e:
296
+ st.error(_fmt_api_error(e))
297
+ except httpx.RequestError as e:
298
+ st.error(_fmt_request_error(e))
299
+ except Exception as e:
300
+ st.exception(e)
301
+
302
+ with tab_jobs:
303
+ st.subheader("Job list")
304
+ j1, j2 = st.columns(2)
305
+ with j1:
306
+ j_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="j_lim")
307
+ with j2:
308
+ j_offset = st.number_input("Offset", min_value=0, value=0, key="j_off")
309
+ if st.button("List jobs", key="btn_jobs"):
310
+ try:
311
+ with _client() as c:
312
+ r = c.get("/jobs", params={"limit": int(j_limit), "offset": int(j_offset)})
313
+ r.raise_for_status()
314
+ payload = r.json()
315
+ jobs: list[dict[str, Any]] = payload.get("jobs", [])
316
+ st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**")
317
+ if jobs:
318
+ st.dataframe(jobs, hide_index=True, use_container_width=True)
319
+ else:
320
+ st.info("No jobs in this window.")
321
+ except httpx.HTTPStatusError as e:
322
+ st.error(_fmt_api_error(e))
323
+ except httpx.RequestError as e:
324
+ st.error(_fmt_request_error(e))
325
+ except Exception as e:
326
+ st.exception(e)
327
+
328
+ st.subheader("Job detail")
329
+ default_job = st.session_state.get("last_job_id", "")
330
+ job_id = st.text_input("Job ID", value=default_job, key="job_id_in")
331
+ c1, c2 = st.columns(2)
332
+ with c1:
333
+ fetch_job = st.button("Fetch job", key="btn_job_one")
334
+ with c2:
335
+ poll_job = st.button("Poll until completed/failed", key="btn_job_poll")
336
+
337
+ if fetch_job and job_id.strip():
338
+ try:
339
+ with _client() as c:
340
+ r = c.get(f"/jobs/{job_id.strip()}")
341
+ r.raise_for_status()
342
+ detail = r.json()
343
+ st.json(detail)
344
+ except httpx.HTTPStatusError as e:
345
+ st.error(_fmt_api_error(e))
346
+ except httpx.RequestError as e:
347
+ st.error(_fmt_request_error(e))
348
+ except Exception as e:
349
+ st.exception(e)
350
+
351
+ if poll_job and job_id.strip():
352
+ status_ph = st.empty()
353
+ try:
354
+ with _client() as c:
355
+ for i in range(120):
356
+ r = c.get(f"/jobs/{job_id.strip()}")
357
+ r.raise_for_status()
358
+ body = r.json()
359
+ st_ = body.get("status", "")
360
+ status_ph.write(f"Poll {i + 1}: **{st_}** — {body.get('progress_percent', 0)}%")
361
+ if st_ in ("completed", "failed"):
362
+ st.json(body)
363
+ break
364
+ time.sleep(1)
365
+ else:
366
+ status_ph.write("Stopped after 120 attempts (~2 min).")
367
+ st.json(body)
368
+ except httpx.HTTPStatusError as e:
369
+ st.error(_fmt_api_error(e))
370
+ except httpx.RequestError as e:
371
+ st.error(_fmt_request_error(e))
372
+ except Exception as e:
373
+ st.exception(e)
374
+
375
+ with tab_ask:
376
+ st.subheader("Ask a question")
377
+ q_col = st.text_input("Collection", value="default", key="ask_col")
378
+ question = st.text_area("Question", height=120, key="ask_q")
379
+ if st.button("Ask", key="btn_ask"):
380
+ if not question.strip():
381
+ st.warning("Enter a question.")
382
+ else:
383
+ try:
384
+ with st.spinner(
385
+ "Calling the API (embeddings + LLM can take several minutes on a slow machine; "
386
+ "ensure Ollama is running). Timeout is controlled by DOC_AUDI_HTTP_READ_TIMEOUT…"
387
+ ):
388
+ with _client() as c:
389
+ r = _post_query_ask(
390
+ c,
391
+ question=question,
392
+ collection_name=q_col,
393
+ )
394
+ r.raise_for_status()
395
+ ans = r.json()
396
+ st.success(f"Query id: `{ans.get('query_id', '')}`")
397
+ if ans.get("answer"):
398
+ st.markdown("### Answer")
399
+ st.markdown(ans["answer"])
400
+ else:
401
+ st.warning(
402
+ "The API returned no **answer** text. "
403
+ "Check the collection has ingested chunks, LLM env, and expand **Raw response** below."
404
+ )
405
+ src = ans.get("sources") or []
406
+ if src:
407
+ with st.expander(f"Sources ({len(src)})"):
408
+ st.json(src)
409
+ else:
410
+ st.caption("No sources in this response (empty retrieval or model returned nothing).")
411
+ with st.expander("Raw response (debug)"):
412
+ st.json(ans)
413
+ except httpx.HTTPStatusError as e:
414
+ st.error(_fmt_api_error(e))
415
+ except httpx.RequestError as e:
416
+ st.error(_fmt_request_error(e))
417
+ except Exception as e:
418
+ st.exception(e)
419
+
420
+ with tab_sum:
421
+ st.subheader("Summarise collection")
422
+ s_col = st.text_input("Collection", value="default", key="sum_col")
423
+ focus = st.text_input("Optional focus / angle", value="", key="sum_focus")
424
+ if st.button("Summarise", key="btn_sum"):
425
+ try:
426
+ body: dict[str, Any] = {"collection_name": s_col}
427
+ if focus.strip():
428
+ body["focus"] = focus.strip()
429
+ with st.spinner("Calling summarise (can take 1–2 minutes on a cold model)…"):
430
+ with _client() as c:
431
+ r = c.post("/query/summarise", json=body)
432
+ r.raise_for_status()
433
+ ans = r.json()
434
+ st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**")
435
+ summary_text = ans.get("summary") or ans.get("answer")
436
+ if summary_text:
437
+ st.markdown("### Summary")
438
+ st.markdown(summary_text)
439
+ else:
440
+ st.warning("No summary text in the response; see **Raw response** below.")
441
+ src = ans.get("sources") or []
442
+ if src:
443
+ with st.expander(f"Sources ({len(src)})"):
444
+ st.json(src)
445
+ with st.expander("Raw response (debug)"):
446
+ st.json(ans)
447
+ except httpx.HTTPStatusError as e:
448
+ st.error(_fmt_api_error(e))
449
+ except httpx.RequestError as e:
450
+ st.error(_fmt_request_error(e))
451
+ except Exception as e:
452
+ st.exception(e)
453
+
454
+ with tab_audit:
455
+ st.subheader("Audit log")
456
+ a1, a2 = st.columns(2)
457
+ with a1:
458
+ a_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="a_lim")
459
+ with a2:
460
+ a_offset = st.number_input("Offset", min_value=0, value=0, key="a_off")
461
+ if st.button("List audit events", key="btn_audit_list"):
462
+ try:
463
+ with _client() as c:
464
+ r = _get_audit_logs(
465
+ c,
466
+ limit=int(a_limit),
467
+ offset=int(a_offset),
468
+ )
469
+ r.raise_for_status()
470
+ payload = r.json()
471
+ events = payload.get("logs", payload.get("events", []))
472
+ st.caption(f"Total matching: **{payload.get('total', len(events))}**")
473
+ if events:
474
+ st.dataframe(events, hide_index=True, use_container_width=True)
475
+ ids = [
476
+ e.get("query_id") or e.get("event_id")
477
+ for e in events
478
+ if isinstance(e, dict) and (e.get("query_id") or e.get("event_id"))
479
+ ]
480
+ if ids:
481
+ st.session_state["_audit_ids"] = ids
482
+ else:
483
+ st.info("No audit events.")
484
+ except httpx.HTTPStatusError as e:
485
+ st.error(_fmt_api_error(e))
486
+ except httpx.RequestError as e:
487
+ st.error(_fmt_request_error(e))
488
+ except Exception as e:
489
+ st.exception(e)
490
+
491
+ st.subheader("Audit event detail")
492
+ ids_for_select = st.session_state.get("_audit_ids", [])
493
+ pick = ""
494
+ if ids_for_select:
495
+ pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
496
+ manual_id = st.text_input("Or enter query / event ID", key="audit_manual")
497
+ ev_id = (manual_id.strip() or (pick or "").strip()).strip()
498
+ if st.button("Load detail", key="btn_audit_detail") and ev_id:
499
+ try:
500
+ with _client() as c:
501
+ r = _get_audit_event_detail(c, ev_id)
502
+ r.raise_for_status()
503
+ st.json(r.json())
504
+ except httpx.HTTPStatusError as e:
505
+ st.error(_fmt_api_error(e))
506
+ except httpx.RequestError as e:
507
+ st.error(_fmt_request_error(e))
508
+ except Exception as e:
509
+ st.exception(e)
510
+
511
+
512
+ if __name__ == "__main__":
513
+ main()
tests/conftest.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytest fixtures: isolated temp DB/Chroma paths and a patched FastAPI test client."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+ from fastapi.testclient import TestClient
8
+
9
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
10
+ if str(PROJECT_ROOT) not in sys.path:
11
+ sys.path.insert(0, str(PROJECT_ROOT))
12
+
13
+ from api.config import Settings
14
+ from api.main import app
15
+
16
+
17
+ @pytest.fixture
18
+ def test_settings(tmp_path) -> Settings:
19
+ return Settings(
20
+ llm_provider="ollama",
21
+ chroma_persist_directory=str(tmp_path / "chroma"),
22
+ audit_db_path=str(tmp_path / "audit.db"),
23
+ jobs_db_path=str(tmp_path / "jobs.db"),
24
+ max_file_size_mb=1,
25
+ top_k_results=3,
26
+ )
27
+
28
+
29
+ @pytest.fixture
30
+ def settings(test_settings) -> Settings:
31
+ """Alias for audit tests that name the fixture `settings`."""
32
+ return test_settings
33
+
34
+
35
+ @pytest.fixture
36
+ def client(test_settings, monkeypatch):
37
+ monkeypatch.setattr("api.main.get_settings", lambda: test_settings)
38
+ for route_mod in ("ingest", "query", "audit", "jobs"):
39
+ monkeypatch.setattr(f"api.routes.{route_mod}.get_settings", lambda ts=test_settings: ts)
40
+ with TestClient(app) as test_client:
41
+ yield test_client
tests/test_audit.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for audit log list, detail, filters, and post-query persistence."""
2
+
3
+ import asyncio
4
+ from unittest.mock import AsyncMock
5
+ from uuid import uuid4
6
+
7
+ import pytest
8
+ from fastapi.testclient import TestClient
9
+
10
+ from api.config import Settings
11
+ from api.main import app
12
+ from models.responses import SourceCitation
13
+ from rag.retriever import RetrievedChunk
14
+ from storage.audit_store import persist_query_audit
15
+
16
+
17
+ def _seed_audit(settings: Settings, question: str = "What are key risks?", user_id: str = "analyst_001") -> str:
18
+ query_id = str(uuid4())
19
+ asyncio.run(
20
+ persist_query_audit(
21
+ settings.audit_db_path,
22
+ query_id=query_id,
23
+ action="query",
24
+ user_id=user_id,
25
+ question=question,
26
+ collection_name="default",
27
+ answer="Grounded answer text for audit trail.",
28
+ sources=[
29
+ SourceCitation(
30
+ document_name="report.pdf",
31
+ page_number=3,
32
+ chunk_text="Risk disclosure excerpt.",
33
+ relevance_score=0.9,
34
+ )
35
+ ],
36
+ model_used="ollama:llama3.1:8b",
37
+ tokens_used=120,
38
+ response_time_ms=50,
39
+ kind="ask",
40
+ )
41
+ )
42
+ return query_id
43
+
44
+
45
+ def test_audit_logs_and_detail_success(client, settings):
46
+ query_id = _seed_audit(settings)
47
+
48
+ list_response = client.get("/audit/logs?limit=10&offset=0")
49
+ assert list_response.status_code == 200
50
+ body = list_response.json()
51
+ assert "logs" in body
52
+ assert body["total"] >= 1
53
+ assert any(entry["query_id"] == query_id for entry in body["logs"])
54
+
55
+ detail_response = client.get(f"/audit/logs/{query_id}")
56
+ assert detail_response.status_code == 200
57
+ detail = detail_response.json()
58
+ assert detail["query_id"] == query_id
59
+ assert detail["question"] == "What are key risks?"
60
+ assert detail["full_answer"] == "Grounded answer text for audit trail."
61
+ assert len(detail["sources"]) == 1
62
+ assert detail["sources"][0]["document_name"] == "report.pdf"
63
+
64
+
65
+ def test_audit_logs_filter_by_user_id(client, settings):
66
+ q1 = _seed_audit(settings, question="Q one", user_id="user_a")
67
+ _seed_audit(settings, question="Q two", user_id="user_b")
68
+
69
+ r = client.get("/audit/logs", params={"user_id": "user_a", "limit": 50, "offset": 0})
70
+ assert r.status_code == 200
71
+ body = r.json()
72
+ ids = {e["query_id"] for e in body["logs"]}
73
+ assert q1 in ids
74
+ assert all(e["user_id"] == "user_a" for e in body["logs"])
75
+
76
+
77
+ def test_audit_logs_filter_by_from_date(client, settings):
78
+ query_id = str(uuid4())
79
+ asyncio.run(
80
+ persist_query_audit(
81
+ settings.audit_db_path,
82
+ query_id=query_id,
83
+ action="query",
84
+ user_id="u",
85
+ question="Future dated row",
86
+ collection_name="default",
87
+ answer="A",
88
+ sources=[],
89
+ model_used="m",
90
+ tokens_used=0,
91
+ response_time_ms=1,
92
+ kind="ask",
93
+ )
94
+ )
95
+ r = client.get("/audit/logs", params={"from_date": "2099-01-01T00:00:00Z", "limit": 50, "offset": 0})
96
+ assert r.status_code == 200
97
+ body = r.json()
98
+ assert query_id not in {e["query_id"] for e in body["logs"]}
99
+
100
+
101
+ def test_audit_logs_filter_by_to_date(client, settings):
102
+ """Spec: date filtering on /audit/logs (upper bound)."""
103
+ query_id = str(uuid4())
104
+ asyncio.run(
105
+ persist_query_audit(
106
+ settings.audit_db_path,
107
+ query_id=query_id,
108
+ action="query",
109
+ user_id="u",
110
+ question="Recent row",
111
+ collection_name="default",
112
+ answer="B",
113
+ sources=[],
114
+ model_used="m",
115
+ tokens_used=0,
116
+ response_time_ms=1,
117
+ kind="ask",
118
+ )
119
+ )
120
+ r = client.get("/audit/logs", params={"to_date": "2000-01-01T00:00:00Z", "limit": 50, "offset": 0})
121
+ assert r.status_code == 200
122
+ body = r.json()
123
+ assert query_id not in {e["query_id"] for e in body["logs"]}
124
+
125
+
126
+ def test_ask_is_logged_after_query_ask(client, monkeypatch):
127
+ """Spec: ask is logged after POST /query/ask."""
128
+ chunks = [
129
+ RetrievedChunk(
130
+ text="Audit trail test chunk.",
131
+ score=0.9,
132
+ source="audit-test.txt",
133
+ page=1,
134
+ chunk_index=0,
135
+ )
136
+ ]
137
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
138
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
139
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
140
+ monkeypatch.setattr(
141
+ "api.routes.query.answer_with_grounding",
142
+ lambda *_: ("Answer stored in audit.", 11),
143
+ )
144
+
145
+ ask = client.post(
146
+ "/query/ask",
147
+ json={
148
+ "question": "What should appear in the audit log?",
149
+ "collection_name": "default",
150
+ "user_id": "audit_user",
151
+ },
152
+ )
153
+ assert ask.status_code == 200
154
+ query_id = ask.json()["query_id"]
155
+
156
+ detail = client.get(f"/audit/logs/{query_id}")
157
+ assert detail.status_code == 200
158
+ body = detail.json()
159
+ assert body["user_id"] == "audit_user"
160
+ assert body["full_answer"] == "Answer stored in audit."
161
+ assert body["question"] == "What should appear in the audit log?"
162
+
163
+
164
+ def test_summarise_is_logged_after_query_summarise(client, monkeypatch):
165
+ """Spec: summarise is logged after POST /query/summarise."""
166
+ chunks = [
167
+ RetrievedChunk(
168
+ text="Summary source chunk.",
169
+ score=0.85,
170
+ source="summary.md",
171
+ page=2,
172
+ chunk_index=0,
173
+ )
174
+ ]
175
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
176
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
177
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
178
+ monkeypatch.setattr(
179
+ "api.routes.query.summarise_with_grounding",
180
+ lambda *_, **__: ("Collection summary for audit.", 7),
181
+ )
182
+ monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 2)
183
+
184
+ summarise = client.post(
185
+ "/query/summarise",
186
+ json={"collection_name": "default", "focus": "key themes", "user_id": "sum_user"},
187
+ )
188
+ assert summarise.status_code == 200
189
+ query_id = summarise.json()["query_id"]
190
+
191
+ detail = client.get(f"/audit/logs/{query_id}")
192
+ assert detail.status_code == 200
193
+ assert detail.json()["full_answer"] == "Collection summary for audit."
194
+ assert detail.json()["user_id"] == "sum_user"
195
+
196
+
197
+ def test_audit_logs_validation_error_for_bad_limit(client):
198
+ response = client.get("/audit/logs?limit=0&offset=0")
199
+ assert response.status_code == 422
200
+
201
+
202
+ def test_audit_detail_not_found(client):
203
+ response = client.get("/audit/logs/does-not-exist")
204
+ assert response.status_code == 404
205
+ assert "not found" in response.json()["detail"].lower()
206
+
207
+
208
+ def test_audit_logs_returns_500_on_store_failure(settings, monkeypatch):
209
+ monkeypatch.setattr("api.main.get_settings", lambda: settings)
210
+ monkeypatch.setattr("api.routes.audit.get_settings", lambda: settings)
211
+ monkeypatch.setattr(
212
+ "api.routes.audit.list_audit_events",
213
+ AsyncMock(side_effect=RuntimeError("audit store failure")),
214
+ )
215
+ with TestClient(app, raise_server_exceptions=False) as test_client:
216
+ response = test_client.get("/audit/logs")
217
+
218
+ assert response.status_code == 500
tests/test_config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Settings behaviour for Hugging Face Spaces and Hub tokens."""
2
+
3
+ from api.config import Settings
4
+
5
+
6
+ def test_space_id_without_llm_provider_env_uses_huggingface_and_hf_token(monkeypatch):
7
+ monkeypatch.setenv("SPACE_ID", "author/repo")
8
+ monkeypatch.delenv("LLM_PROVIDER", raising=False)
9
+ monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
10
+ monkeypatch.setenv("HF_TOKEN", "hf_test_token")
11
+ s = Settings(_env_file=None)
12
+ assert s.llm_provider == "huggingface"
13
+ assert s.huggingface_api_key == "hf_test_token"
14
+
15
+
16
+ def test_space_id_respects_explicit_llm_provider_ollama(monkeypatch):
17
+ monkeypatch.setenv("SPACE_ID", "author/repo")
18
+ monkeypatch.setenv("LLM_PROVIDER", "ollama")
19
+ monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
20
+ s = Settings(_env_file=None)
21
+ assert s.llm_provider == "ollama"
tests/test_health.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke test for the liveness endpoint."""
2
+
3
+ def test_health_returns_ok(client):
4
+ response = client.get("/health")
5
+ assert response.status_code == 200
6
+ body = response.json()
7
+ assert body["status"] == "ok"
8
+ assert "app" in body
9
+ assert "version" in body
tests/test_ingest.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ``/ingest`` upload, URL ingest, and collection management."""
2
+
3
+ import asyncio
4
+ from unittest.mock import AsyncMock
5
+
6
+ from api.routes import ingest as ingest_route
7
+ from storage.job_store import create_ingest_job, mark_job_processing
8
+
9
+
10
+ def test_upload_queues_job_success(client, monkeypatch):
11
+ monkeypatch.setattr("api.routes.ingest.create_ingest_job", AsyncMock(return_value="job-123"))
12
+ monkeypatch.setattr("api.routes.ingest.run_ingest_job", AsyncMock(return_value=None))
13
+
14
+ response = client.post(
15
+ "/ingest/upload",
16
+ data={"collection_name": "default"},
17
+ files=[("files", ("sample.txt", b"hello world", "text/plain"))],
18
+ )
19
+
20
+ assert response.status_code == 200
21
+ body = response.json()
22
+ assert body["status"] == "queued"
23
+ assert body["job_id"] == "job-123"
24
+ assert body["total_files"] == 1
25
+ assert body["filenames"] == ["sample.txt"]
26
+ assert "Poll /jobs/job-123" in body["message"]
27
+
28
+
29
+ def test_upload_rejects_unsupported_extension(client):
30
+ response = client.post(
31
+ "/ingest/upload",
32
+ data={"collection_name": "default"},
33
+ files=[("files", ("sample.csv", b"a,b\n1,2", "text/csv"))],
34
+ )
35
+
36
+ assert response.status_code == 400
37
+ assert "Unsupported file type" in response.json()["detail"]
38
+
39
+
40
+ def test_upload_rejects_oversized_file(client):
41
+ oversized = b"x" * (2 * 1024 * 1024)
42
+ response = client.post(
43
+ "/ingest/upload",
44
+ data={"collection_name": "default"},
45
+ files=[("files", ("large.txt", oversized, "text/plain"))],
46
+ )
47
+
48
+ assert response.status_code == 413
49
+ assert "too large" in response.json()["detail"].lower()
50
+
51
+
52
+ def test_upload_returns_500_on_job_creation_error(client, monkeypatch):
53
+ monkeypatch.setattr(
54
+ "api.routes.ingest.create_ingest_job",
55
+ AsyncMock(side_effect=RuntimeError("job store unavailable")),
56
+ )
57
+ monkeypatch.setattr("api.routes.ingest.run_ingest_job", AsyncMock(return_value=None))
58
+
59
+ response = client.post(
60
+ "/ingest/upload",
61
+ data={"collection_name": "default"},
62
+ files=[("files", ("sample.txt", b"hello", "text/plain"))],
63
+ )
64
+
65
+ assert response.status_code == 500
66
+ assert "job store unavailable" in response.json()["detail"]
67
+
68
+
69
+ def test_download_request_headers_sec_compliant():
70
+ headers = ingest_route._download_request_headers("DocuAudit AI test@example.com")
71
+ assert headers["User-Agent"] == "DocuAudit AI test@example.com"
72
+ assert headers["Accept-Encoding"] == "gzip, deflate"
73
+ assert "application/pdf" in headers["Accept"]
74
+
75
+
76
+ def test_ingest_url_rejects_non_http_scheme(client, monkeypatch):
77
+ monkeypatch.setattr(
78
+ "api.routes.ingest._download_url_to_temp",
79
+ AsyncMock(
80
+ side_effect=ingest_route.HTTPException(status_code=400, detail="Only http and https URLs are supported.")
81
+ ),
82
+ )
83
+
84
+ response = client.post(
85
+ "/ingest/url",
86
+ json={"urls": ["https://example.com/file.txt"], "collection_name": "default"},
87
+ )
88
+
89
+ assert response.status_code == 400
90
+ assert "http and https" in response.json()["detail"]
91
+
92
+
93
+ def test_upload_pdf_queues_job_with_job_id(client, monkeypatch):
94
+ """Spec: single PDF upload returns job_id."""
95
+ monkeypatch.setattr("api.routes.ingest.create_ingest_job", AsyncMock(return_value="pdf-job-99"))
96
+ monkeypatch.setattr("api.routes.ingest.run_ingest_job", AsyncMock(return_value=None))
97
+
98
+ response = client.post(
99
+ "/ingest/upload",
100
+ data={"collection_name": "default"},
101
+ files=[("files", ("brief.pdf", b"%PDF-1.4 minimal", "application/pdf"))],
102
+ )
103
+
104
+ assert response.status_code == 200
105
+ body = response.json()
106
+ assert body["job_id"] == "pdf-job-99"
107
+ assert body["filenames"] == ["brief.pdf"]
108
+
109
+
110
+ def test_list_collections_backfills_created_at_from_jobs(client, test_settings, monkeypatch):
111
+ monkeypatch.setattr(
112
+ "api.routes.ingest.list_collection_names",
113
+ lambda *_: ["default"],
114
+ )
115
+ monkeypatch.setattr("api.routes.ingest.collection_document_count", lambda *_: 3)
116
+ monkeypatch.setattr("api.routes.ingest.collection_created_at", lambda *_: None)
117
+ monkeypatch.setattr(
118
+ "api.routes.ingest.earliest_job_created_at_for_collection",
119
+ AsyncMock(return_value="2026-05-21 07:05:38"),
120
+ )
121
+ monkeypatch.setattr(
122
+ "api.routes.ingest.ensure_collection_created_at",
123
+ lambda *_a, **_k: "2026-05-21T07:05:38Z",
124
+ )
125
+
126
+ response = client.get("/ingest/collections")
127
+ assert response.status_code == 200
128
+ body = response.json()
129
+ assert body["total"] == 1
130
+ assert body["collections"][0]["name"] == "default"
131
+ assert body["collections"][0]["document_count"] == 3
132
+ assert body["collections"][0]["created_at"] is not None
133
+
134
+
135
+ def test_job_status_polling_after_real_job_create(client, test_settings):
136
+ """Spec: job status polling returns correct structure."""
137
+ job_id = asyncio.run(
138
+ create_ingest_job(
139
+ test_settings.jobs_db_path,
140
+ collection_name="default",
141
+ filenames=["sample.txt"],
142
+ )
143
+ )
144
+ asyncio.run(mark_job_processing(test_settings.jobs_db_path, job_id))
145
+
146
+ response = client.get(f"/jobs/{job_id}")
147
+ assert response.status_code == 200
148
+ body = response.json()
149
+ assert body["job_id"] == job_id
150
+ assert body["status"] == "processing"
151
+ assert body["total_files"] == 1
152
+ assert "progress_percent" in body
153
+ assert "errors" in body
tests/test_jobs.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ingest job listing and status endpoints."""
2
+
3
+ import asyncio
4
+
5
+ from storage.job_store import create_ingest_job, update_job_progress
6
+
7
+
8
+ def test_get_job_status_returns_spec_shape(client, test_settings):
9
+ job_id = asyncio.run(
10
+ create_ingest_job(
11
+ test_settings.jobs_db_path,
12
+ collection_name="default",
13
+ filenames=["report.pdf", "notes.txt"],
14
+ )
15
+ )
16
+ asyncio.run(
17
+ update_job_progress(
18
+ test_settings.jobs_db_path,
19
+ job_id,
20
+ processed_files=1,
21
+ failed_files=0,
22
+ errors=[],
23
+ message="Processing first file",
24
+ )
25
+ )
26
+
27
+ response = client.get(f"/jobs/{job_id}")
28
+ assert response.status_code == 200
29
+ body = response.json()
30
+ assert body["job_id"] == job_id
31
+ assert body["status"] in ("queued", "processing", "completed", "failed")
32
+ assert body["total_files"] == 2
33
+ assert body["processed_files"] == 1
34
+ assert body["failed_files"] == 0
35
+ assert 0 <= body["progress_percent"] <= 100
36
+ assert isinstance(body["errors"], list)
37
+
38
+
39
+ def test_list_jobs_includes_total(client, test_settings):
40
+ job_id = asyncio.run(
41
+ create_ingest_job(
42
+ test_settings.jobs_db_path,
43
+ collection_name="default",
44
+ filenames=["sample.txt"],
45
+ )
46
+ )
47
+
48
+ response = client.get("/jobs", params={"limit": 10, "offset": 0})
49
+ assert response.status_code == 200
50
+ body = response.json()
51
+ assert body["total"] >= 1
52
+ assert any(j["job_id"] == job_id for j in body["jobs"])
53
+
54
+
55
+ def test_get_job_not_found_returns_404(client):
56
+ response = client.get("/jobs/nonexistent-job-id")
57
+ assert response.status_code == 404
58
+ assert "not found" in response.json()["detail"].lower()
tests/test_query.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ``/query/ask``, ``/query/summarise``, and legacy ``POST /query``."""
2
+
3
+ from unittest.mock import AsyncMock
4
+
5
+ from rag.retriever import NO_MATCH_ANSWER, RetrievedChunk
6
+
7
+
8
+ def test_ask_returns_grounded_answer_with_sources(client, monkeypatch):
9
+ chunks = [
10
+ RetrievedChunk(
11
+ text="Audi has strategic EV expansion plans.",
12
+ score=0.92,
13
+ source="strategy.md",
14
+ page=1,
15
+ chunk_index=0,
16
+ )
17
+ ]
18
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
19
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
20
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
21
+ monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42))
22
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
23
+
24
+ response = client.post(
25
+ "/query/ask",
26
+ json={
27
+ "question": "What is Audi doing in EV markets worldwide?",
28
+ "collection_name": "default",
29
+ "top_k": 3,
30
+ "user_id": "tester",
31
+ },
32
+ )
33
+
34
+ assert response.status_code == 200
35
+ body = response.json()
36
+ assert body["answer"] == "Audi is expanding EV investment."
37
+ assert "query_id" in body
38
+ assert body["question"].startswith("What is Audi")
39
+ assert len(body["sources"]) == 1
40
+ assert body["sources"][0]["document_name"] == "strategy.md"
41
+ assert body["sources"][0]["page_number"] == 1
42
+ assert body["tokens_used"] == 42
43
+ assert "response_time_ms" in body
44
+ assert "model_used" in body
45
+
46
+
47
+ def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch):
48
+ captured: dict[str, object] = {}
49
+
50
+ def capture_retrieve(vs, question, k):
51
+ captured["k"] = k
52
+ return []
53
+
54
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
55
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
56
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve)
57
+ monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0))
58
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
59
+
60
+ response = client.post(
61
+ "/query/ask",
62
+ json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7},
63
+ )
64
+ assert response.status_code == 200
65
+ assert captured.get("k") == 7
66
+
67
+
68
+ def test_ask_empty_collection_returns_no_match_message(client, monkeypatch):
69
+ """Spec: query on empty collection returns appropriate message."""
70
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
71
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
72
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: [])
73
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
74
+
75
+ response = client.post(
76
+ "/query/ask",
77
+ json={
78
+ "question": "What does the document say about revenue?",
79
+ "collection_name": "default",
80
+ "top_k": 5,
81
+ },
82
+ )
83
+
84
+ assert response.status_code == 200
85
+ assert response.json()["answer"] == NO_MATCH_ANSWER
86
+ assert response.json()["sources"] == []
87
+
88
+
89
+ def test_ask_low_relevance_chunks_returns_no_match_message(client, monkeypatch):
90
+ low_score_chunks = [
91
+ RetrievedChunk(
92
+ text="Unrelated fragment.",
93
+ score=0.05,
94
+ source="noise.txt",
95
+ page=1,
96
+ chunk_index=0,
97
+ )
98
+ ]
99
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
100
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
101
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: low_score_chunks)
102
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
103
+
104
+ response = client.post(
105
+ "/query/ask",
106
+ json={"question": "What are the key risk factors?", "collection_name": "default"},
107
+ )
108
+
109
+ assert response.status_code == 200
110
+ assert response.json()["answer"] == NO_MATCH_ANSWER
111
+
112
+
113
+ def test_ask_returns_422_for_invalid_payload(client):
114
+ response = client.post("/query/ask", json={"collection_name": "default"})
115
+ assert response.status_code == 422
116
+
117
+
118
+ def test_ask_returns_422_for_short_question(client):
119
+ response = client.post(
120
+ "/query/ask",
121
+ json={"question": "hi", "collection_name": "default"},
122
+ )
123
+ assert response.status_code == 422
124
+
125
+
126
+ def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
127
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
128
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
129
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: (_ for _ in ()).throw(RuntimeError("retrieval failed")))
130
+
131
+ response = client.post(
132
+ "/query/ask",
133
+ json={"question": "What happened in the documents?", "collection_name": "default"},
134
+ )
135
+
136
+ assert response.status_code == 500
137
+ assert "retrieval failed" in response.json()["detail"]
138
+
139
+
140
+ def test_summarise_returns_summary_payload(client, monkeypatch):
141
+ """Spec: /query/summarise returns summary payload when collection has documents."""
142
+ chunks = [
143
+ RetrievedChunk(
144
+ text="Revenue grew year over year.",
145
+ score=0.9,
146
+ source="report.txt",
147
+ page=2,
148
+ chunk_index=0,
149
+ )
150
+ ]
151
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
152
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
153
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
154
+ monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Executive summary text.", 25))
155
+ monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 3)
156
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
157
+
158
+ response = client.post(
159
+ "/query/summarise",
160
+ json={"collection_name": "default", "focus": "financial highlights", "user_id": "analyst"},
161
+ )
162
+
163
+ assert response.status_code == 200
164
+ body = response.json()
165
+ assert body["summary"] == "Executive summary text."
166
+ assert body["document_count"] == 3
167
+ assert "query_id" in body
168
+ assert len(body["sources"]) == 1
169
+ assert body["sources"][0]["document_name"] == "report.txt"
170
+
171
+
172
+ def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
173
+ chunks = [
174
+ RetrievedChunk(
175
+ text="Revenue and risks are discussed in the report.",
176
+ score=0.88,
177
+ source="report.txt",
178
+ page=None,
179
+ chunk_index=2,
180
+ )
181
+ ]
182
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
183
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
184
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
185
+ monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10))
186
+ monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5)
187
+ monkeypatch.setattr(
188
+ "api.routes.query.persist_query_audit",
189
+ AsyncMock(side_effect=RuntimeError("audit write failed")),
190
+ )
191
+
192
+ response = client.post(
193
+ "/query/summarise",
194
+ json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"},
195
+ )
196
+
197
+ assert response.status_code == 500
198
+ assert "audit write failed" in response.json()["detail"]
199
+
200
+
201
+ def test_legacy_query_endpoint_matches_ask(client, monkeypatch):
202
+ chunks = [
203
+ RetrievedChunk(
204
+ text="Clause about indemnity.",
205
+ score=0.8,
206
+ source="contract.md",
207
+ page=4,
208
+ chunk_index=1,
209
+ )
210
+ ]
211
+ monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
212
+ monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
213
+ monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
214
+ monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Indemnity is capped.", 5))
215
+ monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
216
+
217
+ payload = {
218
+ "question": "What are the indemnity limits in the contract?",
219
+ "collection_name": "default",
220
+ "top_k": 3,
221
+ }
222
+ ask = client.post("/query/ask", json=payload)
223
+ legacy = client.post("/query", json=payload)
224
+
225
+ assert ask.status_code == 200
226
+ assert legacy.status_code == 200
227
+ assert legacy.json()["answer"] == ask.json()["answer"]
228
+ assert "query_id" in legacy.json()
229
+ assert legacy.json()["sources"][0]["document_name"] == "contract.md"
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
workers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Background workers (ingest pipeline)."""
workers/ingest_worker.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Background ingest worker invoked from FastAPI ``BackgroundTasks``.
2
+
3
+ For each temp file: load → chunk → embed → add to Chroma, then update job progress in SQLite.
4
+ Temp files are always deleted in a ``finally`` block.
5
+ """
6
+
7
+ import asyncio
8
+ from pathlib import Path
9
+
10
+ from rag.chunker import chunk_documents
11
+ from rag.embedder import create_embedding_function
12
+ from rag.loader import load_documents
13
+ from rag.vector_store import add_documents, get_vector_store
14
+ from storage.job_store import (
15
+ complete_ingest_job,
16
+ fail_ingest_job,
17
+ mark_job_processing,
18
+ update_job_progress,
19
+ )
20
+
21
+
22
+ def _ingest_one_file_sync(temp_path: str, collection_name: str, chroma_persist_directory: str) -> tuple[list[str], int]:
23
+ """Blocking ingest for one path; returns ``(chunk_vector_ids, chunk_count)``."""
24
+ documents = load_documents(temp_path)
25
+ chunks = chunk_documents(documents)
26
+ if not chunks:
27
+ raise ValueError("No content to ingest.")
28
+ embedding_function = create_embedding_function()
29
+ vector_store = get_vector_store(
30
+ persist_directory=chroma_persist_directory,
31
+ collection_name=collection_name,
32
+ embedding_function=embedding_function,
33
+ )
34
+ document_ids = add_documents(vector_store, chunks)
35
+ return document_ids, len(chunks)
36
+
37
+
38
+ async def run_ingest_job(
39
+ job_id: str,
40
+ files: list[tuple[str, str]],
41
+ collection_name: str,
42
+ jobs_db_path: str,
43
+ chroma_persist_directory: str,
44
+ ) -> None:
45
+ """
46
+ Process one or more temp files for a single job. ``files`` is (temp_path, display_name).
47
+ """
48
+ all_doc_ids: list[str] = []
49
+ errors: list[str] = []
50
+ processed = 0
51
+ failed = 0
52
+ total = len(files)
53
+ if total == 0:
54
+ await fail_ingest_job(jobs_db_path, job_id, message="No files to ingest.")
55
+ return
56
+
57
+ try:
58
+ await mark_job_processing(jobs_db_path, job_id)
59
+ for temp_path, display_name in files:
60
+ try:
61
+ doc_ids, num_chunks = await asyncio.to_thread(
62
+ _ingest_one_file_sync,
63
+ temp_path,
64
+ collection_name,
65
+ chroma_persist_directory,
66
+ )
67
+ all_doc_ids.extend(doc_ids)
68
+ processed += 1
69
+ await update_job_progress(
70
+ jobs_db_path,
71
+ job_id,
72
+ processed_files=processed,
73
+ failed_files=failed,
74
+ errors=errors,
75
+ message=f"Ingested {display_name} ({num_chunks} chunks).",
76
+ )
77
+ except Exception as exc:
78
+ failed += 1
79
+ errors.append(f"{display_name}: {exc}")
80
+ await update_job_progress(
81
+ jobs_db_path,
82
+ job_id,
83
+ processed_files=processed,
84
+ failed_files=failed,
85
+ errors=errors,
86
+ message=f"Failed on {display_name}: {exc}",
87
+ )
88
+ finally:
89
+ Path(temp_path).unlink(missing_ok=True)
90
+
91
+ if processed == 0:
92
+ await fail_ingest_job(
93
+ jobs_db_path,
94
+ job_id,
95
+ message="All files failed ingestion.",
96
+ errors=errors,
97
+ )
98
+ return
99
+
100
+ chunk_note = f"{len(all_doc_ids)} chunk vector(s) across {processed} file(s)."
101
+ await complete_ingest_job(
102
+ jobs_db_path,
103
+ job_id,
104
+ document_ids=all_doc_ids,
105
+ message=f"Ingestion completed. {chunk_note}",
106
+ )
107
+ except Exception as exc:
108
+ await fail_ingest_job(jobs_db_path, job_id, message=str(exc), errors=errors + [str(exc)])