Mayank Chugh commited on
Commit ·
d44b33d
0
Parent(s):
Deploy DocuAudit AI to Hugging Face Space (no binaries)
Browse files- .dockerignore +29 -0
- .env.example +99 -0
- .python-version +1 -0
- Dockerfile +29 -0
- LICENSE +201 -0
- README.md +186 -0
- api/__init__.py +1 -0
- api/config.py +135 -0
- api/main.py +64 -0
- api/routes/__init__.py +1 -0
- api/routes/audit.py +65 -0
- api/routes/ingest.py +348 -0
- api/routes/jobs.py +47 -0
- api/routes/query.py +179 -0
- app.py +117 -0
- docker-compose.yml +67 -0
- main.py +13 -0
- models/__init__.py +1 -0
- models/requests.py +78 -0
- models/responses.py +135 -0
- pyproject.toml +34 -0
- pytest.ini +3 -0
- rag/__init__.py +6 -0
- rag/chunker.py +28 -0
- rag/embedder.py +44 -0
- rag/hf_hub_inference.py +380 -0
- rag/loader.py +35 -0
- rag/retriever.py +218 -0
- rag/vector_store.py +125 -0
- requirements.txt +23 -0
- sample.txt +16 -0
- storage/__init__.py +1 -0
- storage/audit_store.py +295 -0
- storage/job_store.py +309 -0
- streamlit_app.py +513 -0
- tests/conftest.py +41 -0
- tests/test_audit.py +218 -0
- tests/test_config.py +21 -0
- tests/test_health.py +9 -0
- tests/test_ingest.py +153 -0
- tests/test_jobs.py +58 -0
- tests/test_query.py +229 -0
- uv.lock +0 -0
- workers/__init__.py +1 -0
- workers/ingest_worker.py +108 -0
.dockerignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Markdown: omit from build context except repo root README.md (Dockerfile COPY README.md).
|
| 2 |
+
*.md
|
| 3 |
+
**/*.md
|
| 4 |
+
!README.md
|
| 5 |
+
|
| 6 |
+
.git
|
| 7 |
+
.gitignore
|
| 8 |
+
.env
|
| 9 |
+
.venv
|
| 10 |
+
venv
|
| 11 |
+
__pycache__
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*$py.class
|
| 14 |
+
.pytest_cache
|
| 15 |
+
.mypy_cache
|
| 16 |
+
.ruff_cache
|
| 17 |
+
*.egg-info
|
| 18 |
+
dist
|
| 19 |
+
build
|
| 20 |
+
.coverage
|
| 21 |
+
htmlcov
|
| 22 |
+
.DS_Store
|
| 23 |
+
docs
|
| 24 |
+
tests
|
| 25 |
+
.cursor
|
| 26 |
+
terminals
|
| 27 |
+
*.log
|
| 28 |
+
data/chroma
|
| 29 |
+
chroma
|
.env.example
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DocuAudit AI — environment template (see docs/DOCUAUDIT_AI_REQUIREMENTS.md)
|
| 2 |
+
|
| 3 |
+
# LLM Provider: ollama | anthropic | openai | huggingface
|
| 4 |
+
LLM_PROVIDER=ollama
|
| 5 |
+
|
| 6 |
+
# OpenAI (optional)
|
| 7 |
+
OPENAI_API_KEY=
|
| 8 |
+
OPENAI_MODEL=gpt-4o
|
| 9 |
+
OPENAI_EMBEDDING_MODEL=text-embedding-3-small
|
| 10 |
+
|
| 11 |
+
# Anthropic (optional)
|
| 12 |
+
ANTHROPIC_API_KEY=
|
| 13 |
+
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
| 14 |
+
|
| 15 |
+
# Hugging Face Inference API (when LLM_PROVIDER=huggingface — typical on Hugging Face Spaces)
|
| 16 |
+
# Use a fine-grained token with "Make calls to Inference Providers" / Inference API where required.
|
| 17 |
+
HUGGINGFACE_API_KEY=
|
| 18 |
+
# Use a model your Hub gates allow (e.g. Llama 3.8B under “Meta Llama 3”, or Mistral instruct). Llama 3.1 needs its own gate. Chat: hf-inference then router auto.
|
| 19 |
+
#HUGGINGFACE_MODEL=mistralai/Mistral-7B-Instruct-v0.3
|
| 20 |
+
#HUGGINGFACE_MODEL=meta-llama/Meta-Llama-3.1-8B-Instruct
|
| 21 |
+
HUGGINGFACE_MODEL=meta-llama/Meta-Llama-3-8B-Instruct
|
| 22 |
+
HUGGINGFACE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
| 23 |
+
# Optional: huggingface_hub InferenceClient provider. Leave unset: primary hf-inference, then router auto for chat (Mistral instruct ids also try Novita).
|
| 24 |
+
# Use `auto` for router-only primary client (may pick Novita and break some models).
|
| 25 |
+
HUGGINGFACE_INFERENCE_PROVIDER=
|
| 26 |
+
# On Hugging Face Spaces you can omit HUGGINGFACE_API_KEY if the Space provides HF_TOKEN (mapped
|
| 27 |
+
# automatically when LLM_PROVIDER=huggingface). For local .env you can set HF_TOKEN instead.
|
| 28 |
+
|
| 29 |
+
# Ollama (recommended local default)
|
| 30 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 31 |
+
OLLAMA_CHAT_MODEL=llama3.1:8b
|
| 32 |
+
OLLAMA_EMBEDDING_MODEL=nomic-embed-text
|
| 33 |
+
|
| 34 |
+
# App
|
| 35 |
+
APP_NAME=DocuAudit AI
|
| 36 |
+
APP_VERSION=1.0.0
|
| 37 |
+
DEBUG=false
|
| 38 |
+
MAX_FILE_SIZE_MB=50
|
| 39 |
+
# Spec name alias (optional; mapped to MAX_FILE_SIZE_MB in settings)
|
| 40 |
+
MAX_UPLOAD_SIZE_MB=
|
| 41 |
+
|
| 42 |
+
# ChromaDB
|
| 43 |
+
CHROMA_PERSIST_DIRECTORY=./data/chroma
|
| 44 |
+
CHROMA_PERSIST_DIR=
|
| 45 |
+
CHROMA_COLLECTION_NAME=docuaudit_docs
|
| 46 |
+
|
| 47 |
+
# Chunking
|
| 48 |
+
CHUNK_SIZE=1000
|
| 49 |
+
CHUNK_OVERLAP=200
|
| 50 |
+
|
| 51 |
+
# Retrieval default (overridable per request on /query/ask via top_k)
|
| 52 |
+
TOP_K_RESULTS=5
|
| 53 |
+
|
| 54 |
+
# Audit + jobs SQLite
|
| 55 |
+
AUDIT_DB_PATH=./audit.db
|
| 56 |
+
JOBS_DB_PATH=./data/jobs.db
|
| 57 |
+
|
| 58 |
+
# Limits
|
| 59 |
+
MAX_DOCUMENTS_PER_BATCH=100
|
| 60 |
+
|
| 61 |
+
# URL ingest (POST /ingest/url). SEC.gov blocks undeclared bots — use "Company Name you@email.com".
|
| 62 |
+
# INGEST_USER_AGENT=DocuAudit AI you@example.com
|
| 63 |
+
|
| 64 |
+
# Streamlit → API (Streamlit process reads these when set in the shell / OS env)
|
| 65 |
+
STREAMLIT_BACKEND_URL=http://localhost:8000
|
| 66 |
+
DOC_AUDI_API_BASE=http://127.0.0.1:8000
|
| 67 |
+
# Read timeout (seconds) for Ask/Summarise HTTP calls; default in code is 3600 if unset
|
| 68 |
+
DOC_AUDI_HTTP_READ_TIMEOUT=3600
|
| 69 |
+
|
| 70 |
+
# --- Docker Compose (Milestone 12) ---
|
| 71 |
+
# Copy this file to `.env` before `docker compose up` (Compose loads `.env` for substitution and `env_file`).
|
| 72 |
+
#
|
| 73 |
+
# Persistent paths below are overridden in docker-compose.yml to a single volume mount at /data:
|
| 74 |
+
# CHROMA_PERSIST_DIRECTORY=/data/chroma, AUDIT_DB_PATH=/data/audit.db, JOBS_DB_PATH=/data/jobs.db
|
| 75 |
+
# You do not need to duplicate those in .env for compose unless you use a custom override file.
|
| 76 |
+
#
|
| 77 |
+
# Ollama from the API container cannot reach localhost on your machine; default in compose is:
|
| 78 |
+
# OLLAMA_BASE_URL=http://host.docker.internal:11434
|
| 79 |
+
# (extra_hosts host-gateway is set for Linux.) Run `ollama serve` on the host, or start the bundled
|
| 80 |
+
# Ollama service: docker compose --profile ollama up -d
|
| 81 |
+
# When using the compose `ollama` profile, set in .env:
|
| 82 |
+
# OLLAMA_BASE_URL=http://ollama:11434
|
| 83 |
+
#
|
| 84 |
+
# Compose sets DOC_AUDI_API_BASE / STREAMLIT_BACKEND_URL to http://api:8000 for the Streamlit service
|
| 85 |
+
# so server-side HTTP calls reach the API on the Docker network (do not override for UI in compose).
|
| 86 |
+
#
|
| 87 |
+
# Optional port overrides: API_PORT=8000, STREAMLIT_PORT=8501, OLLAMA_HOST_PORT=11434
|
| 88 |
+
|
| 89 |
+
# --- Hugging Face Spaces ---
|
| 90 |
+
# Recommended for CPU Spaces (no Ollama): set in Space Settings → Repository secrets → Variables
|
| 91 |
+
# LLM_PROVIDER=huggingface
|
| 92 |
+
# HUGGINGFACE_API_KEY=<token> OR rely on built-in HF_TOKEN (same value as a Hub token secret)
|
| 93 |
+
# HUGGINGFACE_MODEL / HUGGINGFACE_EMBEDDING_MODEL as needed
|
| 94 |
+
# If the API runs in a second Space or external URL, set for the Streamlit Space:
|
| 95 |
+
# DOC_AUDI_API_BASE=https://your-api....hf.space (or your FastAPI public URL)
|
| 96 |
+
# Streamlit on Spaces must listen on port 8501 (default). Entry file: app.py (see docs/HUGGING_FACE_SPACES.md).
|
| 97 |
+
# On Streamlit SDK Spaces, only Streamlit starts by default; app.py auto-starts uvicorn on 127.0.0.1:8000 when
|
| 98 |
+
# SPACE_ID is set (built-in Hub env). Set DOC_AUDI_EMBED_API=0 to disable if you use a separate API URL above.
|
| 99 |
+
# Repository secrets (HF_TOKEN / HUGGINGFACE_API_KEY) are copied from st.secrets into the API subprocess env.
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Single image for API (uvicorn) and UI (Streamlit); compose overrides the command per service.
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PYTHONUNBUFFERED=1 \
|
| 6 |
+
PYTHONPATH=/app \
|
| 7 |
+
PIP_NO_CACHE_DIR=1 \
|
| 8 |
+
ANONYMIZED_TELEMETRY=FALSE
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
# PyMuPDF / scientific wheels are manylinux; minimal OS deps for SSL and fonts used by PDF tooling.
|
| 13 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
ca-certificates \
|
| 15 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
+
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 19 |
+
|
| 20 |
+
COPY api/ api/
|
| 21 |
+
COPY models/ models/
|
| 22 |
+
COPY rag/ rag/
|
| 23 |
+
COPY storage/ storage/
|
| 24 |
+
COPY workers/ workers/
|
| 25 |
+
COPY app.py streamlit_app.py main.py pyproject.toml README.md ./
|
| 26 |
+
|
| 27 |
+
EXPOSE 8000 8501
|
| 28 |
+
|
| 29 |
+
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Document-Audit RAG
|
| 3 |
+
emoji: 📑
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: "1.39.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# DocuAudit AI
|
| 12 |
+
|
| 13 |
+
**DocuAudit AI** is a production-oriented FastAPI backend plus optional Streamlit UI for **multi-document RAG**: upload documents, build a Chroma vector index, ask grounded questions with citations, and retain a **SQLite audit trail** of every query.
|
| 14 |
+
|
| 15 |
+
## Architecture
|
| 16 |
+
|
| 17 |
+
```mermaid
|
| 18 |
+
flowchart LR
|
| 19 |
+
subgraph ingest [Ingestion]
|
| 20 |
+
A[PDF / TXT / MD] --> B[Loader]
|
| 21 |
+
B --> C[Chunker]
|
| 22 |
+
C --> D[Embedder]
|
| 23 |
+
D --> E[(ChromaDB)]
|
| 24 |
+
end
|
| 25 |
+
subgraph query [Query path]
|
| 26 |
+
Q[User question] --> R[Semantic search]
|
| 27 |
+
R --> E
|
| 28 |
+
R --> T[Top-K chunks]
|
| 29 |
+
T --> L[LLM]
|
| 30 |
+
L --> U[Answer + citations]
|
| 31 |
+
end
|
| 32 |
+
U --> V[(SQLite audit)]
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
ASCII equivalent:
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
PDF Upload → Parser → Chunker → Embedder → ChromaDB
|
| 39 |
+
↓
|
| 40 |
+
User Query → Semantic Search → Top-K Chunks → LLM → Answer + Citations
|
| 41 |
+
↓
|
| 42 |
+
Audit Log (SQLite)
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Use cases
|
| 46 |
+
|
| 47 |
+
- **Litigation document analysis** — trace claims to exact pages and filenames.
|
| 48 |
+
- **Corporate finance review** — compare disclosures and filings under a consistent audit log.
|
| 49 |
+
- **Investigation support** — bulk ingest, async jobs, and reproducible query history.
|
| 50 |
+
|
| 51 |
+
## Deploying on Hugging Face Spaces
|
| 52 |
+
|
| 53 |
+
- Set **`LLM_PROVIDER=huggingface`**; use **`HUGGINGFACE_API_KEY`** and/or the Space secret **`HF_TOKEN`** (see [`.env.example`](.env.example)).
|
| 54 |
+
- Use root **`app.py`** as the Streamlit entry for the default Hub command.
|
| 55 |
+
- Hub UI, secrets, hardware, and Streamlit SDK details: [Streamlit Spaces](https://huggingface.co/docs/hub/spaces-sdks-streamlit), [Spaces overview](https://huggingface.co/docs/hub/spaces-overview).
|
| 56 |
+
- **Test locally before deploy:** `uv run python scripts/verify_huggingface_inference.py` (requires `LLM_PROVIDER=huggingface` in `.env`).
|
| 57 |
+
|
| 58 |
+
## Quick start with Docker
|
| 59 |
+
|
| 60 |
+
Requires [Docker Engine](https://docs.docker.com/engine/) and Compose v2. The snippet below matches the shipped **`docker-compose.yml`**: API on **8000**, Streamlit on **8501**, with Chroma and SQLite under **`/data`** inside the API container. After **`docker compose up -d`**, expect **`curl http://localhost:8000/health`** to return JSON including **`"status":"ok"`**.
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
git clone <repository-url> doc-Audi-ai
|
| 64 |
+
cd doc-Audi-ai
|
| 65 |
+
cp .env.example .env
|
| 66 |
+
# edit .env as needed; for compose Ollama: OLLAMA_BASE_URL=http://ollama:11434
|
| 67 |
+
# (with host Ollama: run `ollama serve`; compose defaults to host.docker.internal:11434)
|
| 68 |
+
|
| 69 |
+
docker compose build
|
| 70 |
+
docker compose up -d
|
| 71 |
+
curl -s http://localhost:8000/health
|
| 72 |
+
# http://localhost:8501 — Streamlit
|
| 73 |
+
docker compose down
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Optional all-in-one Ollama in Compose: `docker compose --profile ollama up -d` (then set `OLLAMA_BASE_URL=http://ollama:11434` in `.env` and recreate containers).
|
| 77 |
+
|
| 78 |
+
## How it works (user workflow)
|
| 79 |
+
|
| 80 |
+
Collections, ingestion vs querying, jobs vs audit, Streamlit tabs, and **per-button UI flows**: **[docs/USER_WORKFLOW.md](docs/USER_WORKFLOW.md)**.
|
| 81 |
+
|
| 82 |
+
## Run and test (step-by-step)
|
| 83 |
+
|
| 84 |
+
For ingestion formats, URL rules, job polling, sample `sample.txt` walkthrough, curl/PowerShell examples, and troubleshooting, see **[docs/RUN_AND_TEST_GUIDE.md](docs/RUN_AND_TEST_GUIDE.md)**.
|
| 85 |
+
|
| 86 |
+
For SQLite vs Memcached, offline DB inspection, and the Cursor **SQLite Viewer** extension (`qwtel.sqlite-viewer`), see **[docs/SQLITE_AND_DB_INSPECTION.md](docs/SQLITE_AND_DB_INSPECTION.md)**.
|
| 87 |
+
|
| 88 |
+
## Quick start (local, without Docker)
|
| 89 |
+
|
| 90 |
+
Run the API with **uv** (or your preferred tool):
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
git clone <repository-url> doc-Audi-ai
|
| 94 |
+
cd doc-Audi-ai
|
| 95 |
+
cp .env.example .env
|
| 96 |
+
uv sync
|
| 97 |
+
ollama pull llama3.1:8b
|
| 98 |
+
ollama pull nomic-embed-text
|
| 99 |
+
uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
|
| 100 |
+
|
| 101 |
+
uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload --reload-dir api --reload-dir storage
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Optional UI:
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
uv run streamlit run streamlit_app.py --server.port 8501 --server.address 0.0.0.0
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## API overview
|
| 111 |
+
|
| 112 |
+
| Method | Path | Description |
|
| 113 |
+
|--------|------|-------------|
|
| 114 |
+
| GET | `/health` | Liveness; returns configured app name and version |
|
| 115 |
+
| POST | `/ingest/upload` | Multipart **`files`** (one or more); queues background ingest job |
|
| 116 |
+
| POST | `/ingest/url` | JSON **`urls`** array (1–100); download and queue ingest |
|
| 117 |
+
| GET | `/ingest/collections` | Lists collections with **`document_count`** and optional **`created_at`** |
|
| 118 |
+
| DELETE | `/ingest/collection/{collection_name}` | Drops a collection; returns **`documents_removed`** |
|
| 119 |
+
| GET | `/jobs` | Lists jobs with **`total`** count |
|
| 120 |
+
| GET | `/jobs/{job_id}` | Job status with **`progress_percent`**, file counters, timestamps, **`errors`** |
|
| 121 |
+
| POST | `/query/ask` | Grounded answer; request includes **`top_k`**, **`user_id`** |
|
| 122 |
+
| POST | `/query/summarise` | Collection summary; distinct response shape (`summary`, `document_count`, …) |
|
| 123 |
+
| POST | `/query` | Legacy alias of **`/query/ask`** |
|
| 124 |
+
| GET | `/audit/logs` | Filterable audit index (`user_id`, `from_date`, `to_date`, pagination) |
|
| 125 |
+
| GET | `/audit/logs/{query_id}` | Full stored answer and citations for one query |
|
| 126 |
+
|
| 127 |
+
Interactive docs: `http://localhost:8000/docs`.
|
| 128 |
+
|
| 129 |
+
## Sample request and response (`POST /query/ask`)
|
| 130 |
+
|
| 131 |
+
Request:
|
| 132 |
+
|
| 133 |
+
```json
|
| 134 |
+
{
|
| 135 |
+
"question": "What were the key risk factors identified in the Q3 2023 financial report?",
|
| 136 |
+
"collection_name": "default",
|
| 137 |
+
"top_k": 5,
|
| 138 |
+
"user_id": "analyst_001"
|
| 139 |
+
}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Response (shape; values depend on your documents and model):
|
| 143 |
+
|
| 144 |
+
```json
|
| 145 |
+
{
|
| 146 |
+
"query_id": "uuid-string",
|
| 147 |
+
"question": "What were the key risk factors identified in the Q3 2023 financial report?",
|
| 148 |
+
"answer": "… grounded text with citations …",
|
| 149 |
+
"sources": [
|
| 150 |
+
{
|
| 151 |
+
"document_name": "q3_financial_report.pdf",
|
| 152 |
+
"page_number": 12,
|
| 153 |
+
"chunk_text": "Key risk factors include …",
|
| 154 |
+
"relevance_score": 0.91
|
| 155 |
+
}
|
| 156 |
+
],
|
| 157 |
+
"model_used": "llama3.1:8b",
|
| 158 |
+
"tokens_used": 0,
|
| 159 |
+
"response_time_ms": 1820,
|
| 160 |
+
"timestamp": "2026-05-03T12:00:00Z"
|
| 161 |
+
}
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
## Design decisions
|
| 165 |
+
|
| 166 |
+
- **Source citations** — High-stakes review requires every substantive claim to be tied to **document name** and **page** (where available), not a free-floating model monologue.
|
| 167 |
+
- **Auditability** — Each ask/summarise persists **query id**, **user id**, timing, model id, token usage (when the provider exposes it), and serialized sources so regulators or counsel can reconstruct what the system returned.
|
| 168 |
+
|
| 169 |
+
## Scale note
|
| 170 |
+
|
| 171 |
+
Architecture is designed for **high-volume document ingestion** via **async background jobs** (FastAPI `BackgroundTasks`), persistent Chroma collections, and a stateless API tier that can be replicated once you add a shared vector store and job queue.
|
| 172 |
+
|
| 173 |
+
## Tests
|
| 174 |
+
|
| 175 |
+
Automated API tests use **pytest** with isolated temp databases; they do **not** require a running server or Ollama.
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
uv sync
|
| 179 |
+
uv run pytest tests/ -q
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
Full guide (commands, coverage by file, mocks vs manual smoke tests, troubleshooting): **[docs/TESTING.md](docs/TESTING.md)**.
|
| 183 |
+
|
| 184 |
+
## Configuration
|
| 185 |
+
|
| 186 |
+
See **`.env.example`**. Common variables include `LLM_PROVIDER`, Ollama/OpenAI/Anthropic keys and models, `CHROMA_PERSIST_DIRECTORY`, `AUDIT_DB_PATH`, `JOBS_DB_PATH`, and upload limits (`MAX_FILE_SIZE_MB`; **`MAX_UPLOAD_SIZE_MB`** is accepted as an alias via settings normalization).
|
api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""HTTP API package: FastAPI app, settings, and route modules."""
|
api/config.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application configuration loaded from environment variables and ``.env``.
|
| 2 |
+
|
| 3 |
+
``Settings`` is the single source of truth for LLM provider choice, Chroma paths,
|
| 4 |
+
chunking limits, upload caps, and SQLite locations. Use :func:`get_settings` (cached)
|
| 5 |
+
from route handlers and RAG modules instead of reading ``os.environ`` directly.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from typing import Any, Self
|
| 11 |
+
|
| 12 |
+
from pydantic import Field, model_validator
|
| 13 |
+
|
| 14 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Settings(BaseSettings):
|
| 18 |
+
"""Pydantic-settings model for DocuAudit AI; fields map to env vars (case-insensitive)."""
|
| 19 |
+
|
| 20 |
+
model_config = SettingsConfigDict(
|
| 21 |
+
env_file=".env",
|
| 22 |
+
env_file_encoding="utf-8",
|
| 23 |
+
extra="ignore",
|
| 24 |
+
case_sensitive=False,
|
| 25 |
+
populate_by_name=True,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
@model_validator(mode="before")
|
| 29 |
+
@classmethod
|
| 30 |
+
def _map_max_upload_env_alias(cls, data: Any) -> Any:
|
| 31 |
+
if not isinstance(data, dict):
|
| 32 |
+
return data
|
| 33 |
+
out = dict(data)
|
| 34 |
+
if out.get("max_file_size_mb") in (None, "") and out.get("max_upload_size_mb") not in (None, ""):
|
| 35 |
+
out["max_file_size_mb"] = out.pop("max_upload_size_mb")
|
| 36 |
+
elif "max_upload_size_mb" in out and "max_file_size_mb" not in out:
|
| 37 |
+
out["max_file_size_mb"] = out.pop("max_upload_size_mb")
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
app_name: str = Field(default="DocuAudit AI", description="FastAPI title and product name")
|
| 41 |
+
app_version: str = Field(default="1.0.0", description="Application version")
|
| 42 |
+
app_description: str = Field(
|
| 43 |
+
default=(
|
| 44 |
+
"Multi-document RAG API for high-stakes consulting environments. "
|
| 45 |
+
"Every answer is grounded in source documents with full audit trails."
|
| 46 |
+
),
|
| 47 |
+
description="OpenAPI /docs description",
|
| 48 |
+
)
|
| 49 |
+
llm_provider: str = Field(default="ollama", description="Embedding provider")
|
| 50 |
+
|
| 51 |
+
openai_api_key: str | None = Field(default=None, description="OpenAI API key")
|
| 52 |
+
openai_model: str = "gpt-4o"
|
| 53 |
+
openai_embedding_model: str = "text-embedding-3-small"
|
| 54 |
+
|
| 55 |
+
anthropic_api_key: str = ""
|
| 56 |
+
anthropic_model: str = "claude-3-5-sonnet-20241022"
|
| 57 |
+
|
| 58 |
+
huggingface_api_key: str = ""
|
| 59 |
+
huggingface_model: str = Field(
|
| 60 |
+
default="meta-llama/Meta-Llama-3-8B-Instruct",
|
| 61 |
+
description=(
|
| 62 |
+
"HF chat model id (use a repo your Hub account already has access to; Llama 3.1 needs the "
|
| 63 |
+
"separate Llama 3.1 gate). Chat tries hf-inference then router auto when unset."
|
| 64 |
+
),
|
| 65 |
+
)
|
| 66 |
+
huggingface_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 67 |
+
huggingface_inference_provider: str | None = Field(
|
| 68 |
+
default=None,
|
| 69 |
+
description=(
|
| 70 |
+
"Optional huggingface_hub InferenceClient provider (e.g. hf-inference, together). "
|
| 71 |
+
"Unset uses hf-inference in chat code; set to `auto` for router auto-routing."
|
| 72 |
+
),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
ollama_base_url: str = Field(default="http://localhost:11434", description="Ollama base URL")
|
| 76 |
+
ollama_chat_model: str = "llama3.1:8b"
|
| 77 |
+
ollama_embedding_model: str = "nomic-embed-text"
|
| 78 |
+
|
| 79 |
+
chroma_persist_directory: str = Field(default="./data/chroma", description="Chroma persistence path")
|
| 80 |
+
|
| 81 |
+
chroma_persist_dir: str = Field(default="./chroma", description="Chroma persistence path")
|
| 82 |
+
chroma_collection_name: str = "docuaudit_docs"
|
| 83 |
+
|
| 84 |
+
chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
|
| 85 |
+
chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
|
| 86 |
+
top_k_results: int = Field(default=5, ge=1, le=20, description="Default number of chunks to retrieve")
|
| 87 |
+
|
| 88 |
+
audit_db_path: str = "./audit.db"
|
| 89 |
+
jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
|
| 90 |
+
|
| 91 |
+
max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size (MB)")
|
| 92 |
+
max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
|
| 93 |
+
ingest_user_agent: str = Field(
|
| 94 |
+
default="DocuAudit AI docuaudit-ingest@example.com",
|
| 95 |
+
description=(
|
| 96 |
+
"HTTP User-Agent for POST /ingest/url downloads. SEC.gov requires "
|
| 97 |
+
"'Company Name contact@email.com' with a reachable address (see sec.gov/os/accessing-edgar-data)."
|
| 98 |
+
),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
@model_validator(mode="after")
|
| 102 |
+
def _space_default_llm_provider(self) -> Self:
|
| 103 |
+
"""Hugging Face Spaces do not run Ollama locally; use Hub inference unless the user set LLM_PROVIDER."""
|
| 104 |
+
if not (os.environ.get("SPACE_ID") or "").strip():
|
| 105 |
+
return self
|
| 106 |
+
if "LLM_PROVIDER" in os.environ:
|
| 107 |
+
return self
|
| 108 |
+
if self.llm_provider.lower() != "ollama":
|
| 109 |
+
return self
|
| 110 |
+
self.llm_provider = "huggingface"
|
| 111 |
+
return self
|
| 112 |
+
|
| 113 |
+
@model_validator(mode="after")
|
| 114 |
+
def _huggingface_token_from_hub_env(self) -> Self:
|
| 115 |
+
"""When using the Hugging Face inference stack, accept the Hub token from standard env names.
|
| 116 |
+
|
| 117 |
+
Spaces often expose `HF_TOKEN` (read/write per Space secrets). Map it into `huggingface_api_key`
|
| 118 |
+
when `HUGGINGFACE_API_KEY` is unset so embedder/chat clients receive a token.
|
| 119 |
+
"""
|
| 120 |
+
if self.llm_provider.lower() != "huggingface":
|
| 121 |
+
return self
|
| 122 |
+
if (self.huggingface_api_key or "").strip():
|
| 123 |
+
return self
|
| 124 |
+
for key in ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"):
|
| 125 |
+
token = (os.environ.get(key) or "").strip()
|
| 126 |
+
if token:
|
| 127 |
+
self.huggingface_api_key = token
|
| 128 |
+
break
|
| 129 |
+
return self
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@lru_cache
|
| 133 |
+
def get_settings() -> Settings:
|
| 134 |
+
"""Return the process-wide settings singleton (cleared in tests via ``cache_clear()``)."""
|
| 135 |
+
return Settings()
|
api/main.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application entry point for DocuAudit AI.
|
| 2 |
+
|
| 3 |
+
Creates the ASGI app, registers CORS, mounts route modules (ingest, query, jobs, audit),
|
| 4 |
+
and initializes SQLite audit and job stores on startup.
|
| 5 |
+
|
| 6 |
+
Run locally::
|
| 7 |
+
|
| 8 |
+
uv run uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
|
| 9 |
+
|
| 10 |
+
Health check: ``GET /health``.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
# Before any route imports that touch Chroma: disable product telemetry (avoids posthog capture() errors in logs).
|
| 16 |
+
os.environ.setdefault("ANONYMIZED_TELEMETRY", "FALSE")
|
| 17 |
+
|
| 18 |
+
from fastapi import FastAPI
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
|
| 21 |
+
from api.config import get_settings
|
| 22 |
+
from storage.audit_store import init_audit_db
|
| 23 |
+
from storage.job_store import init_jobs_db
|
| 24 |
+
from .routes import audit, ingest, jobs, query
|
| 25 |
+
|
| 26 |
+
_settings = get_settings()
|
| 27 |
+
app = FastAPI(
|
| 28 |
+
title=_settings.app_name,
|
| 29 |
+
version=_settings.app_version,
|
| 30 |
+
description=_settings.app_description,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
app.add_middleware(
|
| 34 |
+
CORSMiddleware,
|
| 35 |
+
allow_origins=["*"],
|
| 36 |
+
allow_credentials=True,
|
| 37 |
+
allow_methods=["*"],
|
| 38 |
+
allow_headers=["*"],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
app.include_router(audit.router)
|
| 42 |
+
app.include_router(ingest.router)
|
| 43 |
+
app.include_router(jobs.router)
|
| 44 |
+
app.include_router(query.router)
|
| 45 |
+
app.include_router(query.legacy_query_router)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@app.on_event("startup")
|
| 49 |
+
async def startup() -> None:
|
| 50 |
+
"""Ensure audit and ingest-job SQLite schemas exist before serving traffic."""
|
| 51 |
+
settings = get_settings()
|
| 52 |
+
await init_audit_db(settings.audit_db_path)
|
| 53 |
+
await init_jobs_db(settings.jobs_db_path)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@app.get("/health", tags=["Health"])
|
| 57 |
+
def health() -> dict[str, str]:
|
| 58 |
+
"""Liveness probe returning app name, version, and ``status: ok``."""
|
| 59 |
+
settings = get_settings()
|
| 60 |
+
return {
|
| 61 |
+
"status": "ok",
|
| 62 |
+
"app": settings.app_name,
|
| 63 |
+
"version": settings.app_version,
|
| 64 |
+
}
|
api/routes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI routers grouped by domain: ingest, query, jobs, and audit."""
|
api/routes/audit.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Query audit log HTTP routes.
|
| 2 |
+
|
| 3 |
+
Every successful ask/summarise call writes to SQLite via :mod:`storage.audit_store`.
|
| 4 |
+
These endpoints expose paginated list and per-query detail for compliance review.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Annotated
|
| 8 |
+
|
| 9 |
+
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
| 10 |
+
|
| 11 |
+
from api.config import get_settings
|
| 12 |
+
from models.requests import AuditListParams
|
| 13 |
+
from models.responses import AuditLogDetailResponse, AuditLogsResponse
|
| 14 |
+
from storage.audit_store import get_audit_event, list_audit_events
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _audit_list_params(
|
| 18 |
+
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
| 19 |
+
offset: Annotated[int, Query(ge=0)] = 0,
|
| 20 |
+
user_id: Annotated[str | None, Query(max_length=256)] = None,
|
| 21 |
+
from_date: Annotated[str | None, Query(description="ISO 8601 lower bound")] = None,
|
| 22 |
+
to_date: Annotated[str | None, Query(description="ISO 8601 upper bound")] = None,
|
| 23 |
+
) -> AuditListParams:
|
| 24 |
+
return AuditListParams(
|
| 25 |
+
limit=limit,
|
| 26 |
+
offset=offset,
|
| 27 |
+
user_id=user_id,
|
| 28 |
+
from_date=from_date,
|
| 29 |
+
to_date=to_date,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
router = APIRouter(prefix="/audit", tags=["audit"])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@router.get("/logs", response_model=AuditLogsResponse)
|
| 37 |
+
async def audit_logs(
|
| 38 |
+
params: Annotated[AuditListParams, Depends(_audit_list_params)],
|
| 39 |
+
) -> AuditLogsResponse:
|
| 40 |
+
"""Paginated audit trail with optional user and date filters."""
|
| 41 |
+
settings = get_settings()
|
| 42 |
+
logs, total = await list_audit_events(
|
| 43 |
+
settings.audit_db_path,
|
| 44 |
+
limit=params.limit,
|
| 45 |
+
offset=params.offset,
|
| 46 |
+
user_id=params.user_id,
|
| 47 |
+
from_date=params.from_date,
|
| 48 |
+
to_date=params.to_date,
|
| 49 |
+
)
|
| 50 |
+
return AuditLogsResponse(
|
| 51 |
+
logs=logs,
|
| 52 |
+
total=total,
|
| 53 |
+
limit=params.limit,
|
| 54 |
+
offset=params.offset,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@router.get("/logs/{query_id}", response_model=AuditLogDetailResponse)
|
| 59 |
+
async def audit_log_detail(query_id: str) -> AuditLogDetailResponse:
|
| 60 |
+
"""Full answer and citations for one audited query."""
|
| 61 |
+
settings = get_settings()
|
| 62 |
+
event = await get_audit_event(settings.audit_db_path, query_id)
|
| 63 |
+
if event is None:
|
| 64 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
|
| 65 |
+
return event
|
api/routes/ingest.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document ingestion HTTP routes.
|
| 2 |
+
|
| 3 |
+
Endpoints under ``/ingest`` queue background jobs that load PDF/TXT/MD files (upload or URL),
|
| 4 |
+
chunk and embed them, and write vectors into a named Chroma collection. Poll ``/jobs/{id}``
|
| 5 |
+
for progress. Collection listing and deletion are synchronous.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from datetime import datetime, timezone
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from tempfile import NamedTemporaryFile
|
| 11 |
+
from typing import Annotated
|
| 12 |
+
from urllib.parse import unquote, urlparse
|
| 13 |
+
|
| 14 |
+
import httpx
|
| 15 |
+
from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
|
| 16 |
+
|
| 17 |
+
from api.config import get_settings
|
| 18 |
+
from models.requests import URLIngestRequest
|
| 19 |
+
from models.responses import (
|
| 20 |
+
CollectionItem,
|
| 21 |
+
IngestCollectionsResponse,
|
| 22 |
+
IngestDeleteCollectionResponse,
|
| 23 |
+
IngestUploadResponse,
|
| 24 |
+
UrlIngestResponse,
|
| 25 |
+
)
|
| 26 |
+
from rag.vector_store import (
|
| 27 |
+
collection_created_at,
|
| 28 |
+
collection_document_count,
|
| 29 |
+
delete_collection,
|
| 30 |
+
ensure_collection_created_at,
|
| 31 |
+
list_collection_names,
|
| 32 |
+
)
|
| 33 |
+
from storage.job_store import create_ingest_job, earliest_job_created_at_for_collection
|
| 34 |
+
from workers.ingest_worker import run_ingest_job
|
| 35 |
+
|
| 36 |
+
router = APIRouter(prefix="/ingest", tags=["ingest"])
|
| 37 |
+
|
| 38 |
+
_SUPPORTED_EXTENSIONS = frozenset({".pdf", ".txt", ".md"})
|
| 39 |
+
|
| 40 |
+
_CONTENT_TYPE_SUFFIX: dict[str, str] = {
|
| 41 |
+
"application/pdf": ".pdf",
|
| 42 |
+
"text/plain": ".txt",
|
| 43 |
+
"text/markdown": ".md",
|
| 44 |
+
"text/x-markdown": ".md",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _validate_file(file: UploadFile, max_bytes: int) -> str:
|
| 49 |
+
"""Check extension and size; return normalized suffix (e.g. ``.pdf``)."""
|
| 50 |
+
filename = (file.filename or "").strip()
|
| 51 |
+
if not filename:
|
| 52 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Filename is required.")
|
| 53 |
+
|
| 54 |
+
suffix = Path(filename).suffix.lower()
|
| 55 |
+
if suffix not in _SUPPORTED_EXTENSIONS:
|
| 56 |
+
raise HTTPException(
|
| 57 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 58 |
+
detail="Unsupported file type. Only PDF, TXT, and MD are accepted.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
file.file.seek(0, 2)
|
| 62 |
+
size = file.file.tell()
|
| 63 |
+
file.file.seek(0)
|
| 64 |
+
if size > max_bytes:
|
| 65 |
+
raise HTTPException(
|
| 66 |
+
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
| 67 |
+
detail=f"File too large. Max allowed is {max_bytes // (1024 * 1024)}MB.",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return suffix
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _suffix_from_url_path(url: str) -> str | None:
|
| 74 |
+
path = urlparse(url).path
|
| 75 |
+
suffix = Path(unquote(path)).suffix.lower()
|
| 76 |
+
return suffix if suffix in _SUPPORTED_EXTENSIONS else None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _suffix_from_content_type(content_type: str | None) -> str | None:
|
| 80 |
+
if not content_type:
|
| 81 |
+
return None
|
| 82 |
+
base = content_type.split(";")[0].strip().lower()
|
| 83 |
+
return _CONTENT_TYPE_SUFFIX.get(base)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _download_request_headers(user_agent: str) -> dict[str, str]:
|
| 87 |
+
"""Headers for remote URL fetches (SEC.gov requires declared User-Agent + Accept-Encoding)."""
|
| 88 |
+
return {
|
| 89 |
+
"User-Agent": user_agent.strip() or "DocuAudit AI docuaudit-ingest@example.com",
|
| 90 |
+
"Accept-Encoding": "gzip, deflate",
|
| 91 |
+
"Accept": "application/pdf,text/plain,text/markdown,*/*;q=0.8",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _display_name_from_url(url: str, suffix: str) -> str:
|
| 96 |
+
name = Path(unquote(urlparse(url).path)).name.strip()
|
| 97 |
+
if not name or name in {"/", "."}:
|
| 98 |
+
return f"download{suffix}"
|
| 99 |
+
if Path(name).suffix.lower() not in _SUPPORTED_EXTENSIONS:
|
| 100 |
+
return f"{name}{suffix}" if not name.endswith(suffix) else name
|
| 101 |
+
return name
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
async def _download_url_to_temp(url: str, max_bytes: int, user_agent: str | None = None) -> tuple[str, str]:
|
| 105 |
+
"""Stream-download a URL to a temp file; return ``(path, display_name)``."""
|
| 106 |
+
parsed = urlparse(url)
|
| 107 |
+
if parsed.scheme not in ("http", "https"):
|
| 108 |
+
raise HTTPException(
|
| 109 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 110 |
+
detail="Only http and https URLs are supported.",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
ua = user_agent or get_settings().ingest_user_agent
|
| 114 |
+
timeout = httpx.Timeout(60.0, connect=10.0)
|
| 115 |
+
limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
|
| 116 |
+
headers = _download_request_headers(ua)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
|
| 120 |
+
async with client.stream("GET", url, headers=headers) as response:
|
| 121 |
+
response.raise_for_status()
|
| 122 |
+
content_type = response.headers.get("content-type")
|
| 123 |
+
suffix = _suffix_from_url_path(url) or _suffix_from_content_type(content_type)
|
| 124 |
+
if not suffix:
|
| 125 |
+
raise HTTPException(
|
| 126 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 127 |
+
detail=(
|
| 128 |
+
"Could not determine file type from the URL path or Content-Type. "
|
| 129 |
+
"Provide a .pdf, .txt, or .md resource with matching content-type."
|
| 130 |
+
),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
display_name = _display_name_from_url(url, suffix)
|
| 134 |
+
total = 0
|
| 135 |
+
with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 136 |
+
temp_path = tmp.name
|
| 137 |
+
async for chunk in response.aiter_bytes(chunk_size=65536):
|
| 138 |
+
total += len(chunk)
|
| 139 |
+
if total > max_bytes:
|
| 140 |
+
raise HTTPException(
|
| 141 |
+
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
| 142 |
+
detail=f"Download too large. Max allowed is {max_bytes // (1024 * 1024)}MB.",
|
| 143 |
+
)
|
| 144 |
+
tmp.write(chunk)
|
| 145 |
+
except HTTPException:
|
| 146 |
+
raise
|
| 147 |
+
except httpx.HTTPStatusError as exc:
|
| 148 |
+
code = exc.response.status_code if exc.response else "unknown"
|
| 149 |
+
detail = f"Remote server returned HTTP {code}."
|
| 150 |
+
if code == 403 and "sec.gov" in parsed.netloc.lower():
|
| 151 |
+
detail += (
|
| 152 |
+
" SEC.gov requires a declared User-Agent ('Company Name you@email.com'). "
|
| 153 |
+
"Set INGEST_USER_AGENT in .env (see sec.gov/os/accessing-edgar-data)."
|
| 154 |
+
)
|
| 155 |
+
raise HTTPException(
|
| 156 |
+
status_code=status.HTTP_502_BAD_GATEWAY,
|
| 157 |
+
detail=detail,
|
| 158 |
+
) from exc
|
| 159 |
+
except httpx.RequestError as exc:
|
| 160 |
+
raise HTTPException(
|
| 161 |
+
status_code=status.HTTP_502_BAD_GATEWAY,
|
| 162 |
+
detail=f"Failed to download URL: {exc}",
|
| 163 |
+
) from exc
|
| 164 |
+
|
| 165 |
+
return temp_path, display_name
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _parse_created_at(raw: str | None) -> datetime | None:
|
| 169 |
+
if not raw:
|
| 170 |
+
return None
|
| 171 |
+
s = raw.strip()
|
| 172 |
+
if s.endswith("Z"):
|
| 173 |
+
s = s[:-1] + "+00:00"
|
| 174 |
+
try:
|
| 175 |
+
dt = datetime.fromisoformat(s)
|
| 176 |
+
if dt.tzinfo is None:
|
| 177 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 178 |
+
return dt
|
| 179 |
+
except ValueError:
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@router.post("/upload", response_model=IngestUploadResponse)
|
| 184 |
+
async def upload_endpoint(
|
| 185 |
+
background_tasks: BackgroundTasks,
|
| 186 |
+
files: list[UploadFile] = File(..., description="One or more PDF, TXT, or MD files"),
|
| 187 |
+
collection_name: Annotated[str, Form(min_length=1, max_length=256)] = "default",
|
| 188 |
+
) -> IngestUploadResponse:
|
| 189 |
+
"""Accept multipart file uploads, validate, and queue a background ingest job."""
|
| 190 |
+
settings = get_settings()
|
| 191 |
+
if not files:
|
| 192 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required.")
|
| 193 |
+
if len(files) > settings.max_documents_per_batch:
|
| 194 |
+
raise HTTPException(
|
| 195 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 196 |
+
detail=f"Too many files in one request (max {settings.max_documents_per_batch}).",
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
max_bytes = settings.max_file_size_mb * 1024 * 1024
|
| 200 |
+
temp_paths: list[tuple[str, str]] = []
|
| 201 |
+
filenames: list[str] = []
|
| 202 |
+
try:
|
| 203 |
+
for file in files:
|
| 204 |
+
suffix = _validate_file(file, max_bytes)
|
| 205 |
+
display_name = (file.filename or "upload").strip()
|
| 206 |
+
file_bytes = await file.read()
|
| 207 |
+
with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 208 |
+
tmp.write(file_bytes)
|
| 209 |
+
temp_paths.append((tmp.name, display_name))
|
| 210 |
+
filenames.append(display_name)
|
| 211 |
+
await file.close()
|
| 212 |
+
|
| 213 |
+
job_id = await create_ingest_job(
|
| 214 |
+
settings.jobs_db_path,
|
| 215 |
+
collection_name=collection_name.strip(),
|
| 216 |
+
filenames=filenames,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
background_tasks.add_task(
|
| 220 |
+
run_ingest_job,
|
| 221 |
+
job_id,
|
| 222 |
+
temp_paths,
|
| 223 |
+
collection_name.strip(),
|
| 224 |
+
settings.jobs_db_path,
|
| 225 |
+
settings.chroma_persist_directory,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return IngestUploadResponse(
|
| 229 |
+
job_id=job_id,
|
| 230 |
+
status="queued",
|
| 231 |
+
total_files=len(filenames),
|
| 232 |
+
filenames=filenames,
|
| 233 |
+
message=f"Documents queued for processing. Poll /jobs/{job_id} for status.",
|
| 234 |
+
)
|
| 235 |
+
except HTTPException:
|
| 236 |
+
for path, _ in temp_paths:
|
| 237 |
+
Path(path).unlink(missing_ok=True)
|
| 238 |
+
raise
|
| 239 |
+
except Exception as exc:
|
| 240 |
+
for path, _ in temp_paths:
|
| 241 |
+
Path(path).unlink(missing_ok=True)
|
| 242 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@router.post("/url", response_model=UrlIngestResponse)
|
| 246 |
+
async def ingest_url_endpoint(
|
| 247 |
+
background_tasks: BackgroundTasks,
|
| 248 |
+
payload: URLIngestRequest,
|
| 249 |
+
) -> UrlIngestResponse:
|
| 250 |
+
"""Download one or more HTTP(S) documents and queue them for ingestion."""
|
| 251 |
+
settings = get_settings()
|
| 252 |
+
max_bytes = settings.max_file_size_mb * 1024 * 1024
|
| 253 |
+
url_strings = [str(u).strip() for u in payload.urls]
|
| 254 |
+
if len(url_strings) > settings.max_documents_per_batch:
|
| 255 |
+
raise HTTPException(
|
| 256 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 257 |
+
detail=f"Too many URLs in one request (max {settings.max_documents_per_batch}).",
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
downloaded: list[tuple[str, str]] = []
|
| 261 |
+
try:
|
| 262 |
+
for url_str in url_strings:
|
| 263 |
+
temp_path, display_name = await _download_url_to_temp(
|
| 264 |
+
url_str, max_bytes, user_agent=settings.ingest_user_agent
|
| 265 |
+
)
|
| 266 |
+
downloaded.append((temp_path, display_name))
|
| 267 |
+
|
| 268 |
+
coll = (payload.collection_name or "default").strip()
|
| 269 |
+
job_id = await create_ingest_job(
|
| 270 |
+
settings.jobs_db_path,
|
| 271 |
+
collection_name=coll,
|
| 272 |
+
filenames=[name for _, name in downloaded],
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
background_tasks.add_task(
|
| 276 |
+
run_ingest_job,
|
| 277 |
+
job_id,
|
| 278 |
+
downloaded,
|
| 279 |
+
coll,
|
| 280 |
+
settings.jobs_db_path,
|
| 281 |
+
settings.chroma_persist_directory,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
return UrlIngestResponse(
|
| 285 |
+
job_id=job_id,
|
| 286 |
+
status="queued",
|
| 287 |
+
total_urls=len(downloaded),
|
| 288 |
+
message="URLs queued for download and processing.",
|
| 289 |
+
)
|
| 290 |
+
except HTTPException:
|
| 291 |
+
for path, _ in downloaded:
|
| 292 |
+
Path(path).unlink(missing_ok=True)
|
| 293 |
+
raise
|
| 294 |
+
except Exception as exc:
|
| 295 |
+
for path, _ in downloaded:
|
| 296 |
+
Path(path).unlink(missing_ok=True)
|
| 297 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@router.get("/collections", response_model=IngestCollectionsResponse)
|
| 301 |
+
async def list_collections_endpoint() -> IngestCollectionsResponse:
|
| 302 |
+
"""List Chroma collections with document counts and creation timestamps."""
|
| 303 |
+
settings = get_settings()
|
| 304 |
+
try:
|
| 305 |
+
names = list_collection_names(settings.chroma_persist_directory)
|
| 306 |
+
except Exception as exc:
|
| 307 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 308 |
+
items: list[CollectionItem] = []
|
| 309 |
+
for n in names:
|
| 310 |
+
cnt = collection_document_count(settings.chroma_persist_directory, n)
|
| 311 |
+
raw_created = collection_created_at(settings.chroma_persist_directory, n)
|
| 312 |
+
if not raw_created:
|
| 313 |
+
job_fallback = await earliest_job_created_at_for_collection(settings.jobs_db_path, n)
|
| 314 |
+
raw_created = ensure_collection_created_at(
|
| 315 |
+
settings.chroma_persist_directory,
|
| 316 |
+
n,
|
| 317 |
+
fallback=job_fallback,
|
| 318 |
+
)
|
| 319 |
+
items.append(
|
| 320 |
+
CollectionItem(
|
| 321 |
+
name=n,
|
| 322 |
+
document_count=cnt,
|
| 323 |
+
created_at=_parse_created_at(raw_created),
|
| 324 |
+
)
|
| 325 |
+
)
|
| 326 |
+
return IngestCollectionsResponse(collections=items, total=len(items))
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
|
| 330 |
+
async def delete_collection_endpoint(collection_name: str) -> IngestDeleteCollectionResponse:
|
| 331 |
+
"""Remove a Chroma collection and all embedded chunks."""
|
| 332 |
+
if not collection_name.strip():
|
| 333 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="collection_name is required.")
|
| 334 |
+
settings = get_settings()
|
| 335 |
+
name = collection_name.strip()
|
| 336 |
+
try:
|
| 337 |
+
existing = list_collection_names(settings.chroma_persist_directory)
|
| 338 |
+
if name not in existing:
|
| 339 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
|
| 340 |
+
removed = delete_collection(settings.chroma_persist_directory, name)
|
| 341 |
+
except HTTPException:
|
| 342 |
+
raise
|
| 343 |
+
except Exception as exc:
|
| 344 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 345 |
+
return IngestDeleteCollectionResponse(
|
| 346 |
+
message=f"Collection '{name}' deleted successfully.",
|
| 347 |
+
documents_removed=removed,
|
| 348 |
+
)
|
api/routes/jobs.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ingest job status and listing.
|
| 2 |
+
|
| 3 |
+
Jobs are created by upload/URL ingest routes and updated by :mod:`workers.ingest_worker`.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Annotated
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
| 9 |
+
|
| 10 |
+
from api.config import get_settings
|
| 11 |
+
from models.requests import JobsListParams
|
| 12 |
+
from models.responses import JobListResponse, JobStatusResponse
|
| 13 |
+
from storage.job_store import get_job_status, list_ingest_jobs
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _jobs_list_params(
|
| 17 |
+
limit: Annotated[int, Query(ge=1, le=100)] = 10,
|
| 18 |
+
offset: Annotated[int, Query(ge=0)] = 0,
|
| 19 |
+
) -> JobsListParams:
|
| 20 |
+
return JobsListParams(limit=limit, offset=offset)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
router = APIRouter(tags=["jobs"])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.get("/jobs", response_model=JobListResponse)
|
| 27 |
+
async def list_jobs(
|
| 28 |
+
params: Annotated[JobsListParams, Depends(_jobs_list_params)],
|
| 29 |
+
) -> JobListResponse:
|
| 30 |
+
"""Paginated list of ingest jobs (newest first)."""
|
| 31 |
+
settings = get_settings()
|
| 32 |
+
jobs, total = await list_ingest_jobs(
|
| 33 |
+
settings.jobs_db_path,
|
| 34 |
+
limit=params.limit,
|
| 35 |
+
offset=params.offset,
|
| 36 |
+
)
|
| 37 |
+
return JobListResponse(jobs=jobs, total=total)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@router.get("/jobs/{job_id}", response_model=JobStatusResponse)
|
| 41 |
+
async def get_job(job_id: str) -> JobStatusResponse:
|
| 42 |
+
"""Poll a single job by id (404 if unknown)."""
|
| 43 |
+
settings = get_settings()
|
| 44 |
+
job = await get_job_status(settings.jobs_db_path, job_id)
|
| 45 |
+
if job is None:
|
| 46 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found.")
|
| 47 |
+
return job
|
api/routes/query.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grounded Q&A and summarisation routes.
|
| 2 |
+
|
| 3 |
+
``POST /query/ask`` retrieves top-K chunks from Chroma, calls the configured LLM with
|
| 4 |
+
citations enforced in the prompt, persists an audit row, and returns answer + sources.
|
| 5 |
+
``POST /query/summarise`` uses a retrieval-oriented query then a summary-focused prompt.
|
| 6 |
+
``POST /query`` is a legacy alias for ``/query/ask``.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
from datetime import datetime, timezone
|
| 11 |
+
from uuid import uuid4
|
| 12 |
+
|
| 13 |
+
from fastapi import APIRouter, HTTPException, status
|
| 14 |
+
|
| 15 |
+
from api.config import Settings, get_settings
|
| 16 |
+
from models.requests import QueryRequest, SummariseRequest
|
| 17 |
+
from models.responses import AskQueryResponse, SourceCitation, SummariseQueryResponse
|
| 18 |
+
from rag.embedder import create_embedding_function
|
| 19 |
+
from rag.retriever import (
|
| 20 |
+
SUMMARY_RETRIEVAL_QUERY,
|
| 21 |
+
RetrievedChunk,
|
| 22 |
+
answer_with_grounding,
|
| 23 |
+
retrieve_chunks,
|
| 24 |
+
summarise_with_grounding,
|
| 25 |
+
)
|
| 26 |
+
from rag.vector_store import collection_document_count, get_vector_store
|
| 27 |
+
from storage.audit_store import persist_query_audit
|
| 28 |
+
|
| 29 |
+
router = APIRouter(prefix="/query", tags=["query"])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _model_used_label(settings: Settings) -> str:
|
| 33 |
+
provider = settings.llm_provider.lower()
|
| 34 |
+
if provider == "openai":
|
| 35 |
+
return settings.openai_model
|
| 36 |
+
if provider == "ollama":
|
| 37 |
+
return settings.ollama_chat_model
|
| 38 |
+
if provider == "anthropic":
|
| 39 |
+
return settings.anthropic_model
|
| 40 |
+
if provider == "huggingface":
|
| 41 |
+
return settings.huggingface_model
|
| 42 |
+
return f"{provider}:unknown"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _chunks_to_citations(chunks: list[RetrievedChunk]) -> list[SourceCitation]:
|
| 46 |
+
citations: list[SourceCitation] = []
|
| 47 |
+
for chunk in chunks:
|
| 48 |
+
page = chunk.page if chunk.page is not None else 0
|
| 49 |
+
score = float(chunk.score) if chunk.score is not None else 0.0
|
| 50 |
+
citations.append(
|
| 51 |
+
SourceCitation(
|
| 52 |
+
document_name=chunk.source or "unknown",
|
| 53 |
+
page_number=page,
|
| 54 |
+
chunk_text=chunk.text,
|
| 55 |
+
relevance_score=score,
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
return citations
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def _run_ask(
|
| 62 |
+
settings: Settings,
|
| 63 |
+
payload: QueryRequest,
|
| 64 |
+
) -> AskQueryResponse:
|
| 65 |
+
"""Retrieve, generate grounded answer, audit, and build the API response."""
|
| 66 |
+
top_k = payload.top_k
|
| 67 |
+
t0 = time.perf_counter()
|
| 68 |
+
embedding_function = create_embedding_function()
|
| 69 |
+
vector_store = get_vector_store(
|
| 70 |
+
persist_directory=settings.chroma_persist_directory,
|
| 71 |
+
collection_name=payload.collection_name or "default",
|
| 72 |
+
embedding_function=embedding_function,
|
| 73 |
+
)
|
| 74 |
+
chunks = retrieve_chunks(vector_store, payload.question, top_k)
|
| 75 |
+
answer, tokens_used = answer_with_grounding(settings, payload.question, chunks)
|
| 76 |
+
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
| 77 |
+
citations = _chunks_to_citations(chunks)
|
| 78 |
+
query_id = str(uuid4())
|
| 79 |
+
ts = datetime.now(timezone.utc)
|
| 80 |
+
response = AskQueryResponse(
|
| 81 |
+
query_id=query_id,
|
| 82 |
+
question=payload.question,
|
| 83 |
+
answer=answer,
|
| 84 |
+
sources=citations,
|
| 85 |
+
model_used=_model_used_label(settings),
|
| 86 |
+
tokens_used=tokens_used,
|
| 87 |
+
response_time_ms=elapsed_ms,
|
| 88 |
+
timestamp=ts,
|
| 89 |
+
)
|
| 90 |
+
await persist_query_audit(
|
| 91 |
+
settings.audit_db_path,
|
| 92 |
+
query_id=query_id,
|
| 93 |
+
action="query",
|
| 94 |
+
user_id=payload.user_id,
|
| 95 |
+
question=payload.question,
|
| 96 |
+
collection_name=payload.collection_name or "default",
|
| 97 |
+
answer=answer,
|
| 98 |
+
sources=citations,
|
| 99 |
+
model_used=response.model_used,
|
| 100 |
+
tokens_used=tokens_used,
|
| 101 |
+
response_time_ms=elapsed_ms,
|
| 102 |
+
kind="ask",
|
| 103 |
+
)
|
| 104 |
+
return response
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
async def _run_summarise(
|
| 108 |
+
settings: Settings,
|
| 109 |
+
payload: SummariseRequest,
|
| 110 |
+
) -> SummariseQueryResponse:
|
| 111 |
+
"""Retrieve with focus or default overview query, summarise, and audit."""
|
| 112 |
+
top_k = settings.top_k_results
|
| 113 |
+
retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
|
| 114 |
+
audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
|
| 115 |
+
t0 = time.perf_counter()
|
| 116 |
+
embedding_function = create_embedding_function()
|
| 117 |
+
vector_store = get_vector_store(
|
| 118 |
+
persist_directory=settings.chroma_persist_directory,
|
| 119 |
+
collection_name=payload.collection_name,
|
| 120 |
+
embedding_function=embedding_function,
|
| 121 |
+
)
|
| 122 |
+
chunks = retrieve_chunks(vector_store, retrieval_query, top_k)
|
| 123 |
+
summary, tokens_used = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
|
| 124 |
+
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
| 125 |
+
citations = _chunks_to_citations(chunks)
|
| 126 |
+
doc_count = collection_document_count(settings.chroma_persist_directory, payload.collection_name)
|
| 127 |
+
query_id = str(uuid4())
|
| 128 |
+
ts = datetime.now(timezone.utc)
|
| 129 |
+
response = SummariseQueryResponse(
|
| 130 |
+
query_id=query_id,
|
| 131 |
+
summary=summary,
|
| 132 |
+
document_count=doc_count,
|
| 133 |
+
sources=citations,
|
| 134 |
+
timestamp=ts,
|
| 135 |
+
)
|
| 136 |
+
await persist_query_audit(
|
| 137 |
+
settings.audit_db_path,
|
| 138 |
+
query_id=query_id,
|
| 139 |
+
action="summarise",
|
| 140 |
+
user_id=payload.user_id,
|
| 141 |
+
question=audit_question,
|
| 142 |
+
collection_name=payload.collection_name,
|
| 143 |
+
answer=summary,
|
| 144 |
+
sources=citations,
|
| 145 |
+
model_used=_model_used_label(settings),
|
| 146 |
+
tokens_used=tokens_used,
|
| 147 |
+
response_time_ms=elapsed_ms,
|
| 148 |
+
kind="summarise",
|
| 149 |
+
)
|
| 150 |
+
return response
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@router.post("/ask", response_model=AskQueryResponse)
|
| 154 |
+
async def ask_endpoint(payload: QueryRequest) -> AskQueryResponse:
|
| 155 |
+
"""Grounded question answering against a Chroma collection."""
|
| 156 |
+
settings = get_settings()
|
| 157 |
+
try:
|
| 158 |
+
return await _run_ask(settings, payload)
|
| 159 |
+
except Exception as exc:
|
| 160 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@router.post("/summarise", response_model=SummariseQueryResponse)
|
| 164 |
+
async def summarise_endpoint(payload: SummariseRequest) -> SummariseQueryResponse:
|
| 165 |
+
"""Collection-wide summary with optional focus for retrieval."""
|
| 166 |
+
settings = get_settings()
|
| 167 |
+
try:
|
| 168 |
+
return await _run_summarise(settings, payload)
|
| 169 |
+
except Exception as exc:
|
| 170 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
legacy_query_router = APIRouter(tags=["query"])
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@legacy_query_router.post("/query", response_model=AskQueryResponse)
|
| 177 |
+
async def query_post_compat(payload: QueryRequest) -> AskQueryResponse:
|
| 178 |
+
"""Same behavior as POST /query/ask; kept for older clients and docs that used POST /query."""
|
| 179 |
+
return await ask_endpoint(payload)
|
app.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Spaces default entry (Streamlit SDK expects `app.py`).
|
| 2 |
+
|
| 3 |
+
Local development can still use `streamlit run streamlit_app.py`; Docker Compose uses `app.py`
|
| 4 |
+
so the same entry path works on the Hub and in containers.
|
| 5 |
+
|
| 6 |
+
On Hugging Face Streamlit Spaces only `streamlit run app.py` is started — no separate uvicorn
|
| 7 |
+
process — so we spawn the FastAPI app on 127.0.0.1:8000 when `SPACE_ID` is present (see Hub
|
| 8 |
+
built-in env vars). Set `DOC_AUDI_EMBED_API=0` to disable. Use `DOC_AUDI_EMBED_API=1` to force
|
| 9 |
+
embedding elsewhere (e.g. demos).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import atexit
|
| 15 |
+
import os
|
| 16 |
+
import socket
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
_uvicorn_proc: subprocess.Popen[bytes] | None = None
|
| 22 |
+
_cleanup_registered = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _port_accepting_connections(host: str, port: int) -> bool:
|
| 26 |
+
try:
|
| 27 |
+
with socket.create_connection((host, port), timeout=0.3):
|
| 28 |
+
return True
|
| 29 |
+
except OSError:
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _want_embedded_api() -> bool:
|
| 34 |
+
if os.environ.get("DOC_AUDI_EMBED_API", "").lower() in ("0", "false", "no"):
|
| 35 |
+
return False
|
| 36 |
+
if os.environ.get("DOC_AUDI_EMBED_API", "").lower() in ("1", "true", "yes"):
|
| 37 |
+
return True
|
| 38 |
+
return bool(os.environ.get("SPACE_ID"))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _propagate_streamlit_secrets_to_environ() -> None:
|
| 42 |
+
"""Copy Hub tokens from Streamlit secrets into os.environ for the embedded uvicorn child.
|
| 43 |
+
|
| 44 |
+
On Hugging Face Streamlit Spaces, repository secrets are often available as ``st.secrets``
|
| 45 |
+
but are not always present in ``os.environ``. ``subprocess.Popen`` only forwards the
|
| 46 |
+
process environment, so the API would miss ``HF_TOKEN`` / ``HUGGINGFACE_API_KEY`` otherwise.
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
import streamlit as st
|
| 50 |
+
except ImportError:
|
| 51 |
+
return
|
| 52 |
+
secrets = getattr(st, "secrets", None)
|
| 53 |
+
if secrets is None:
|
| 54 |
+
return
|
| 55 |
+
for key in ("HF_TOKEN", "HUGGINGFACE_API_KEY", "HUGGING_FACE_HUB_TOKEN"):
|
| 56 |
+
if (os.environ.get(key) or "").strip():
|
| 57 |
+
continue
|
| 58 |
+
try:
|
| 59 |
+
raw = secrets[key]
|
| 60 |
+
except Exception:
|
| 61 |
+
continue
|
| 62 |
+
if raw is not None and str(raw).strip():
|
| 63 |
+
os.environ[key] = str(raw).strip()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _maybe_start_embedded_uvicorn() -> None:
|
| 67 |
+
"""Start uvicorn in-process when running on HF Spaces (or when DOC_AUDI_EMBED_API=1)."""
|
| 68 |
+
global _uvicorn_proc, _cleanup_registered
|
| 69 |
+
if not _want_embedded_api():
|
| 70 |
+
return
|
| 71 |
+
_propagate_streamlit_secrets_to_environ()
|
| 72 |
+
if _port_accepting_connections("127.0.0.1", 8000):
|
| 73 |
+
return
|
| 74 |
+
if _uvicorn_proc is not None and _uvicorn_proc.poll() is None:
|
| 75 |
+
for _ in range(120):
|
| 76 |
+
if _port_accepting_connections("127.0.0.1", 8000):
|
| 77 |
+
return
|
| 78 |
+
time.sleep(0.05)
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
cmd = [
|
| 82 |
+
sys.executable,
|
| 83 |
+
"-m",
|
| 84 |
+
"uvicorn",
|
| 85 |
+
"api.main:app",
|
| 86 |
+
"--host",
|
| 87 |
+
"127.0.0.1",
|
| 88 |
+
"--port",
|
| 89 |
+
"8000",
|
| 90 |
+
]
|
| 91 |
+
_uvicorn_proc = subprocess.Popen(cmd)
|
| 92 |
+
proc = _uvicorn_proc
|
| 93 |
+
|
| 94 |
+
if not _cleanup_registered:
|
| 95 |
+
|
| 96 |
+
def _cleanup(p: subprocess.Popen[bytes] = proc) -> None:
|
| 97 |
+
if p.poll() is None:
|
| 98 |
+
p.terminate()
|
| 99 |
+
try:
|
| 100 |
+
p.wait(timeout=10)
|
| 101 |
+
except subprocess.TimeoutExpired:
|
| 102 |
+
p.kill()
|
| 103 |
+
|
| 104 |
+
atexit.register(_cleanup)
|
| 105 |
+
_cleanup_registered = True
|
| 106 |
+
|
| 107 |
+
for _ in range(120):
|
| 108 |
+
if _port_accepting_connections("127.0.0.1", 8000):
|
| 109 |
+
return
|
| 110 |
+
time.sleep(0.05)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
_maybe_start_embedded_uvicorn()
|
| 114 |
+
|
| 115 |
+
from streamlit_app import main # noqa: E402 — start API before importing Streamlit stack
|
| 116 |
+
|
| 117 |
+
main()
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requires a project `.env` (copy from `.env.example`) for `env_file` and variable substitution.
|
| 2 |
+
name: docuaudit-ai
|
| 3 |
+
|
| 4 |
+
x-app: &app
|
| 5 |
+
build: .
|
| 6 |
+
image: docuaudit-ai:${IMAGE_TAG:-local}
|
| 7 |
+
|
| 8 |
+
services:
|
| 9 |
+
api:
|
| 10 |
+
<<: *app
|
| 11 |
+
command: uvicorn api.main:app --host 0.0.0.0 --port 8000
|
| 12 |
+
ports:
|
| 13 |
+
- "${API_PORT:-8000}:8000"
|
| 14 |
+
env_file:
|
| 15 |
+
- .env
|
| 16 |
+
environment:
|
| 17 |
+
CHROMA_PERSIST_DIRECTORY: /data/chroma
|
| 18 |
+
AUDIT_DB_PATH: /data/audit.db
|
| 19 |
+
JOBS_DB_PATH: /data/jobs.db
|
| 20 |
+
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://host.docker.internal:11434}
|
| 21 |
+
volumes:
|
| 22 |
+
- docuaudit_data:/data
|
| 23 |
+
extra_hosts:
|
| 24 |
+
- "host.docker.internal:host-gateway"
|
| 25 |
+
healthcheck:
|
| 26 |
+
test:
|
| 27 |
+
[
|
| 28 |
+
"CMD",
|
| 29 |
+
"python",
|
| 30 |
+
"-c",
|
| 31 |
+
"import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5)",
|
| 32 |
+
]
|
| 33 |
+
interval: 15s
|
| 34 |
+
timeout: 5s
|
| 35 |
+
retries: 5
|
| 36 |
+
start_period: 40s
|
| 37 |
+
|
| 38 |
+
streamlit:
|
| 39 |
+
<<: *app
|
| 40 |
+
command: >
|
| 41 |
+
streamlit run app.py
|
| 42 |
+
--server.port=8501
|
| 43 |
+
--server.address=0.0.0.0
|
| 44 |
+
--server.headless=true
|
| 45 |
+
--browser.gatherUsageStats=false
|
| 46 |
+
ports:
|
| 47 |
+
- "${STREAMLIT_PORT:-8501}:8501"
|
| 48 |
+
env_file:
|
| 49 |
+
- .env
|
| 50 |
+
environment:
|
| 51 |
+
DOC_AUDI_API_BASE: http://api:8000
|
| 52 |
+
STREAMLIT_BACKEND_URL: http://api:8000
|
| 53 |
+
depends_on:
|
| 54 |
+
api:
|
| 55 |
+
condition: service_healthy
|
| 56 |
+
|
| 57 |
+
ollama:
|
| 58 |
+
image: ollama/ollama:latest
|
| 59 |
+
profiles: ["ollama"]
|
| 60 |
+
ports:
|
| 61 |
+
- "${OLLAMA_HOST_PORT:-11434}:11434"
|
| 62 |
+
volumes:
|
| 63 |
+
- ollama_data:/root/.ollama
|
| 64 |
+
|
| 65 |
+
volumes:
|
| 66 |
+
docuaudit_data:
|
| 67 |
+
ollama_data:
|
main.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal CLI placeholder (not used by Docker or Hugging Face entrypoints).
|
| 2 |
+
|
| 3 |
+
Production entrypoints: ``api.main:app`` (FastAPI) and ``app.py`` / ``streamlit_app.py`` (UI).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main() -> None:
|
| 8 |
+
"""Print a hello message when run as ``python main.py``."""
|
| 9 |
+
print("Hello from doc-audi-ai!")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API contract models: request payloads and response DTOs."""
|
models/requests.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic request bodies and query-parameter models for the HTTP API.
|
| 2 |
+
|
| 3 |
+
Used by FastAPI route handlers for validation and OpenAPI schema generation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, ConfigDict, Field, HttpUrl
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class QueryRequest(BaseModel):
|
| 12 |
+
model_config = ConfigDict(extra="forbid")
|
| 13 |
+
|
| 14 |
+
question: str = Field(min_length=5, max_length=2000, description="Natural language question")
|
| 15 |
+
collection_name: Optional[str] = Field(
|
| 16 |
+
default="default",
|
| 17 |
+
min_length=1,
|
| 18 |
+
max_length=256,
|
| 19 |
+
description="Chroma collection to search",
|
| 20 |
+
)
|
| 21 |
+
top_k: int = Field(default=5, ge=1, le=20, description="Number of chunks to retrieve")
|
| 22 |
+
user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SummariseRequest(BaseModel):
|
| 26 |
+
model_config = ConfigDict(extra="forbid")
|
| 27 |
+
|
| 28 |
+
collection_name: str = Field(
|
| 29 |
+
default="default",
|
| 30 |
+
min_length=1,
|
| 31 |
+
max_length=256,
|
| 32 |
+
description="Chroma collection to summarise",
|
| 33 |
+
)
|
| 34 |
+
focus: str | None = Field(
|
| 35 |
+
default=None,
|
| 36 |
+
max_length=8000,
|
| 37 |
+
description="Optional angle or scope for retrieval and the summary",
|
| 38 |
+
)
|
| 39 |
+
user_id: str = Field(default="anonymous", max_length=256, description="Caller id for audit filtering")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class URLIngestRequest(BaseModel):
|
| 43 |
+
model_config = ConfigDict(extra="forbid")
|
| 44 |
+
|
| 45 |
+
urls: list[HttpUrl] = Field(
|
| 46 |
+
min_length=1,
|
| 47 |
+
max_length=100,
|
| 48 |
+
description="One or more HTTP(S) URLs to PDF, TXT, or Markdown documents",
|
| 49 |
+
)
|
| 50 |
+
collection_name: Optional[str] = Field(
|
| 51 |
+
default="default",
|
| 52 |
+
min_length=1,
|
| 53 |
+
max_length=256,
|
| 54 |
+
description="Target Chroma collection name",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class JobsListParams(BaseModel):
|
| 59 |
+
model_config = ConfigDict(extra="forbid")
|
| 60 |
+
|
| 61 |
+
limit: int = Field(default=10, ge=1, le=100, description="Max jobs to return")
|
| 62 |
+
offset: int = Field(default=0, ge=0, description="Offset for pagination")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AuditListParams(BaseModel):
|
| 66 |
+
model_config = ConfigDict(extra="forbid")
|
| 67 |
+
|
| 68 |
+
limit: int = Field(default=50, ge=1, le=100, description="Max log entries to return")
|
| 69 |
+
offset: int = Field(default=0, ge=0, description="Offset for pagination")
|
| 70 |
+
user_id: str | None = Field(default=None, max_length=256, description="Filter by user id")
|
| 71 |
+
from_date: str | None = Field(
|
| 72 |
+
default=None,
|
| 73 |
+
description="ISO 8601 datetime lower bound (inclusive) on timestamp",
|
| 74 |
+
)
|
| 75 |
+
to_date: str | None = Field(
|
| 76 |
+
default=None,
|
| 77 |
+
description="ISO 8601 datetime upper bound (inclusive) on timestamp",
|
| 78 |
+
)
|
models/responses.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic response models returned by FastAPI routes.
|
| 2 |
+
|
| 3 |
+
Shared shape: :class:`SourceCitation` appears on ask, summarise, and audit detail responses.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# --- Shared citations (spec-shaped) ---
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SourceCitation(BaseModel):
|
| 15 |
+
document_name: str
|
| 16 |
+
page_number: int
|
| 17 |
+
chunk_text: str
|
| 18 |
+
relevance_score: float
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# --- Query: ask ---
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AskQueryResponse(BaseModel):
|
| 25 |
+
query_id: str
|
| 26 |
+
question: str
|
| 27 |
+
answer: str
|
| 28 |
+
sources: list[SourceCitation] = Field(default_factory=list)
|
| 29 |
+
model_used: str
|
| 30 |
+
tokens_used: int
|
| 31 |
+
response_time_ms: int
|
| 32 |
+
timestamp: datetime
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# --- Query: summarise ---
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SummariseQueryResponse(BaseModel):
|
| 39 |
+
query_id: str
|
| 40 |
+
summary: str
|
| 41 |
+
document_count: int
|
| 42 |
+
sources: list[SourceCitation] = Field(default_factory=list)
|
| 43 |
+
timestamp: datetime
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --- Ingest ---
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class IngestUploadResponse(BaseModel):
|
| 50 |
+
job_id: str
|
| 51 |
+
status: str
|
| 52 |
+
total_files: int
|
| 53 |
+
filenames: list[str]
|
| 54 |
+
message: str
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class UrlIngestResponse(BaseModel):
|
| 58 |
+
job_id: str
|
| 59 |
+
status: str
|
| 60 |
+
total_urls: int
|
| 61 |
+
message: str
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CollectionItem(BaseModel):
|
| 65 |
+
name: str
|
| 66 |
+
document_count: int
|
| 67 |
+
created_at: datetime | None = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class IngestCollectionsResponse(BaseModel):
|
| 71 |
+
collections: list[CollectionItem] = Field(default_factory=list)
|
| 72 |
+
total: int
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class IngestDeleteCollectionResponse(BaseModel):
|
| 76 |
+
message: str
|
| 77 |
+
documents_removed: int
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# --- Jobs ---
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class JobStatusResponse(BaseModel):
|
| 84 |
+
job_id: str
|
| 85 |
+
status: str
|
| 86 |
+
total_files: int
|
| 87 |
+
processed_files: int
|
| 88 |
+
failed_files: int
|
| 89 |
+
progress_percent: int
|
| 90 |
+
started_at: datetime | None
|
| 91 |
+
completed_at: datetime | None
|
| 92 |
+
errors: list[str] = Field(default_factory=list)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class JobListItem(BaseModel):
|
| 96 |
+
job_id: str
|
| 97 |
+
status: str
|
| 98 |
+
total_files: int
|
| 99 |
+
completed_at: datetime | None = None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class JobListResponse(BaseModel):
|
| 103 |
+
jobs: list[JobListItem] = Field(default_factory=list)
|
| 104 |
+
total: int
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# --- Audit ---
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class AuditLogEntry(BaseModel):
|
| 111 |
+
query_id: str
|
| 112 |
+
user_id: str
|
| 113 |
+
question: str
|
| 114 |
+
answer_summary: str
|
| 115 |
+
sources_count: int
|
| 116 |
+
model_used: str | None
|
| 117 |
+
timestamp: datetime
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class AuditLogsResponse(BaseModel):
|
| 121 |
+
logs: list[AuditLogEntry] = Field(default_factory=list)
|
| 122 |
+
total: int
|
| 123 |
+
limit: int
|
| 124 |
+
offset: int
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class AuditLogDetailResponse(BaseModel):
|
| 128 |
+
query_id: str
|
| 129 |
+
user_id: str
|
| 130 |
+
question: str
|
| 131 |
+
full_answer: str
|
| 132 |
+
sources: list[SourceCitation] = Field(default_factory=list)
|
| 133 |
+
model_used: str | None
|
| 134 |
+
tokens_used: int | None
|
| 135 |
+
timestamp: datetime
|
pyproject.toml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "doc-audi-ai"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi==0.111.0",
|
| 9 |
+
"langchain==0.2.0",
|
| 10 |
+
"langchain-openai==0.1.7",
|
| 11 |
+
"langchain-community==0.2.0",
|
| 12 |
+
"langchain-chroma==0.1.4",
|
| 13 |
+
"langchain-text-splitters==0.2.0",
|
| 14 |
+
"langchain-anthropic==0.1.15",
|
| 15 |
+
"langchain-ollama==0.1.3",
|
| 16 |
+
"chromadb==0.5.0",
|
| 17 |
+
# Chroma 0.5 calls posthog.capture(distinct_id, event, props); posthog 6+ removed that API (breaks telemetry + spams stderr).
|
| 18 |
+
"posthog>=3.7.0,<4",
|
| 19 |
+
"openai==1.30.1",
|
| 20 |
+
"anthropic==0.28.1",
|
| 21 |
+
"pydantic-settings==2.3.4",
|
| 22 |
+
"pymupdf==1.25.5",
|
| 23 |
+
"python-multipart==0.0.9",
|
| 24 |
+
"aiosqlite>=0.21.0",
|
| 25 |
+
"httpx>=0.27.0",
|
| 26 |
+
"uvicorn[standard]==0.29.0",
|
| 27 |
+
"huggingface-hub>=1.13.0",
|
| 28 |
+
"langchain-huggingface>=0.0.3",
|
| 29 |
+
"streamlit>=1.39.0",
|
| 30 |
+
"pytest>=8.4.2",
|
| 31 |
+
"pytest-asyncio>=1.2.0",
|
| 32 |
+
"onnxruntime==1.23.2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
|
| 33 |
+
"torch==2.2.2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
|
| 34 |
+
]
|
pytest.ini
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
testpaths = tests
|
| 3 |
+
python_files = test_*.py
|
rag/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG pipeline: load → chunk → embed → store → retrieve → generate.
|
| 2 |
+
|
| 3 |
+
Submodules: :mod:`loader`, :mod:`chunker`, :mod:`embedder`, :mod:`vector_store`,
|
| 4 |
+
:mod:`retriever`, and :mod:`hf_hub_inference` for Hugging Face Hub compatibility.
|
| 5 |
+
"""
|
| 6 |
+
|
rag/chunker.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Split loaded documents into overlapping chunks for embedding.
|
| 2 |
+
|
| 3 |
+
Chunk size and overlap come from :func:`api.config.get_settings`. Each chunk receives
|
| 4 |
+
``chunk_index``, ``source``, and ``page`` metadata.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from langchain_core.documents import Document
|
| 8 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 9 |
+
|
| 10 |
+
from api.config import get_settings
|
| 11 |
+
|
| 12 |
+
def chunk_documents(
|
| 13 |
+
documents: list[Document],
|
| 14 |
+
) -> list[Document]:
|
| 15 |
+
"""Recursive character split of all input documents."""
|
| 16 |
+
settings = get_settings()
|
| 17 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 18 |
+
chunk_size=settings.chunk_size,
|
| 19 |
+
chunk_overlap=settings.chunk_overlap,
|
| 20 |
+
separators=["\n\n", "\n", ". ", " ", ""],
|
| 21 |
+
)
|
| 22 |
+
chunks = splitter.split_documents(documents)
|
| 23 |
+
for idx, chunk in enumerate(chunks):
|
| 24 |
+
chunk.metadata["chunk_index"] = idx
|
| 25 |
+
chunk.metadata.setdefault("source", "unknown")
|
| 26 |
+
chunk.metadata.setdefault("page", 0)
|
| 27 |
+
return chunks
|
| 28 |
+
|
rag/embedder.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Factory for LangChain embedding backends (OpenAI, Ollama, Hugging Face).
|
| 2 |
+
|
| 3 |
+
The active provider is ``Settings.llm_provider``. Used by ingest and query paths when
|
| 4 |
+
opening or querying Chroma collections.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from langchain_core.embeddings import Embeddings
|
| 8 |
+
from langchain_ollama import OllamaEmbeddings
|
| 9 |
+
from langchain_openai import OpenAIEmbeddings
|
| 10 |
+
from pydantic import SecretStr
|
| 11 |
+
|
| 12 |
+
from api.config import get_settings
|
| 13 |
+
from rag.hf_hub_inference import HubInferenceEmbeddings
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_embedding_function() -> Embeddings:
|
| 17 |
+
"""Return an ``Embeddings`` implementation matching the configured LLM provider."""
|
| 18 |
+
settings = get_settings()
|
| 19 |
+
provider = settings.llm_provider.lower()
|
| 20 |
+
|
| 21 |
+
if provider == "openai":
|
| 22 |
+
if not settings.openai_api_key:
|
| 23 |
+
raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai")
|
| 24 |
+
return OpenAIEmbeddings(
|
| 25 |
+
model=settings.openai_embedding_model,
|
| 26 |
+
api_key=SecretStr(settings.openai_api_key),
|
| 27 |
+
)
|
| 28 |
+
if provider == "huggingface":
|
| 29 |
+
if not settings.huggingface_api_key:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
"A Hugging Face token is required when LLM_PROVIDER=huggingface "
|
| 32 |
+
"(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)."
|
| 33 |
+
)
|
| 34 |
+
return HubInferenceEmbeddings(
|
| 35 |
+
model=settings.huggingface_embedding_model,
|
| 36 |
+
api_token=settings.huggingface_api_key,
|
| 37 |
+
)
|
| 38 |
+
if provider == "ollama":
|
| 39 |
+
return OllamaEmbeddings(
|
| 40 |
+
model=settings.ollama_embedding_model,
|
| 41 |
+
base_url=settings.ollama_base_url,
|
| 42 |
+
)
|
| 43 |
+
raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}")
|
| 44 |
+
|
rag/hf_hub_inference.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Inference API via ``huggingface_hub.InferenceClient``.
|
| 2 |
+
|
| 3 |
+
``langchain_huggingface`` 0.0.x uses ``InferenceClient.post()``, which was removed in
|
| 4 |
+
``huggingface_hub`` 1.x. Chat tries ``InferenceClient.chat_completion`` on the primary
|
| 5 |
+
provider, then (for repo ids containing ``mistral`` when primary is not Novita) Novita,
|
| 6 |
+
which often maps those weights to conversational chat only. On router errors or local
|
| 7 |
+
``ValueError`` (Hub sometimes omits ``pipeline_tag``), we fall back to ``text_generation``
|
| 8 |
+
providers, then the classic **api-inference** ``POST /models/{id}`` JSON API.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any, List, Optional
|
| 14 |
+
|
| 15 |
+
import httpx
|
| 16 |
+
import numpy as np
|
| 17 |
+
from langchain_core.embeddings import Embeddings
|
| 18 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 19 |
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
| 20 |
+
from langchain_core.outputs import ChatGeneration, ChatResult
|
| 21 |
+
from langchain_core.pydantic_v1 import Field, root_validator
|
| 22 |
+
|
| 23 |
+
from huggingface_hub import InferenceClient, constants
|
| 24 |
+
from huggingface_hub.errors import BadRequestError, HfHubHTTPError
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _lc_messages_to_hf_chat(messages: List[BaseMessage]) -> list[dict[str, str]]:
|
| 28 |
+
"""Map LangChain messages to Hugging Face ``chat_completion`` message dicts."""
|
| 29 |
+
out: list[dict[str, str]] = []
|
| 30 |
+
for m in messages:
|
| 31 |
+
content = m.content if isinstance(m.content, str) else str(m.content)
|
| 32 |
+
if isinstance(m, SystemMessage):
|
| 33 |
+
out.append({"role": "system", "content": content})
|
| 34 |
+
elif isinstance(m, HumanMessage):
|
| 35 |
+
out.append({"role": "user", "content": content})
|
| 36 |
+
elif isinstance(m, AIMessage):
|
| 37 |
+
out.append({"role": "assistant", "content": content})
|
| 38 |
+
elif isinstance(m, ToolMessage):
|
| 39 |
+
out.append({"role": "user", "content": f"[tool result]\n{content}"})
|
| 40 |
+
else:
|
| 41 |
+
out.append({"role": "user", "content": content})
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _messages_to_text_generation_prompt(repo_id: str, messages: List[BaseMessage]) -> str:
|
| 46 |
+
"""Build a single prompt for causal / text-generation APIs (instruct templates)."""
|
| 47 |
+
blocks: list[str] = []
|
| 48 |
+
for m in messages:
|
| 49 |
+
content = m.content if isinstance(m.content, str) else str(m.content)
|
| 50 |
+
if isinstance(m, SystemMessage):
|
| 51 |
+
blocks.append(content)
|
| 52 |
+
elif isinstance(m, HumanMessage):
|
| 53 |
+
blocks.append(content)
|
| 54 |
+
elif isinstance(m, AIMessage):
|
| 55 |
+
blocks.append(content)
|
| 56 |
+
elif isinstance(m, ToolMessage):
|
| 57 |
+
blocks.append(f"[tool]\n{content}")
|
| 58 |
+
else:
|
| 59 |
+
blocks.append(content)
|
| 60 |
+
body = "\n\n".join(blocks)
|
| 61 |
+
rid = repo_id.lower()
|
| 62 |
+
if "mistral" in rid:
|
| 63 |
+
return f"<s>[INST] {body} [/INST]"
|
| 64 |
+
return f"{body}\n\nAssistant:\n"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _chat_completion_text_and_usage(out: Any) -> tuple[str, dict[str, int] | None]:
|
| 68 |
+
"""Extract assistant text and optional token usage from ``ChatCompletionOutput``."""
|
| 69 |
+
choices = getattr(out, "choices", None) or []
|
| 70 |
+
if not choices:
|
| 71 |
+
return (str(out).strip(), None)
|
| 72 |
+
msg = getattr(choices[0], "message", None)
|
| 73 |
+
text = (getattr(msg, "content", None) or "").strip() if msg is not None else ""
|
| 74 |
+
|
| 75 |
+
usage_meta: dict[str, int] | None = None
|
| 76 |
+
u = getattr(out, "usage", None)
|
| 77 |
+
if u is not None:
|
| 78 |
+
usage_meta = {}
|
| 79 |
+
tt = getattr(u, "total_tokens", None)
|
| 80 |
+
pt = getattr(u, "prompt_tokens", None)
|
| 81 |
+
ct = getattr(u, "completion_tokens", None)
|
| 82 |
+
if tt is not None:
|
| 83 |
+
usage_meta["total_tokens"] = int(tt)
|
| 84 |
+
if pt is not None:
|
| 85 |
+
usage_meta["input_tokens"] = int(pt)
|
| 86 |
+
if ct is not None:
|
| 87 |
+
usage_meta["output_tokens"] = int(ct)
|
| 88 |
+
if not usage_meta:
|
| 89 |
+
usage_meta = None
|
| 90 |
+
|
| 91 |
+
return text, usage_meta
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _legacy_api_text_generation(
|
| 95 |
+
model_id: str,
|
| 96 |
+
api_token: str,
|
| 97 |
+
prompt: str,
|
| 98 |
+
*,
|
| 99 |
+
max_new_tokens: int,
|
| 100 |
+
temperature: float,
|
| 101 |
+
stop: list[str] | None,
|
| 102 |
+
) -> str:
|
| 103 |
+
"""Classic HF Inference API (bypasses strict ``InferenceClient`` task checks)."""
|
| 104 |
+
url = f"{constants.INFERENCE_ENDPOINT.rstrip('/')}/models/{model_id}"
|
| 105 |
+
parameters: dict[str, Any] = {
|
| 106 |
+
"max_new_tokens": max_new_tokens,
|
| 107 |
+
"temperature": temperature,
|
| 108 |
+
"return_full_text": False,
|
| 109 |
+
}
|
| 110 |
+
if stop:
|
| 111 |
+
parameters["stop"] = stop
|
| 112 |
+
body = {"inputs": prompt, "parameters": parameters}
|
| 113 |
+
headers = {"Authorization": f"Bearer {api_token}"}
|
| 114 |
+
timeout = httpx.Timeout(60.0, read=300.0)
|
| 115 |
+
with httpx.Client(timeout=timeout) as client:
|
| 116 |
+
resp = client.post(url, json=body, headers=headers)
|
| 117 |
+
try:
|
| 118 |
+
resp.raise_for_status()
|
| 119 |
+
except httpx.HTTPStatusError as exc:
|
| 120 |
+
_raise_legacy_inference_http_error(model_id, exc)
|
| 121 |
+
data = resp.json()
|
| 122 |
+
if isinstance(data, dict) and data.get("error"):
|
| 123 |
+
raise RuntimeError(str(data["error"]))
|
| 124 |
+
if isinstance(data, list) and data:
|
| 125 |
+
first = data[0]
|
| 126 |
+
if isinstance(first, dict) and "generated_text" in first:
|
| 127 |
+
return str(first["generated_text"]).strip()
|
| 128 |
+
if isinstance(data, dict) and "generated_text" in data:
|
| 129 |
+
return str(data["generated_text"]).strip()
|
| 130 |
+
raise RuntimeError(f"Unexpected legacy inference response: {data!r}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class LegacyInferenceNotFoundError(RuntimeError):
|
| 134 |
+
"""Classic ``api-inference`` returned 404 for this model id (weights not on that route)."""
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _raise_legacy_inference_http_error(model_id: str, exc: httpx.HTTPStatusError) -> None:
|
| 138 |
+
if exc.response.status_code == 404:
|
| 139 |
+
raise LegacyInferenceNotFoundError(
|
| 140 |
+
f"Hugging Face legacy inference returned 404 for model {model_id!r}. "
|
| 141 |
+
"The classic api-inference route often no longer serves this checkpoint, and router chat "
|
| 142 |
+
"can 404 as well depending on provider health. Try "
|
| 143 |
+
"HUGGINGFACE_MODEL=meta-llama/Meta-Llama-3-8B-Instruct (or another id your token can call), "
|
| 144 |
+
"another model id your token can reach, or Ollama/local inference."
|
| 145 |
+
) from exc
|
| 146 |
+
raise exc
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class HubInferenceEmbeddings(Embeddings):
|
| 150 |
+
"""Embeddings through ``InferenceClient.feature_extraction``."""
|
| 151 |
+
|
| 152 |
+
def __init__(self, *, model: str, api_token: str) -> None:
|
| 153 |
+
self._model = model
|
| 154 |
+
self._client = InferenceClient(model=model, token=api_token or None)
|
| 155 |
+
|
| 156 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 157 |
+
out: list[list[float]] = []
|
| 158 |
+
for text in texts:
|
| 159 |
+
t = text.replace("\n", " ")
|
| 160 |
+
raw = self._client.feature_extraction(t, model=self._model)
|
| 161 |
+
vec = np.asarray(raw, dtype=np.float32)
|
| 162 |
+
if vec.ndim > 1:
|
| 163 |
+
vec = vec.mean(axis=0)
|
| 164 |
+
out.append(vec.flatten().tolist())
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
def embed_query(self, text: str) -> List[float]:
|
| 168 |
+
return self.embed_documents([text])[0]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class HubInferenceChatModel(BaseChatModel):
|
| 172 |
+
"""HF Inference: ``chat_completion`` when supported, else ``text_generation`` fallback."""
|
| 173 |
+
|
| 174 |
+
repo_id: str = Field(..., description="Hugging Face model id for inference")
|
| 175 |
+
huggingfacehub_api_token: str = Field(..., repr=False)
|
| 176 |
+
temperature: float = Field(default=0.2)
|
| 177 |
+
max_new_tokens: int = Field(default=2048)
|
| 178 |
+
inference_provider: Optional[str] = Field(
|
| 179 |
+
default=None,
|
| 180 |
+
description=(
|
| 181 |
+
"huggingface_hub provider id. Default is hf-inference (avoids Novita-only mappings). "
|
| 182 |
+
"Set to `auto` for router auto-routing (provider=None)."
|
| 183 |
+
),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
class Config:
|
| 187 |
+
"""Pydantic v1 config."""
|
| 188 |
+
|
| 189 |
+
arbitrary_types_allowed = True
|
| 190 |
+
|
| 191 |
+
client: Any = Field(default=None, exclude=True)
|
| 192 |
+
|
| 193 |
+
@root_validator(skip_on_failure=True)
|
| 194 |
+
def _build_client(cls, values: dict) -> dict:
|
| 195 |
+
if values.get("client") is not None:
|
| 196 |
+
return values
|
| 197 |
+
raw = values.get("inference_provider")
|
| 198 |
+
if isinstance(raw, str):
|
| 199 |
+
raw = raw.strip() or None
|
| 200 |
+
# Auto-routing often picks Novita for Mistral instruct; Novita maps that model to
|
| 201 |
+
# "conversational" only, so text_generation fails. Default to HF's inference proxy.
|
| 202 |
+
if raw is None:
|
| 203 |
+
client_provider: str | None = "hf-inference"
|
| 204 |
+
stored = "hf-inference"
|
| 205 |
+
elif raw.lower() == "auto":
|
| 206 |
+
client_provider = None
|
| 207 |
+
stored = "auto"
|
| 208 |
+
else:
|
| 209 |
+
client_provider = raw
|
| 210 |
+
stored = raw
|
| 211 |
+
values["inference_provider"] = stored
|
| 212 |
+
values["client"] = InferenceClient(
|
| 213 |
+
model=values["repo_id"],
|
| 214 |
+
token=values.get("huggingfacehub_api_token") or None,
|
| 215 |
+
provider=client_provider,
|
| 216 |
+
)
|
| 217 |
+
return values
|
| 218 |
+
|
| 219 |
+
def _chat_inference_clients(self) -> list[InferenceClient]:
|
| 220 |
+
"""Ordered ``InferenceClient`` instances for ``chat_completion``.
|
| 221 |
+
|
| 222 |
+
- Primary client (usually ``hf-inference`` when unset).
|
| 223 |
+
- For Mistral instruct ids, Novita often exposes **conversational** chat while HF task checks
|
| 224 |
+
or ``hf-inference`` reject the same repo.
|
| 225 |
+
- When primary is ``hf-inference``, append **router auto** (``provider=None``): many models
|
| 226 |
+
(e.g. Llama 3.1 Instruct) return *Model not supported by provider hf-inference* on the
|
| 227 |
+
serverless HF proxy but work via the inference router to another provider.
|
| 228 |
+
"""
|
| 229 |
+
token = self.huggingfacehub_api_token or None
|
| 230 |
+
rid = self.repo_id
|
| 231 |
+
clients: list[InferenceClient] = [self.client]
|
| 232 |
+
ip = (self.inference_provider or "").strip().lower()
|
| 233 |
+
if "mistral" in rid.lower() and ip != "novita":
|
| 234 |
+
clients.append(InferenceClient(model=rid, token=token, provider="novita"))
|
| 235 |
+
if ip == "hf-inference":
|
| 236 |
+
clients.append(InferenceClient(model=rid, token=token, provider=None))
|
| 237 |
+
return clients
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def _llm_type(self) -> str:
|
| 241 |
+
return "hf-hub-inference"
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def _identifying_params(self) -> dict[str, Any]:
|
| 245 |
+
return {
|
| 246 |
+
"repo_id": self.repo_id,
|
| 247 |
+
"temperature": self.temperature,
|
| 248 |
+
"max_new_tokens": self.max_new_tokens,
|
| 249 |
+
"inference_provider": self.inference_provider,
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
def _text_generation_fallback(self, messages: List[BaseMessage], stop: Optional[List[str]]) -> str:
|
| 253 |
+
prompt = _messages_to_text_generation_prompt(self.repo_id, messages)
|
| 254 |
+
token = self.huggingfacehub_api_token
|
| 255 |
+
rid = self.repo_id
|
| 256 |
+
chain_raw: list[str | None] = []
|
| 257 |
+
p = (self.inference_provider or "").strip()
|
| 258 |
+
if p.lower() == "auto":
|
| 259 |
+
chain_raw.append(None)
|
| 260 |
+
elif p and p.lower() != "hf-inference":
|
| 261 |
+
chain_raw.append(p)
|
| 262 |
+
chain_raw.append("hf-inference")
|
| 263 |
+
chain_raw.append(None)
|
| 264 |
+
chain: list[str | None] = []
|
| 265 |
+
seen: set[str] = set()
|
| 266 |
+
for prov in chain_raw:
|
| 267 |
+
key = prov if prov is not None else "__auto__"
|
| 268 |
+
if key in seen:
|
| 269 |
+
continue
|
| 270 |
+
seen.add(key)
|
| 271 |
+
chain.append(prov)
|
| 272 |
+
|
| 273 |
+
last: Exception | None = None
|
| 274 |
+
for prov in chain:
|
| 275 |
+
try:
|
| 276 |
+
cli = InferenceClient(model=rid, token=token, provider=prov)
|
| 277 |
+
raw = cli.text_generation(
|
| 278 |
+
prompt,
|
| 279 |
+
model=rid,
|
| 280 |
+
max_new_tokens=self.max_new_tokens,
|
| 281 |
+
temperature=self.temperature,
|
| 282 |
+
stop=stop,
|
| 283 |
+
return_full_text=False,
|
| 284 |
+
)
|
| 285 |
+
return (raw if isinstance(raw, str) else str(raw)).strip()
|
| 286 |
+
except Exception as exc:
|
| 287 |
+
last = exc
|
| 288 |
+
continue
|
| 289 |
+
try:
|
| 290 |
+
return _legacy_api_text_generation(
|
| 291 |
+
rid,
|
| 292 |
+
token,
|
| 293 |
+
prompt,
|
| 294 |
+
max_new_tokens=self.max_new_tokens,
|
| 295 |
+
temperature=self.temperature,
|
| 296 |
+
stop=stop,
|
| 297 |
+
)
|
| 298 |
+
except Exception as legacy_exc:
|
| 299 |
+
if last is not None:
|
| 300 |
+
# Prefer the legacy endpoint error (e.g. explicit 404 guidance) over the last
|
| 301 |
+
# provider text_generation failure (often a task-mapping ValueError).
|
| 302 |
+
raise legacy_exc from last
|
| 303 |
+
raise legacy_exc
|
| 304 |
+
|
| 305 |
+
def _generate(
|
| 306 |
+
self,
|
| 307 |
+
messages: List[BaseMessage],
|
| 308 |
+
stop: Optional[List[str]] = None,
|
| 309 |
+
run_manager: Optional[Any] = None,
|
| 310 |
+
**kwargs: Any,
|
| 311 |
+
) -> ChatResult:
|
| 312 |
+
chat_payload = _lc_messages_to_hf_chat(messages)
|
| 313 |
+
last_chat_err: BaseException | None = None
|
| 314 |
+
|
| 315 |
+
for cli in self._chat_inference_clients():
|
| 316 |
+
try:
|
| 317 |
+
out = cli.chat_completion(
|
| 318 |
+
chat_payload,
|
| 319 |
+
model=self.repo_id,
|
| 320 |
+
max_tokens=self.max_new_tokens,
|
| 321 |
+
temperature=self.temperature,
|
| 322 |
+
stop=stop,
|
| 323 |
+
)
|
| 324 |
+
text, usage_meta = _chat_completion_text_and_usage(out)
|
| 325 |
+
message = AIMessage(content=text, usage_metadata=usage_meta)
|
| 326 |
+
return ChatResult(generations=[ChatGeneration(message=message)])
|
| 327 |
+
except BadRequestError as exc:
|
| 328 |
+
last_chat_err = exc
|
| 329 |
+
err = str(exc).lower()
|
| 330 |
+
if (
|
| 331 |
+
"not a chat model" in err
|
| 332 |
+
or "model_not_supported" in err
|
| 333 |
+
or "not supported by provider" in err
|
| 334 |
+
# Defer to post-loop handling so we can explain gated / unknown ids without masking
|
| 335 |
+
# earlier recoverable errors from another client.
|
| 336 |
+
or "model_not_found" in err
|
| 337 |
+
or ("does not exist" in err and "model" in err)
|
| 338 |
+
):
|
| 339 |
+
continue
|
| 340 |
+
raise
|
| 341 |
+
except HfHubHTTPError as exc:
|
| 342 |
+
last_chat_err = exc
|
| 343 |
+
code = getattr(exc.response, "status_code", None)
|
| 344 |
+
# Novita/router may 404 a model or route; try remaining clients then completion fallbacks.
|
| 345 |
+
if code in (404, 410):
|
| 346 |
+
continue
|
| 347 |
+
raise
|
| 348 |
+
except ValueError as exc:
|
| 349 |
+
# e.g. hf-inference _check_supported_task when Hub model card has no pipeline_tag
|
| 350 |
+
last_chat_err = exc
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
if last_chat_err is not None and isinstance(last_chat_err, BadRequestError):
|
| 354 |
+
le = str(last_chat_err).lower()
|
| 355 |
+
if "model_not_found" in le or (
|
| 356 |
+
"does not exist" in le and ("model" in le or "requested model" in le)
|
| 357 |
+
):
|
| 358 |
+
raise RuntimeError(
|
| 359 |
+
f"Inference router could not use chat model {self.repo_id!r} "
|
| 360 |
+
"(common for gated models: open the model page on the Hugging Face Hub, accept the "
|
| 361 |
+
"license, ensure your API token has read access to that model, then retry)."
|
| 362 |
+
) from last_chat_err
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
text = self._text_generation_fallback(messages, stop)
|
| 366 |
+
except LegacyInferenceNotFoundError:
|
| 367 |
+
raise
|
| 368 |
+
except Exception as exc:
|
| 369 |
+
hint = (
|
| 370 |
+
f"Hugging Face chat_completion failed for {self.repo_id!r} on all tried providers; "
|
| 371 |
+
"text_generation / legacy fallbacks also failed. "
|
| 372 |
+
"Accept the model license on the Hub, check your token, or set "
|
| 373 |
+
"HUGGINGFACE_INFERENCE_PROVIDER=auto to use only router routing."
|
| 374 |
+
)
|
| 375 |
+
if last_chat_err is not None:
|
| 376 |
+
raise RuntimeError(f"{hint} Last chat error: {last_chat_err!r}") from exc
|
| 377 |
+
raise RuntimeError(hint) from exc
|
| 378 |
+
|
| 379 |
+
message = AIMessage(content=text, usage_metadata=None)
|
| 380 |
+
return ChatResult(generations=[ChatGeneration(message=message)])
|
rag/loader.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load raw documents from disk into LangChain ``Document`` objects.
|
| 2 |
+
|
| 3 |
+
Supports PDF (PyMuPDF), plain text, and Markdown. Each document gets ``source`` and
|
| 4 |
+
``page`` metadata for downstream chunking and citations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from langchain_core.documents import Document
|
| 10 |
+
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_documents(paths: str | list[str]) -> list[Document]:
|
| 14 |
+
"""Load one or more files; raise ``ValueError`` for unsupported extensions."""
|
| 15 |
+
normalized_paths = [paths] if isinstance(paths, str) else paths
|
| 16 |
+
all_docs: list[Document] = []
|
| 17 |
+
for path_str in normalized_paths:
|
| 18 |
+
path = Path(path_str)
|
| 19 |
+
suffix = path.suffix.lower()
|
| 20 |
+
|
| 21 |
+
if suffix == ".pdf":
|
| 22 |
+
loader = PyMuPDFLoader(str(path_str))
|
| 23 |
+
elif suffix in {".txt", ".md"}:
|
| 24 |
+
loader = TextLoader(str(path_str), encoding="utf-8")
|
| 25 |
+
else:
|
| 26 |
+
raise ValueError(f"Unsupported file type: {suffix or 'unknown'}")
|
| 27 |
+
|
| 28 |
+
documents = loader.load()
|
| 29 |
+
for doc in documents:
|
| 30 |
+
doc.metadata.setdefault("source", path.name)
|
| 31 |
+
doc.metadata.setdefault("page", 0)
|
| 32 |
+
all_docs.extend(documents)
|
| 33 |
+
|
| 34 |
+
return all_docs
|
| 35 |
+
|
rag/retriever.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Semantic retrieval and grounded LLM generation for ask and summarise flows.
|
| 2 |
+
|
| 3 |
+
Pipeline: similarity search on Chroma → relevance filter → provider-specific chat model
|
| 4 |
+
→ answer with citations. Prompt templates enforce document-only answers for consulting use.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
from langchain_chroma import Chroma
|
| 10 |
+
from langchain_core.language_models import BaseChatModel
|
| 11 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 12 |
+
from langchain_ollama import ChatOllama
|
| 13 |
+
from langchain_openai import ChatOpenAI
|
| 14 |
+
from pydantic import SecretStr
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from langchain_anthropic import ChatAnthropic
|
| 18 |
+
except ImportError:
|
| 19 |
+
ChatAnthropic = None # type: ignore[assignment]
|
| 20 |
+
|
| 21 |
+
from api.config import Settings
|
| 22 |
+
from rag.hf_hub_inference import HubInferenceChatModel
|
| 23 |
+
|
| 24 |
+
NO_MATCH_ANSWER = "I cannot find this information in the uploaded documents."
|
| 25 |
+
MIN_RELEVANCE_SCORE = 0.15
|
| 26 |
+
|
| 27 |
+
# Verbatim from DOCUAUDIT_AI_REQUIREMENTS.md (placeholders filled at runtime).
|
| 28 |
+
DOCUAUDIT_ASK_TEMPLATE = """You are DocuAudit AI, an expert document analyst for consulting environments.
|
| 29 |
+
|
| 30 |
+
RULES:
|
| 31 |
+
1. Answer ONLY based on the provided document excerpts below.
|
| 32 |
+
2. If the answer is not in the documents, say: "I cannot find this information in the uploaded documents."
|
| 33 |
+
3. ALWAYS cite your sources: mention the document name and page number for every claim.
|
| 34 |
+
4. Be precise and professional. This is a high-stakes consulting environment.
|
| 35 |
+
5. Do not speculate or add information not present in the documents.
|
| 36 |
+
|
| 37 |
+
DOCUMENT EXCERPTS:
|
| 38 |
+
{context}
|
| 39 |
+
|
| 40 |
+
QUESTION: {question}
|
| 41 |
+
|
| 42 |
+
ANSWER (with source citations):
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class RetrievedChunk:
|
| 48 |
+
"""One search hit with metadata needed for prompts and API citations."""
|
| 49 |
+
|
| 50 |
+
text: str
|
| 51 |
+
score: float | None
|
| 52 |
+
source: str
|
| 53 |
+
page: int | None
|
| 54 |
+
chunk_index: int | None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def retrieve_chunks(vector_store: Chroma, question: str, k: int) -> list[RetrievedChunk]:
|
| 58 |
+
"""Top-K similarity search with relevance scores from Chroma/LangChain."""
|
| 59 |
+
results = vector_store.similarity_search_with_relevance_scores(question, k=k)
|
| 60 |
+
chunks: list[RetrievedChunk] = []
|
| 61 |
+
for doc, score in results:
|
| 62 |
+
metadata = doc.metadata or {}
|
| 63 |
+
chunks.append(
|
| 64 |
+
RetrievedChunk(
|
| 65 |
+
text=doc.page_content,
|
| 66 |
+
score=score,
|
| 67 |
+
source=str(metadata.get("source", "unknown")),
|
| 68 |
+
page=_to_int_or_none(metadata.get("page")),
|
| 69 |
+
chunk_index=_to_int_or_none(metadata.get("chunk_index")),
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
return chunks
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
SUMMARY_RETRIEVAL_QUERY = (
|
| 76 |
+
"Overview of the document: main topics, key definitions, obligations, risks, and conclusions."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> tuple[str, int]:
|
| 81 |
+
"""Generate a cited answer from chunks; return ``(answer_text, token_count)``."""
|
| 82 |
+
ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
|
| 83 |
+
if not ranked_chunks:
|
| 84 |
+
return NO_MATCH_ANSWER, 0
|
| 85 |
+
|
| 86 |
+
llm = _create_chat_model(settings)
|
| 87 |
+
prompt_context = _format_context(ranked_chunks)
|
| 88 |
+
user_content = DOCUAUDIT_ASK_TEMPLATE.format(context=prompt_context, question=question)
|
| 89 |
+
messages = [HumanMessage(content=user_content)]
|
| 90 |
+
response = llm.invoke(messages)
|
| 91 |
+
answer = _extract_message_text(response).strip()
|
| 92 |
+
tokens = _extract_usage_tokens(response)
|
| 93 |
+
return (answer or NO_MATCH_ANSWER), tokens
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def summarise_with_grounding(
|
| 97 |
+
settings: Settings,
|
| 98 |
+
*,
|
| 99 |
+
focus: str | None,
|
| 100 |
+
chunks: list[RetrievedChunk],
|
| 101 |
+
) -> tuple[str, int]:
|
| 102 |
+
"""Produce a structured summary grounded in retrieved excerpts."""
|
| 103 |
+
ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
|
| 104 |
+
if not ranked_chunks:
|
| 105 |
+
return NO_MATCH_ANSWER, 0
|
| 106 |
+
|
| 107 |
+
llm = _create_chat_model(settings)
|
| 108 |
+
prompt_context = _format_context(ranked_chunks)
|
| 109 |
+
user_instruction = (
|
| 110 |
+
focus.strip()
|
| 111 |
+
if focus and focus.strip()
|
| 112 |
+
else "Summarise the main themes, structure, and important details. Use bullet points where helpful."
|
| 113 |
+
)
|
| 114 |
+
messages = [
|
| 115 |
+
SystemMessage(
|
| 116 |
+
content=(
|
| 117 |
+
"You write accurate summaries using only the provided document excerpts. "
|
| 118 |
+
"Do not invent facts. If the excerpts are insufficient, say what is missing."
|
| 119 |
+
)
|
| 120 |
+
),
|
| 121 |
+
HumanMessage(
|
| 122 |
+
content=(
|
| 123 |
+
f"Summary request: {user_instruction}\n\n"
|
| 124 |
+
f"Document excerpts:\n{prompt_context}\n\n"
|
| 125 |
+
"Return a structured, concise summary grounded in the excerpts above."
|
| 126 |
+
)
|
| 127 |
+
),
|
| 128 |
+
]
|
| 129 |
+
response = llm.invoke(messages)
|
| 130 |
+
answer = _extract_message_text(response).strip()
|
| 131 |
+
tokens = _extract_usage_tokens(response)
|
| 132 |
+
return (answer or NO_MATCH_ANSWER), tokens
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _create_chat_model(settings: Settings) -> BaseChatModel:
|
| 136 |
+
provider = settings.llm_provider.lower()
|
| 137 |
+
|
| 138 |
+
if provider == "openai":
|
| 139 |
+
if not settings.openai_api_key:
|
| 140 |
+
raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai")
|
| 141 |
+
return ChatOpenAI(model=settings.openai_model, api_key=SecretStr(settings.openai_api_key))
|
| 142 |
+
if provider == "ollama":
|
| 143 |
+
return ChatOllama(model=settings.ollama_chat_model, base_url=settings.ollama_base_url)
|
| 144 |
+
if provider == "anthropic":
|
| 145 |
+
if ChatAnthropic is None:
|
| 146 |
+
raise ValueError("langchain-anthropic is not installed for LLM_PROVIDER=anthropic")
|
| 147 |
+
if not settings.anthropic_api_key:
|
| 148 |
+
raise ValueError("ANTHROPIC_API_KEY is required when LLM_PROVIDER=anthropic")
|
| 149 |
+
return ChatAnthropic(model=settings.anthropic_model, api_key=SecretStr(settings.anthropic_api_key))
|
| 150 |
+
if provider == "huggingface":
|
| 151 |
+
if not settings.huggingface_api_key:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"A Hugging Face token is required when LLM_PROVIDER=huggingface "
|
| 154 |
+
"(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)."
|
| 155 |
+
)
|
| 156 |
+
return HubInferenceChatModel(
|
| 157 |
+
repo_id=settings.huggingface_model,
|
| 158 |
+
huggingfacehub_api_token=settings.huggingface_api_key,
|
| 159 |
+
temperature=0.2,
|
| 160 |
+
max_new_tokens=2048,
|
| 161 |
+
inference_provider=settings.huggingface_inference_provider,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _format_context(chunks: list[RetrievedChunk]) -> str:
|
| 168 |
+
lines: list[str] = []
|
| 169 |
+
for idx, chunk in enumerate(chunks, start=1):
|
| 170 |
+
lines.append(
|
| 171 |
+
f"[{idx}] source={chunk.source}, page={chunk.page}, chunk={chunk.chunk_index}, score={chunk.score}\n"
|
| 172 |
+
f"{chunk.text}"
|
| 173 |
+
)
|
| 174 |
+
return "\n\n".join(lines)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _to_int_or_none(value: object) -> int | None:
|
| 178 |
+
try:
|
| 179 |
+
if value is None:
|
| 180 |
+
return None
|
| 181 |
+
return int(value)
|
| 182 |
+
except (TypeError, ValueError):
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _extract_usage_tokens(response: object) -> int:
|
| 187 |
+
um = getattr(response, "usage_metadata", None)
|
| 188 |
+
if isinstance(um, dict):
|
| 189 |
+
total = um.get("total_tokens")
|
| 190 |
+
if total is not None:
|
| 191 |
+
return int(total)
|
| 192 |
+
inp = int(um.get("input_tokens", 0) or 0)
|
| 193 |
+
out = int(um.get("output_tokens", 0) or 0)
|
| 194 |
+
return inp + out
|
| 195 |
+
rm = getattr(response, "response_metadata", None) or {}
|
| 196 |
+
if isinstance(rm, dict):
|
| 197 |
+
tu = rm.get("token_usage")
|
| 198 |
+
if isinstance(tu, dict):
|
| 199 |
+
if tu.get("total_tokens") is not None:
|
| 200 |
+
return int(tu["total_tokens"])
|
| 201 |
+
return int(tu.get("prompt_tokens", 0) or 0) + int(tu.get("completion_tokens", 0) or 0)
|
| 202 |
+
return 0
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _extract_message_text(response: object) -> str:
|
| 206 |
+
content = getattr(response, "content", "")
|
| 207 |
+
if isinstance(content, str):
|
| 208 |
+
return content
|
| 209 |
+
if isinstance(content, list):
|
| 210 |
+
text_parts: list[str] = []
|
| 211 |
+
for item in content:
|
| 212 |
+
if isinstance(item, str):
|
| 213 |
+
text_parts.append(item)
|
| 214 |
+
elif isinstance(item, dict) and "text" in item:
|
| 215 |
+
text_parts.append(str(item["text"]))
|
| 216 |
+
return "\n".join(part for part in text_parts if part)
|
| 217 |
+
return str(content)
|
| 218 |
+
|
rag/vector_store.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ChromaDB persistence and LangChain ``Chroma`` vector store helpers.
|
| 2 |
+
|
| 3 |
+
Collections are named per ingest target; documents are stored with UUID chunk ids.
|
| 4 |
+
Telemetry is disabled at the client level for quieter logs in production.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from datetime import datetime, timezone
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
|
| 11 |
+
import chromadb
|
| 12 |
+
from chromadb.config import Settings
|
| 13 |
+
from langchain_chroma import Chroma
|
| 14 |
+
from langchain_core.documents import Document
|
| 15 |
+
from langchain_core.embeddings import Embeddings
|
| 16 |
+
|
| 17 |
+
_CHROMA_CLIENT_SETTINGS = Settings(anonymized_telemetry=False)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _utc_now_iso() -> str:
|
| 21 |
+
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _chroma_client(persist_directory: str) -> chromadb.PersistentClient:
|
| 25 |
+
Path(persist_directory).mkdir(parents=True, exist_ok=True)
|
| 26 |
+
return chromadb.PersistentClient(path=persist_directory, settings=_CHROMA_CLIENT_SETTINGS)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_vector_store(
|
| 30 |
+
persist_directory: str,
|
| 31 |
+
collection_name: str,
|
| 32 |
+
embedding_function: Embeddings,
|
| 33 |
+
) -> Chroma:
|
| 34 |
+
"""Open or create a persisted Chroma collection wired to the given embedder."""
|
| 35 |
+
client = _chroma_client(persist_directory)
|
| 36 |
+
try:
|
| 37 |
+
client.get_collection(name=collection_name)
|
| 38 |
+
except Exception:
|
| 39 |
+
client.get_or_create_collection(
|
| 40 |
+
name=collection_name,
|
| 41 |
+
metadata={"created_at": _utc_now_iso()},
|
| 42 |
+
)
|
| 43 |
+
return Chroma(
|
| 44 |
+
collection_name=collection_name,
|
| 45 |
+
embedding_function=embedding_function,
|
| 46 |
+
persist_directory=persist_directory,
|
| 47 |
+
client_settings=_CHROMA_CLIENT_SETTINGS,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def add_documents(vector_store: Chroma, chunks: list[Document]) -> list[str]:
|
| 52 |
+
"""Embed and insert chunks; return the generated vector ids."""
|
| 53 |
+
document_ids = [str(uuid4()) for _ in chunks]
|
| 54 |
+
vector_store.add_documents(documents=chunks, ids=document_ids)
|
| 55 |
+
return document_ids
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def list_collection_names(persist_directory: str) -> list[str]:
|
| 59 |
+
"""Sorted list of collection names in the persist directory."""
|
| 60 |
+
client = _chroma_client(persist_directory)
|
| 61 |
+
return sorted(c.name for c in client.list_collections())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def delete_collection(persist_directory: str, collection_name: str) -> int:
|
| 65 |
+
"""Delete a collection and return the number of documents that were removed (best effort)."""
|
| 66 |
+
client = _chroma_client(persist_directory)
|
| 67 |
+
removed = 0
|
| 68 |
+
try:
|
| 69 |
+
col = client.get_collection(name=collection_name)
|
| 70 |
+
removed = int(col.count())
|
| 71 |
+
except Exception:
|
| 72 |
+
removed = 0
|
| 73 |
+
client.delete_collection(name=collection_name)
|
| 74 |
+
return removed
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def collection_document_count(persist_directory: str, collection_name: str) -> int:
|
| 78 |
+
"""Number of vectors in a collection, or 0 if the collection does not exist."""
|
| 79 |
+
client = _chroma_client(persist_directory)
|
| 80 |
+
try:
|
| 81 |
+
col = client.get_collection(name=collection_name)
|
| 82 |
+
return int(col.count())
|
| 83 |
+
except Exception:
|
| 84 |
+
return 0
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def collection_created_at(persist_directory: str, collection_name: str) -> str | None:
|
| 88 |
+
"""Return collection metadata ``created_at`` if present (Chroma-specific)."""
|
| 89 |
+
client = _chroma_client(persist_directory)
|
| 90 |
+
try:
|
| 91 |
+
col = client.get_collection(name=collection_name)
|
| 92 |
+
meta = getattr(col, "metadata", None) or {}
|
| 93 |
+
if isinstance(meta, dict):
|
| 94 |
+
raw = meta.get("created_at") or meta.get("created")
|
| 95 |
+
if raw is not None:
|
| 96 |
+
return str(raw)
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def ensure_collection_created_at(
|
| 103 |
+
persist_directory: str,
|
| 104 |
+
collection_name: str,
|
| 105 |
+
*,
|
| 106 |
+
fallback: str | None = None,
|
| 107 |
+
) -> str | None:
|
| 108 |
+
"""Persist ``created_at`` on the Chroma collection when missing; never overwrites an existing value."""
|
| 109 |
+
client = _chroma_client(persist_directory)
|
| 110 |
+
try:
|
| 111 |
+
col = client.get_collection(name=collection_name)
|
| 112 |
+
except Exception:
|
| 113 |
+
return None
|
| 114 |
+
meta = getattr(col, "metadata", None) or {}
|
| 115 |
+
if not isinstance(meta, dict):
|
| 116 |
+
meta = {}
|
| 117 |
+
raw = meta.get("created_at") or meta.get("created")
|
| 118 |
+
if raw is not None:
|
| 119 |
+
return str(raw)
|
| 120 |
+
value = fallback or _utc_now_iso()
|
| 121 |
+
updated = dict(meta)
|
| 122 |
+
updated["created_at"] = value
|
| 123 |
+
col.modify(metadata=updated)
|
| 124 |
+
return value
|
| 125 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.111.0
|
| 2 |
+
uvicorn[standard]==0.29.0
|
| 3 |
+
pydantic-settings==2.3.4
|
| 4 |
+
langchain==0.2.0
|
| 5 |
+
langchain-openai==0.1.7
|
| 6 |
+
langchain-community==0.2.0
|
| 7 |
+
langchain-chroma==0.1.4
|
| 8 |
+
langchain-text-splitters==0.2.0
|
| 9 |
+
langchain-anthropic==0.1.15
|
| 10 |
+
langchain-ollama==0.1.3
|
| 11 |
+
chromadb==0.5.0
|
| 12 |
+
posthog>=3.7.0,<4
|
| 13 |
+
openai==1.30.1
|
| 14 |
+
anthropic==0.28.1
|
| 15 |
+
pymupdf==1.25.5
|
| 16 |
+
python-multipart==0.0.9
|
| 17 |
+
aiosqlite
|
| 18 |
+
httpx>=0.27.0
|
| 19 |
+
huggingface-hub
|
| 20 |
+
langchain-huggingface
|
| 21 |
+
streamlit>=1.39.0
|
| 22 |
+
pytest>=8.4.2
|
| 23 |
+
pytest-asyncio>=1.2.0
|
sample.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Doc-Audi-AI RAG Smoke Test Document
|
| 2 |
+
|
| 3 |
+
Project: Doc-Audi-AI
|
| 4 |
+
Environment: Lightning AI deployment with Ollama embeddings.
|
| 5 |
+
|
| 6 |
+
This sample document is used to test ingestion and retrieval.
|
| 7 |
+
The system should split this file into chunks, generate embeddings, and store vectors in Chroma.
|
| 8 |
+
|
| 9 |
+
Key facts:
|
| 10 |
+
- The project supports file ingestion for PDF, TXT, and MD formats.
|
| 11 |
+
- The default collection name for tests is "default".
|
| 12 |
+
- A typical retrieval question is: "What is this document about?"
|
| 13 |
+
- Another test question is: "Which file formats are supported?"
|
| 14 |
+
|
| 15 |
+
Expected behavior:
|
| 16 |
+
If ingestion succeeds, querying should return text snippets from this document with relevance scores.
|
storage/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Persistence layer: SQLite audit log and ingest job tracking."""
|
storage/audit_store.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLite persistence for query and summarise audit events.
|
| 2 |
+
|
| 3 |
+
Schema is created/migrated on first use. Stores full answers, citation JSON, token usage,
|
| 4 |
+
and optional filters (user_id, date range) for list endpoints.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime, timezone
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
from uuid import uuid4
|
| 12 |
+
|
| 13 |
+
import aiosqlite
|
| 14 |
+
|
| 15 |
+
from models.responses import AuditLogDetailResponse, AuditLogEntry, SourceCitation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _utc_now_iso() -> str:
|
| 19 |
+
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _parse_ts(value: object) -> datetime:
|
| 23 |
+
if value is None or value == "":
|
| 24 |
+
return datetime.now(timezone.utc)
|
| 25 |
+
s = str(value).strip()
|
| 26 |
+
if s.endswith("Z"):
|
| 27 |
+
s = s[:-1] + "+00:00"
|
| 28 |
+
try:
|
| 29 |
+
dt = datetime.fromisoformat(s)
|
| 30 |
+
if dt.tzinfo is None:
|
| 31 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 32 |
+
return dt
|
| 33 |
+
except ValueError:
|
| 34 |
+
return datetime.now(timezone.utc)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def _migrate_audit_columns(conn: aiosqlite.Connection) -> None:
|
| 38 |
+
cursor = await conn.execute("PRAGMA table_info(audit_events)")
|
| 39 |
+
rows = await cursor.fetchall()
|
| 40 |
+
col_names = {str(r[1]) for r in rows}
|
| 41 |
+
alters: list[str] = []
|
| 42 |
+
if "user_id" not in col_names:
|
| 43 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN user_id TEXT NOT NULL DEFAULT 'anonymous'")
|
| 44 |
+
if "model_used" not in col_names:
|
| 45 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN model_used TEXT")
|
| 46 |
+
if "tokens_used" not in col_names:
|
| 47 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN tokens_used INTEGER")
|
| 48 |
+
if "response_time_ms" not in col_names:
|
| 49 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN response_time_ms INTEGER")
|
| 50 |
+
if "answer_summary" not in col_names:
|
| 51 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN answer_summary TEXT")
|
| 52 |
+
if "kind" not in col_names:
|
| 53 |
+
alters.append("ALTER TABLE audit_events ADD COLUMN kind TEXT NOT NULL DEFAULT 'ask'")
|
| 54 |
+
for stmt in alters:
|
| 55 |
+
await conn.execute(stmt)
|
| 56 |
+
if alters:
|
| 57 |
+
await conn.commit()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
async def init_audit_db(db_path: str) -> None:
|
| 61 |
+
"""Create ``audit_events`` table and apply additive column migrations."""
|
| 62 |
+
db_file = Path(db_path)
|
| 63 |
+
db_file.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
async with aiosqlite.connect(db_file.as_posix()) as conn:
|
| 65 |
+
await conn.execute(
|
| 66 |
+
"""
|
| 67 |
+
CREATE TABLE IF NOT EXISTS audit_events (
|
| 68 |
+
event_id TEXT PRIMARY KEY,
|
| 69 |
+
action TEXT NOT NULL,
|
| 70 |
+
question TEXT NOT NULL,
|
| 71 |
+
collection_name TEXT NOT NULL,
|
| 72 |
+
answer TEXT,
|
| 73 |
+
status TEXT NOT NULL,
|
| 74 |
+
message TEXT NOT NULL,
|
| 75 |
+
sources_json TEXT NOT NULL,
|
| 76 |
+
results_json TEXT NOT NULL,
|
| 77 |
+
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 78 |
+
user_id TEXT NOT NULL DEFAULT 'anonymous',
|
| 79 |
+
model_used TEXT,
|
| 80 |
+
tokens_used INTEGER,
|
| 81 |
+
response_time_ms INTEGER,
|
| 82 |
+
answer_summary TEXT,
|
| 83 |
+
kind TEXT NOT NULL DEFAULT 'ask'
|
| 84 |
+
)
|
| 85 |
+
"""
|
| 86 |
+
)
|
| 87 |
+
await conn.commit()
|
| 88 |
+
await _migrate_audit_columns(conn)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _summary_from_answer(answer: str, max_len: int = 280) -> str:
|
| 92 |
+
text = (answer or "").strip()
|
| 93 |
+
if len(text) <= max_len:
|
| 94 |
+
return text
|
| 95 |
+
return text[: max_len - 1].rstrip() + "…"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _sources_to_citations(raw: list[dict[str, Any]]) -> list[SourceCitation]:
|
| 99 |
+
out: list[SourceCitation] = []
|
| 100 |
+
for item in raw:
|
| 101 |
+
if not isinstance(item, dict):
|
| 102 |
+
continue
|
| 103 |
+
if "document_name" in item:
|
| 104 |
+
doc = str(item.get("document_name", ""))
|
| 105 |
+
page = int(item.get("page_number", 0) or 0)
|
| 106 |
+
chunk = str(item.get("chunk_text", ""))
|
| 107 |
+
score = float(item.get("relevance_score", 0.0) or 0.0)
|
| 108 |
+
else:
|
| 109 |
+
doc = str(item.get("source", item.get("document_name", "")))
|
| 110 |
+
p = item.get("page_number", item.get("page"))
|
| 111 |
+
try:
|
| 112 |
+
page = int(p) if p is not None else 0
|
| 113 |
+
except (TypeError, ValueError):
|
| 114 |
+
page = 0
|
| 115 |
+
chunk = str(item.get("chunk_text", item.get("excerpt", item.get("text", ""))))
|
| 116 |
+
s = item.get("relevance_score", item.get("score"))
|
| 117 |
+
try:
|
| 118 |
+
score = float(s) if s is not None else 0.0
|
| 119 |
+
except (TypeError, ValueError):
|
| 120 |
+
score = 0.0
|
| 121 |
+
out.append(
|
| 122 |
+
SourceCitation(
|
| 123 |
+
document_name=doc or "unknown",
|
| 124 |
+
page_number=page,
|
| 125 |
+
chunk_text=chunk,
|
| 126 |
+
relevance_score=score,
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
async def persist_query_audit(
|
| 133 |
+
db_path: str,
|
| 134 |
+
*,
|
| 135 |
+
query_id: str,
|
| 136 |
+
action: str,
|
| 137 |
+
user_id: str,
|
| 138 |
+
question: str,
|
| 139 |
+
collection_name: str,
|
| 140 |
+
answer: str,
|
| 141 |
+
sources: list[SourceCitation],
|
| 142 |
+
model_used: str,
|
| 143 |
+
tokens_used: int,
|
| 144 |
+
response_time_ms: int,
|
| 145 |
+
status: str = "success",
|
| 146 |
+
message: str = "ok",
|
| 147 |
+
kind: str = "ask",
|
| 148 |
+
) -> str:
|
| 149 |
+
"""Insert one audit row after a successful ask or summarise; returns ``query_id``."""
|
| 150 |
+
await init_audit_db(db_path)
|
| 151 |
+
sources_payload = [s.model_dump(mode="json") for s in sources]
|
| 152 |
+
summary = _summary_from_answer(answer)
|
| 153 |
+
created = _utc_now_iso()
|
| 154 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 155 |
+
await conn.execute(
|
| 156 |
+
"""
|
| 157 |
+
INSERT INTO audit_events (
|
| 158 |
+
event_id, action, question, collection_name, answer, status, message,
|
| 159 |
+
sources_json, results_json, created_at, user_id, model_used, tokens_used,
|
| 160 |
+
response_time_ms, answer_summary, kind
|
| 161 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, '[]', ?, ?, ?, ?, ?, ?, ?)
|
| 162 |
+
""",
|
| 163 |
+
(
|
| 164 |
+
query_id,
|
| 165 |
+
action,
|
| 166 |
+
question,
|
| 167 |
+
collection_name,
|
| 168 |
+
answer,
|
| 169 |
+
status,
|
| 170 |
+
message,
|
| 171 |
+
json.dumps(sources_payload),
|
| 172 |
+
created,
|
| 173 |
+
user_id,
|
| 174 |
+
model_used,
|
| 175 |
+
tokens_used,
|
| 176 |
+
response_time_ms,
|
| 177 |
+
summary,
|
| 178 |
+
kind,
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
await conn.commit()
|
| 182 |
+
return query_id
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
async def count_audit_events(
|
| 186 |
+
db_path: str,
|
| 187 |
+
*,
|
| 188 |
+
user_id: str | None = None,
|
| 189 |
+
from_date: str | None = None,
|
| 190 |
+
to_date: str | None = None,
|
| 191 |
+
) -> int:
|
| 192 |
+
await init_audit_db(db_path)
|
| 193 |
+
where, params = _audit_filters(user_id, from_date, to_date)
|
| 194 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 195 |
+
cur = await conn.execute(f"SELECT COUNT(*) AS c FROM audit_events {where}", params)
|
| 196 |
+
row = await cur.fetchone()
|
| 197 |
+
return int(row[0]) if row else 0
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _audit_filters(user_id: str | None, from_date: str | None, to_date: str | None) -> tuple[str, list[Any]]:
|
| 201 |
+
clauses: list[str] = []
|
| 202 |
+
params: list[Any] = []
|
| 203 |
+
if user_id:
|
| 204 |
+
clauses.append("user_id = ?")
|
| 205 |
+
params.append(user_id)
|
| 206 |
+
if from_date:
|
| 207 |
+
clauses.append("datetime(created_at) >= datetime(?)")
|
| 208 |
+
params.append(from_date)
|
| 209 |
+
if to_date:
|
| 210 |
+
clauses.append("datetime(created_at) <= datetime(?)")
|
| 211 |
+
params.append(to_date)
|
| 212 |
+
if not clauses:
|
| 213 |
+
return "", []
|
| 214 |
+
return "WHERE " + " AND ".join(clauses), params
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
async def list_audit_events(
|
| 218 |
+
db_path: str,
|
| 219 |
+
*,
|
| 220 |
+
limit: int,
|
| 221 |
+
offset: int,
|
| 222 |
+
user_id: str | None = None,
|
| 223 |
+
from_date: str | None = None,
|
| 224 |
+
to_date: str | None = None,
|
| 225 |
+
) -> tuple[list[AuditLogEntry], int]:
|
| 226 |
+
"""Paginated audit list with optional user and ISO datetime filters."""
|
| 227 |
+
await init_audit_db(db_path)
|
| 228 |
+
where, fparams = _audit_filters(user_id, from_date, to_date)
|
| 229 |
+
total = await count_audit_events(db_path, user_id=user_id, from_date=from_date, to_date=to_date)
|
| 230 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 231 |
+
conn.row_factory = aiosqlite.Row
|
| 232 |
+
cursor = await conn.execute(
|
| 233 |
+
f"""
|
| 234 |
+
SELECT event_id, user_id, question, answer, answer_summary, sources_json, model_used, created_at
|
| 235 |
+
FROM audit_events
|
| 236 |
+
{where}
|
| 237 |
+
ORDER BY datetime(created_at) DESC, rowid DESC
|
| 238 |
+
LIMIT ? OFFSET ?
|
| 239 |
+
""",
|
| 240 |
+
[*fparams, limit, offset],
|
| 241 |
+
)
|
| 242 |
+
rows = await cursor.fetchall()
|
| 243 |
+
logs: list[AuditLogEntry] = []
|
| 244 |
+
for row in rows:
|
| 245 |
+
src_raw = json.loads(row["sources_json"] or "[]")
|
| 246 |
+
if not isinstance(src_raw, list):
|
| 247 |
+
src_raw = []
|
| 248 |
+
summary_cell = row["answer_summary"]
|
| 249 |
+
summary_text = str(summary_cell).strip() if summary_cell else ""
|
| 250 |
+
if not summary_text:
|
| 251 |
+
summary_text = _summary_from_answer(str(row["answer"] or ""))
|
| 252 |
+
logs.append(
|
| 253 |
+
AuditLogEntry(
|
| 254 |
+
query_id=str(row["event_id"]),
|
| 255 |
+
user_id=str(row["user_id"] or "anonymous"),
|
| 256 |
+
question=str(row["question"]),
|
| 257 |
+
answer_summary=summary_text,
|
| 258 |
+
sources_count=len(src_raw),
|
| 259 |
+
model_used=row["model_used"],
|
| 260 |
+
timestamp=_parse_ts(row["created_at"]),
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
return logs, total
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
async def get_audit_event(db_path: str, query_id: str) -> AuditLogDetailResponse | None:
|
| 267 |
+
"""Full audit record for one ``query_id``, or ``None`` if missing."""
|
| 268 |
+
await init_audit_db(db_path)
|
| 269 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 270 |
+
conn.row_factory = aiosqlite.Row
|
| 271 |
+
cursor = await conn.execute(
|
| 272 |
+
"""
|
| 273 |
+
SELECT event_id, user_id, question, answer, sources_json, model_used, tokens_used, created_at
|
| 274 |
+
FROM audit_events
|
| 275 |
+
WHERE event_id = ?
|
| 276 |
+
""",
|
| 277 |
+
(query_id,),
|
| 278 |
+
)
|
| 279 |
+
row = await cursor.fetchone()
|
| 280 |
+
if row is None:
|
| 281 |
+
return None
|
| 282 |
+
src_raw = json.loads(row["sources_json"] or "[]")
|
| 283 |
+
if not isinstance(src_raw, list):
|
| 284 |
+
src_raw = []
|
| 285 |
+
citations = _sources_to_citations(src_raw)
|
| 286 |
+
return AuditLogDetailResponse(
|
| 287 |
+
query_id=str(row["event_id"]),
|
| 288 |
+
user_id=str(row["user_id"] or "anonymous"),
|
| 289 |
+
question=str(row["question"]),
|
| 290 |
+
full_answer=str(row["answer"] or ""),
|
| 291 |
+
sources=citations,
|
| 292 |
+
model_used=row["model_used"],
|
| 293 |
+
tokens_used=row["tokens_used"],
|
| 294 |
+
timestamp=_parse_ts(row["created_at"]),
|
| 295 |
+
)
|
storage/job_store.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLite tracking for asynchronous document ingest jobs.
|
| 2 |
+
|
| 3 |
+
Jobs move through ``queued`` → ``processing`` → ``completed`` or ``failed``. Progress
|
| 4 |
+
fields support multi-file batches and per-file error messages.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime, timezone
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
from uuid import uuid4
|
| 12 |
+
|
| 13 |
+
import aiosqlite
|
| 14 |
+
|
| 15 |
+
from models.responses import JobListItem, JobStatusResponse
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _utc_now_iso() -> str:
|
| 19 |
+
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def _migrate_jobs_columns(conn: aiosqlite.Connection) -> None:
|
| 23 |
+
cursor = await conn.execute("PRAGMA table_info(ingest_jobs)")
|
| 24 |
+
rows = await cursor.fetchall()
|
| 25 |
+
col_names = {str(r[1]) for r in rows}
|
| 26 |
+
alters: list[str] = []
|
| 27 |
+
if "total_files" not in col_names:
|
| 28 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN total_files INTEGER NOT NULL DEFAULT 1")
|
| 29 |
+
if "processed_files" not in col_names:
|
| 30 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN processed_files INTEGER NOT NULL DEFAULT 0")
|
| 31 |
+
if "failed_files" not in col_names:
|
| 32 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN failed_files INTEGER NOT NULL DEFAULT 0")
|
| 33 |
+
if "filenames_json" not in col_names:
|
| 34 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN filenames_json TEXT NOT NULL DEFAULT '[]'")
|
| 35 |
+
if "errors_json" not in col_names:
|
| 36 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN errors_json TEXT NOT NULL DEFAULT '[]'")
|
| 37 |
+
if "started_at" not in col_names:
|
| 38 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN started_at TEXT")
|
| 39 |
+
if "completed_at" not in col_names:
|
| 40 |
+
alters.append("ALTER TABLE ingest_jobs ADD COLUMN completed_at TEXT")
|
| 41 |
+
for stmt in alters:
|
| 42 |
+
await conn.execute(stmt)
|
| 43 |
+
if alters:
|
| 44 |
+
await conn.commit()
|
| 45 |
+
await _backfill_job_filenames(conn)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def _backfill_job_filenames(conn: aiosqlite.Connection) -> None:
|
| 49 |
+
conn.row_factory = aiosqlite.Row
|
| 50 |
+
cursor = await conn.execute("SELECT job_id, filename, filenames_json, total_files FROM ingest_jobs")
|
| 51 |
+
rows = await cursor.fetchall()
|
| 52 |
+
for row in rows:
|
| 53 |
+
raw = row["filenames_json"] or "[]"
|
| 54 |
+
try:
|
| 55 |
+
parsed: Any = json.loads(raw)
|
| 56 |
+
except json.JSONDecodeError:
|
| 57 |
+
parsed = []
|
| 58 |
+
if not parsed and row["filename"]:
|
| 59 |
+
await conn.execute(
|
| 60 |
+
"""
|
| 61 |
+
UPDATE ingest_jobs
|
| 62 |
+
SET filenames_json = ?, total_files = CASE WHEN total_files IS NULL OR total_files < 1 THEN 1 ELSE total_files END
|
| 63 |
+
WHERE job_id = ?
|
| 64 |
+
""",
|
| 65 |
+
(json.dumps([row["filename"]]), row["job_id"]),
|
| 66 |
+
)
|
| 67 |
+
await conn.commit()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
async def init_jobs_db(db_path: str) -> None:
|
| 71 |
+
"""Create ``ingest_jobs`` table and apply additive column migrations."""
|
| 72 |
+
db_file = Path(db_path)
|
| 73 |
+
db_file.parent.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
async with aiosqlite.connect(db_file.as_posix()) as conn:
|
| 75 |
+
await conn.execute(
|
| 76 |
+
"""
|
| 77 |
+
CREATE TABLE IF NOT EXISTS ingest_jobs (
|
| 78 |
+
job_id TEXT PRIMARY KEY,
|
| 79 |
+
status TEXT NOT NULL,
|
| 80 |
+
collection_name TEXT NOT NULL,
|
| 81 |
+
filename TEXT NOT NULL,
|
| 82 |
+
message TEXT NOT NULL DEFAULT '',
|
| 83 |
+
document_ids_json TEXT NOT NULL DEFAULT '[]',
|
| 84 |
+
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 85 |
+
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
| 86 |
+
total_files INTEGER NOT NULL DEFAULT 1,
|
| 87 |
+
processed_files INTEGER NOT NULL DEFAULT 0,
|
| 88 |
+
failed_files INTEGER NOT NULL DEFAULT 0,
|
| 89 |
+
filenames_json TEXT NOT NULL DEFAULT '[]',
|
| 90 |
+
errors_json TEXT NOT NULL DEFAULT '[]',
|
| 91 |
+
started_at TEXT,
|
| 92 |
+
completed_at TEXT
|
| 93 |
+
)
|
| 94 |
+
"""
|
| 95 |
+
)
|
| 96 |
+
await conn.commit()
|
| 97 |
+
await _migrate_jobs_columns(conn)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
async def create_ingest_job(
|
| 101 |
+
db_path: str,
|
| 102 |
+
*,
|
| 103 |
+
collection_name: str,
|
| 104 |
+
filenames: list[str],
|
| 105 |
+
) -> str:
|
| 106 |
+
"""Insert a new queued job; return the generated ``job_id``."""
|
| 107 |
+
if not filenames:
|
| 108 |
+
raise ValueError("filenames must not be empty")
|
| 109 |
+
job_id = str(uuid4())
|
| 110 |
+
primary = filenames[0]
|
| 111 |
+
names_json = json.dumps(filenames)
|
| 112 |
+
total = len(filenames)
|
| 113 |
+
await init_jobs_db(db_path)
|
| 114 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 115 |
+
await conn.execute(
|
| 116 |
+
"""
|
| 117 |
+
INSERT INTO ingest_jobs (
|
| 118 |
+
job_id, status, collection_name, filename, message, document_ids_json,
|
| 119 |
+
total_files, processed_files, failed_files, filenames_json, errors_json
|
| 120 |
+
) VALUES (?, 'queued', ?, ?, '', '[]', ?, 0, 0, ?, '[]')
|
| 121 |
+
""",
|
| 122 |
+
(job_id, collection_name, primary, total, names_json),
|
| 123 |
+
)
|
| 124 |
+
await conn.commit()
|
| 125 |
+
return job_id
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
async def mark_job_processing(db_path: str, job_id: str) -> None:
|
| 129 |
+
await init_jobs_db(db_path)
|
| 130 |
+
started = _utc_now_iso()
|
| 131 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 132 |
+
await conn.execute(
|
| 133 |
+
"""
|
| 134 |
+
UPDATE ingest_jobs
|
| 135 |
+
SET status = 'processing', message = 'Ingestion in progress.', started_at = COALESCE(started_at, ?),
|
| 136 |
+
updated_at = CURRENT_TIMESTAMP
|
| 137 |
+
WHERE job_id = ?
|
| 138 |
+
""",
|
| 139 |
+
(started, job_id),
|
| 140 |
+
)
|
| 141 |
+
await conn.commit()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
async def update_job_progress(
|
| 145 |
+
db_path: str,
|
| 146 |
+
job_id: str,
|
| 147 |
+
*,
|
| 148 |
+
processed_files: int,
|
| 149 |
+
failed_files: int,
|
| 150 |
+
errors: list[str],
|
| 151 |
+
message: str | None = None,
|
| 152 |
+
) -> None:
|
| 153 |
+
await init_jobs_db(db_path)
|
| 154 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 155 |
+
await conn.execute(
|
| 156 |
+
"""
|
| 157 |
+
UPDATE ingest_jobs
|
| 158 |
+
SET processed_files = ?, failed_files = ?, errors_json = ?,
|
| 159 |
+
message = COALESCE(?, message), updated_at = CURRENT_TIMESTAMP
|
| 160 |
+
WHERE job_id = ?
|
| 161 |
+
""",
|
| 162 |
+
(processed_files, failed_files, json.dumps(errors), message, job_id),
|
| 163 |
+
)
|
| 164 |
+
await conn.commit()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
async def complete_ingest_job(
|
| 168 |
+
db_path: str,
|
| 169 |
+
job_id: str,
|
| 170 |
+
*,
|
| 171 |
+
document_ids: list[str],
|
| 172 |
+
message: str,
|
| 173 |
+
) -> None:
|
| 174 |
+
await init_jobs_db(db_path)
|
| 175 |
+
completed = _utc_now_iso()
|
| 176 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 177 |
+
await conn.execute(
|
| 178 |
+
"""
|
| 179 |
+
UPDATE ingest_jobs
|
| 180 |
+
SET status = 'completed', message = ?, document_ids_json = ?,
|
| 181 |
+
completed_at = ?, updated_at = CURRENT_TIMESTAMP
|
| 182 |
+
WHERE job_id = ?
|
| 183 |
+
""",
|
| 184 |
+
(message, json.dumps(document_ids), completed, job_id),
|
| 185 |
+
)
|
| 186 |
+
await conn.commit()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
async def fail_ingest_job(db_path: str, job_id: str, *, message: str, errors: list[str] | None = None) -> None:
|
| 190 |
+
await init_jobs_db(db_path)
|
| 191 |
+
completed = _utc_now_iso()
|
| 192 |
+
err_json = json.dumps(errors or [message])
|
| 193 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 194 |
+
await conn.execute(
|
| 195 |
+
"""
|
| 196 |
+
UPDATE ingest_jobs
|
| 197 |
+
SET status = 'failed', message = ?, errors_json = ?, completed_at = ?,
|
| 198 |
+
updated_at = CURRENT_TIMESTAMP
|
| 199 |
+
WHERE job_id = ?
|
| 200 |
+
""",
|
| 201 |
+
(message, err_json, completed, job_id),
|
| 202 |
+
)
|
| 203 |
+
await conn.commit()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
async def get_job_status(db_path: str, job_id: str) -> JobStatusResponse | None:
|
| 207 |
+
"""Job status DTO for API, including computed ``progress_percent``."""
|
| 208 |
+
await init_jobs_db(db_path)
|
| 209 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 210 |
+
conn.row_factory = aiosqlite.Row
|
| 211 |
+
cursor = await conn.execute(
|
| 212 |
+
"""
|
| 213 |
+
SELECT job_id, status, total_files, processed_files, failed_files, errors_json,
|
| 214 |
+
started_at, completed_at, message
|
| 215 |
+
FROM ingest_jobs
|
| 216 |
+
WHERE job_id = ?
|
| 217 |
+
""",
|
| 218 |
+
(job_id,),
|
| 219 |
+
)
|
| 220 |
+
row = await cursor.fetchone()
|
| 221 |
+
if row is None:
|
| 222 |
+
return None
|
| 223 |
+
data = dict(row)
|
| 224 |
+
total = int(data["total_files"] or 0)
|
| 225 |
+
processed = int(data["processed_files"] or 0)
|
| 226 |
+
failed = int(data["failed_files"] or 0)
|
| 227 |
+
denom = total if total > 0 else 1
|
| 228 |
+
progress = int(min(100, max(0, round((processed + failed) / denom * 100))))
|
| 229 |
+
errors = json.loads(data.get("errors_json") or "[]")
|
| 230 |
+
if not isinstance(errors, list):
|
| 231 |
+
errors = [str(errors)]
|
| 232 |
+
errors_str = [str(e) for e in errors]
|
| 233 |
+
return JobStatusResponse(
|
| 234 |
+
job_id=str(data["job_id"]),
|
| 235 |
+
status=str(data["status"]),
|
| 236 |
+
total_files=total,
|
| 237 |
+
processed_files=processed,
|
| 238 |
+
failed_files=failed,
|
| 239 |
+
progress_percent=progress,
|
| 240 |
+
started_at=_parse_dt(data.get("started_at")),
|
| 241 |
+
completed_at=_parse_dt(data.get("completed_at")),
|
| 242 |
+
errors=errors_str,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
async def earliest_job_created_at_for_collection(db_path: str, collection_name: str) -> str | None:
|
| 247 |
+
"""Earliest ingest job timestamp for a collection (SQLite ``created_at`` string)."""
|
| 248 |
+
await init_jobs_db(db_path)
|
| 249 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 250 |
+
conn.row_factory = aiosqlite.Row
|
| 251 |
+
cursor = await conn.execute(
|
| 252 |
+
"""
|
| 253 |
+
SELECT MIN(created_at) AS earliest
|
| 254 |
+
FROM ingest_jobs
|
| 255 |
+
WHERE collection_name = ?
|
| 256 |
+
""",
|
| 257 |
+
(collection_name,),
|
| 258 |
+
)
|
| 259 |
+
row = await cursor.fetchone()
|
| 260 |
+
if row is None or row["earliest"] is None:
|
| 261 |
+
return None
|
| 262 |
+
return str(row["earliest"])
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
async def list_ingest_jobs(db_path: str, *, limit: int, offset: int) -> tuple[list[JobListItem], int]:
|
| 266 |
+
"""Recent jobs summary list and total count for pagination."""
|
| 267 |
+
await init_jobs_db(db_path)
|
| 268 |
+
async with aiosqlite.connect(db_path) as conn:
|
| 269 |
+
conn.row_factory = aiosqlite.Row
|
| 270 |
+
cur_total = await conn.execute("SELECT COUNT(*) AS c FROM ingest_jobs")
|
| 271 |
+
total_row = await cur_total.fetchone()
|
| 272 |
+
total = int(total_row["c"]) if total_row else 0
|
| 273 |
+
cursor = await conn.execute(
|
| 274 |
+
"""
|
| 275 |
+
SELECT job_id, status, total_files, completed_at
|
| 276 |
+
FROM ingest_jobs
|
| 277 |
+
ORDER BY datetime(updated_at) DESC, rowid DESC
|
| 278 |
+
LIMIT ? OFFSET ?
|
| 279 |
+
""",
|
| 280 |
+
(limit, offset),
|
| 281 |
+
)
|
| 282 |
+
rows = await cursor.fetchall()
|
| 283 |
+
items = [
|
| 284 |
+
JobListItem(
|
| 285 |
+
job_id=str(r["job_id"]),
|
| 286 |
+
status=str(r["status"]),
|
| 287 |
+
total_files=int(r["total_files"] or 0),
|
| 288 |
+
completed_at=_parse_dt(r["completed_at"]),
|
| 289 |
+
)
|
| 290 |
+
for r in rows
|
| 291 |
+
]
|
| 292 |
+
return items, total
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _parse_dt(value: object) -> datetime | None:
|
| 296 |
+
if value is None or value == "":
|
| 297 |
+
return None
|
| 298 |
+
s = str(value).strip()
|
| 299 |
+
if not s:
|
| 300 |
+
return None
|
| 301 |
+
if s.endswith("Z"):
|
| 302 |
+
s = s[:-1] + "+00:00"
|
| 303 |
+
try:
|
| 304 |
+
dt = datetime.fromisoformat(s)
|
| 305 |
+
if dt.tzinfo is None:
|
| 306 |
+
return dt.replace(tzinfo=timezone.utc)
|
| 307 |
+
return dt
|
| 308 |
+
except ValueError:
|
| 309 |
+
return None
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit UI for doc-audi-ai — talks to the FastAPI backend only."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
import streamlit as st
|
| 11 |
+
|
| 12 |
+
DEFAULT_API_BASE = os.environ.get("DOC_AUDI_API_BASE", "http://127.0.0.1:8000")
|
| 13 |
+
|
| 14 |
+
# httpx read timeout for Ask/Summarise: embeddings + LLM on CPU or cold Ollama often exceeds 10 minutes.
|
| 15 |
+
_HTTP_READ_TIMEOUT_DEFAULT_S = 3600.0
|
| 16 |
+
_HTTP_READ_TIMEOUT_MIN_S = 60.0
|
| 17 |
+
_HTTP_READ_TIMEOUT_MAX_S = 7200.0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _http_read_timeout_seconds() -> float:
|
| 21 |
+
raw = os.environ.get(
|
| 22 |
+
"DOC_AUDI_HTTP_READ_TIMEOUT",
|
| 23 |
+
str(int(_HTTP_READ_TIMEOUT_DEFAULT_S)),
|
| 24 |
+
)
|
| 25 |
+
try:
|
| 26 |
+
read_s = float(raw)
|
| 27 |
+
except ValueError:
|
| 28 |
+
read_s = _HTTP_READ_TIMEOUT_DEFAULT_S
|
| 29 |
+
return max(_HTTP_READ_TIMEOUT_MIN_S, min(read_s, _HTTP_READ_TIMEOUT_MAX_S))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _http_timeout() -> httpx.Timeout:
|
| 33 |
+
"""LLM + embeddings can exceed a few minutes on CPU or cold Ollama; Streamlit uses this, not Uvicorn."""
|
| 34 |
+
read_s = _http_read_timeout_seconds()
|
| 35 |
+
return httpx.Timeout(connect=20.0, read=read_s, write=120.0, pool=30.0)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _fmt_timeout_hint() -> str:
|
| 39 |
+
cap = int(_http_read_timeout_seconds())
|
| 40 |
+
lo, hi = int(_HTTP_READ_TIMEOUT_MIN_S), int(_HTTP_READ_TIMEOUT_MAX_S)
|
| 41 |
+
return (
|
| 42 |
+
f"The UI stops waiting after **{cap}s** per request (set **DOC_AUDI_HTTP_READ_TIMEOUT**, "
|
| 43 |
+
f"allowed **{lo}–{hi}** s). "
|
| 44 |
+
"Ensure `ollama serve` is running; cold models or CPU inference can exceed a few minutes."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _api_base() -> str:
|
| 49 |
+
"""Resolve API base URL. Whitespace-only sidebar input must not win over default (breaks httpx)."""
|
| 50 |
+
raw = st.session_state.get("api_base")
|
| 51 |
+
if raw is None:
|
| 52 |
+
return DEFAULT_API_BASE.rstrip("/")
|
| 53 |
+
s = str(raw).strip()
|
| 54 |
+
if not s:
|
| 55 |
+
return DEFAULT_API_BASE.rstrip("/")
|
| 56 |
+
return s.rstrip("/")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _client() -> httpx.Client:
|
| 60 |
+
return httpx.Client(base_url=_api_base(), timeout=_http_timeout())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _fmt_api_error(exc: httpx.HTTPStatusError) -> str:
|
| 64 |
+
try:
|
| 65 |
+
body = exc.response.json()
|
| 66 |
+
except Exception:
|
| 67 |
+
return f"HTTP {exc.response.status_code}: {exc.response.text[:500]}"
|
| 68 |
+
detail = body.get("detail")
|
| 69 |
+
if isinstance(detail, list):
|
| 70 |
+
parts = []
|
| 71 |
+
for item in detail:
|
| 72 |
+
if isinstance(item, dict):
|
| 73 |
+
loc = item.get("loc", ())
|
| 74 |
+
msg = item.get("msg", "")
|
| 75 |
+
parts.append(f"{'/'.join(str(x) for x in loc)}: {msg}")
|
| 76 |
+
else:
|
| 77 |
+
parts.append(str(item))
|
| 78 |
+
return f"HTTP {exc.response.status_code}: " + "; ".join(parts)
|
| 79 |
+
if detail is not None:
|
| 80 |
+
return f"HTTP {exc.response.status_code}: {detail}"
|
| 81 |
+
return f"HTTP {exc.response.status_code}"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _fmt_request_error(exc: httpx.RequestError) -> str:
|
| 85 |
+
"""Human-readable transport errors (connection, timeouts, TLS, etc.)."""
|
| 86 |
+
base = _api_base()
|
| 87 |
+
if isinstance(exc, httpx.ReadTimeout):
|
| 88 |
+
return (
|
| 89 |
+
f"**Read timeout** — `{base}` did not send a full response in time (embeddings/LLM can be slow). "
|
| 90 |
+
f"{_fmt_timeout_hint()}"
|
| 91 |
+
)
|
| 92 |
+
if isinstance(exc, httpx.ConnectTimeout):
|
| 93 |
+
return (
|
| 94 |
+
f"**Connect timeout** — could not open TCP to `{base}` in time. "
|
| 95 |
+
"Confirm the FastAPI process is listening (`uv run uvicorn api.main:app --host 0.0.0.0 --port 8000`)."
|
| 96 |
+
)
|
| 97 |
+
if isinstance(exc, httpx.ConnectError):
|
| 98 |
+
return (
|
| 99 |
+
f"**Connection failed** — nothing is accepting HTTP at `{base}`: {exc}. "
|
| 100 |
+
"Start the API, or fix **API base URL** / **`DOC_AUDI_API_BASE`** (use `http://127.0.0.1:8000` from the same machine, not `0.0.0.0`)."
|
| 101 |
+
)
|
| 102 |
+
if isinstance(exc, httpx.TimeoutException):
|
| 103 |
+
return f"**Timeout** ({type(exc).__name__}): {exc}. {_fmt_timeout_hint()}"
|
| 104 |
+
return f"**Request error** ({type(exc).__name__}): {exc}. Backend: `{base}`."
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _post_query_ask(
|
| 108 |
+
client: httpx.Client,
|
| 109 |
+
*,
|
| 110 |
+
question: str,
|
| 111 |
+
collection_name: str,
|
| 112 |
+
top_k: int = 5,
|
| 113 |
+
user_id: str = "anonymous",
|
| 114 |
+
) -> httpx.Response:
|
| 115 |
+
"""POST /query/ask (falls back to POST /query on older servers)."""
|
| 116 |
+
body: dict[str, object] = {
|
| 117 |
+
"question": question.strip(),
|
| 118 |
+
"collection_name": collection_name,
|
| 119 |
+
"top_k": top_k,
|
| 120 |
+
"user_id": user_id,
|
| 121 |
+
}
|
| 122 |
+
r = client.post("/query/ask", json=body)
|
| 123 |
+
if r.status_code == 404:
|
| 124 |
+
r = client.post("/query", json=body)
|
| 125 |
+
return r
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _get_audit_logs(
|
| 129 |
+
client: httpx.Client,
|
| 130 |
+
*,
|
| 131 |
+
limit: int,
|
| 132 |
+
offset: int,
|
| 133 |
+
user_id: str | None = None,
|
| 134 |
+
from_date: str | None = None,
|
| 135 |
+
to_date: str | None = None,
|
| 136 |
+
) -> httpx.Response:
|
| 137 |
+
params: dict[str, object] = {"limit": limit, "offset": offset}
|
| 138 |
+
if user_id:
|
| 139 |
+
params["user_id"] = user_id
|
| 140 |
+
if from_date:
|
| 141 |
+
params["from_date"] = from_date
|
| 142 |
+
if to_date:
|
| 143 |
+
params["to_date"] = to_date
|
| 144 |
+
r = client.get("/audit/logs", params=params)
|
| 145 |
+
if r.status_code == 404:
|
| 146 |
+
r = client.get("/audit", params=params)
|
| 147 |
+
return r
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _get_audit_event_detail(client: httpx.Client, event_id: str) -> httpx.Response:
|
| 151 |
+
r = client.get(f"/audit/logs/{event_id}")
|
| 152 |
+
if r.status_code == 404:
|
| 153 |
+
r = client.get(f"/audit/{event_id}")
|
| 154 |
+
return r
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _health_check() -> tuple[bool, str]:
|
| 158 |
+
try:
|
| 159 |
+
with _client() as c:
|
| 160 |
+
r = c.get("/health")
|
| 161 |
+
r.raise_for_status()
|
| 162 |
+
data = r.json()
|
| 163 |
+
return True, str(data)
|
| 164 |
+
except httpx.HTTPStatusError as e:
|
| 165 |
+
return False, _fmt_api_error(e)
|
| 166 |
+
except httpx.RequestError as e:
|
| 167 |
+
return False, _fmt_request_error(e)
|
| 168 |
+
except Exception as e:
|
| 169 |
+
return False, str(e)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def main() -> None:
|
| 173 |
+
st.set_page_config(page_title="doc-audi-ai", layout="wide")
|
| 174 |
+
if "api_base" not in st.session_state:
|
| 175 |
+
st.session_state.api_base = DEFAULT_API_BASE
|
| 176 |
+
|
| 177 |
+
st.title("doc-audi-ai")
|
| 178 |
+
st.caption("Ingest, query, and audit via the FastAPI backend.")
|
| 179 |
+
st.caption(f"Requests go to: `{_api_base()}`")
|
| 180 |
+
|
| 181 |
+
with st.sidebar:
|
| 182 |
+
st.subheader("Backend")
|
| 183 |
+
st.text_input(
|
| 184 |
+
"API base URL",
|
| 185 |
+
key="api_base",
|
| 186 |
+
placeholder=DEFAULT_API_BASE,
|
| 187 |
+
help=f"Default: {DEFAULT_API_BASE}. Clear the field to use the default.",
|
| 188 |
+
)
|
| 189 |
+
st.caption(
|
| 190 |
+
f"Ask/Summarise wait up to **{int(_http_read_timeout_seconds())}s** per request "
|
| 191 |
+
f"(env `DOC_AUDI_HTTP_READ_TIMEOUT`, range {int(_HTTP_READ_TIMEOUT_MIN_S)}–{int(_HTTP_READ_TIMEOUT_MAX_S)})."
|
| 192 |
+
)
|
| 193 |
+
if st.button("Test connection"):
|
| 194 |
+
ok, msg = _health_check()
|
| 195 |
+
if ok:
|
| 196 |
+
st.success(msg)
|
| 197 |
+
else:
|
| 198 |
+
st.error(msg)
|
| 199 |
+
|
| 200 |
+
tab_upload, tab_jobs, tab_ask, tab_sum, tab_audit = st.tabs(
|
| 201 |
+
["Upload", "Jobs", "Ask", "Summarise", "Audit"]
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
with tab_upload:
|
| 205 |
+
st.subheader("Upload document")
|
| 206 |
+
col_u1, col_u2 = st.columns(2)
|
| 207 |
+
with col_u1:
|
| 208 |
+
up_collection = st.text_input("Collection", value="default", key="up_col")
|
| 209 |
+
uploaded = st.file_uploader("PDF, TXT, or Markdown", type=["pdf", "txt", "md"], key="up_file")
|
| 210 |
+
with col_u2:
|
| 211 |
+
if st.button("Submit upload", key="btn_upload", disabled=uploaded is None):
|
| 212 |
+
if uploaded is None:
|
| 213 |
+
st.warning("Choose a file first.")
|
| 214 |
+
else:
|
| 215 |
+
try:
|
| 216 |
+
files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")}
|
| 217 |
+
data = {"collection_name": up_collection}
|
| 218 |
+
with _client() as c:
|
| 219 |
+
r = c.post("/ingest/upload", files=files, data=data)
|
| 220 |
+
r.raise_for_status()
|
| 221 |
+
out = r.json()
|
| 222 |
+
st.success(out.get("message", "Queued"))
|
| 223 |
+
st.json(out)
|
| 224 |
+
if out.get("job_id"):
|
| 225 |
+
st.session_state["last_job_id"] = out["job_id"]
|
| 226 |
+
except httpx.HTTPStatusError as e:
|
| 227 |
+
st.error(_fmt_api_error(e))
|
| 228 |
+
except httpx.RequestError as e:
|
| 229 |
+
st.error(_fmt_request_error(e))
|
| 230 |
+
except Exception as e:
|
| 231 |
+
st.exception(e)
|
| 232 |
+
|
| 233 |
+
st.subheader("Ingest from URL")
|
| 234 |
+
url_col = st.columns([3, 1])
|
| 235 |
+
with url_col[0]:
|
| 236 |
+
ingest_url = st.text_input("Document URL (http/https)", key="ingest_url")
|
| 237 |
+
with url_col[1]:
|
| 238 |
+
url_collection = st.text_input("Collection", value="default", key="url_col")
|
| 239 |
+
if st.button("Queue URL ingest", key="btn_url"):
|
| 240 |
+
if not ingest_url.strip():
|
| 241 |
+
st.warning("Enter a URL.")
|
| 242 |
+
else:
|
| 243 |
+
try:
|
| 244 |
+
with _client() as c:
|
| 245 |
+
r = c.post(
|
| 246 |
+
"/ingest/url",
|
| 247 |
+
json={"urls": [ingest_url.strip()], "collection_name": url_collection},
|
| 248 |
+
)
|
| 249 |
+
r.raise_for_status()
|
| 250 |
+
out = r.json()
|
| 251 |
+
st.success(out.get("message", "Queued"))
|
| 252 |
+
st.json(out)
|
| 253 |
+
if out.get("job_id"):
|
| 254 |
+
st.session_state["last_job_id"] = out["job_id"]
|
| 255 |
+
except httpx.HTTPStatusError as e:
|
| 256 |
+
st.error(_fmt_api_error(e))
|
| 257 |
+
except httpx.RequestError as e:
|
| 258 |
+
st.error(_fmt_request_error(e))
|
| 259 |
+
except Exception as e:
|
| 260 |
+
st.exception(e)
|
| 261 |
+
|
| 262 |
+
st.subheader("Collections")
|
| 263 |
+
if st.button("Refresh collections", key="btn_collections"):
|
| 264 |
+
try:
|
| 265 |
+
with _client() as c:
|
| 266 |
+
r = c.get("/ingest/collections")
|
| 267 |
+
r.raise_for_status()
|
| 268 |
+
cols = r.json()
|
| 269 |
+
rows = cols.get("collections", [])
|
| 270 |
+
st.write(f"{cols.get('total', len(rows))} collection(s).")
|
| 271 |
+
if rows:
|
| 272 |
+
st.dataframe(rows, hide_index=True, use_container_width=True)
|
| 273 |
+
else:
|
| 274 |
+
st.info("No collections yet.")
|
| 275 |
+
except httpx.HTTPStatusError as e:
|
| 276 |
+
st.error(_fmt_api_error(e))
|
| 277 |
+
except httpx.RequestError as e:
|
| 278 |
+
st.error(_fmt_request_error(e))
|
| 279 |
+
except Exception as e:
|
| 280 |
+
st.exception(e)
|
| 281 |
+
|
| 282 |
+
del_name = st.text_input("Delete collection name (optional)", key="del_col")
|
| 283 |
+
if st.button("Delete collection", key="btn_del_col"):
|
| 284 |
+
if not del_name.strip():
|
| 285 |
+
st.warning("Enter a collection name.")
|
| 286 |
+
else:
|
| 287 |
+
try:
|
| 288 |
+
with _client() as c:
|
| 289 |
+
r = c.delete(f"/ingest/collection/{del_name.strip()}")
|
| 290 |
+
r.raise_for_status()
|
| 291 |
+
del_body = r.json()
|
| 292 |
+
st.success(del_body.get("message", "Deleted"))
|
| 293 |
+
if "documents_removed" in del_body:
|
| 294 |
+
st.caption(f"Documents removed: **{del_body['documents_removed']}**")
|
| 295 |
+
except httpx.HTTPStatusError as e:
|
| 296 |
+
st.error(_fmt_api_error(e))
|
| 297 |
+
except httpx.RequestError as e:
|
| 298 |
+
st.error(_fmt_request_error(e))
|
| 299 |
+
except Exception as e:
|
| 300 |
+
st.exception(e)
|
| 301 |
+
|
| 302 |
+
with tab_jobs:
|
| 303 |
+
st.subheader("Job list")
|
| 304 |
+
j1, j2 = st.columns(2)
|
| 305 |
+
with j1:
|
| 306 |
+
j_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="j_lim")
|
| 307 |
+
with j2:
|
| 308 |
+
j_offset = st.number_input("Offset", min_value=0, value=0, key="j_off")
|
| 309 |
+
if st.button("List jobs", key="btn_jobs"):
|
| 310 |
+
try:
|
| 311 |
+
with _client() as c:
|
| 312 |
+
r = c.get("/jobs", params={"limit": int(j_limit), "offset": int(j_offset)})
|
| 313 |
+
r.raise_for_status()
|
| 314 |
+
payload = r.json()
|
| 315 |
+
jobs: list[dict[str, Any]] = payload.get("jobs", [])
|
| 316 |
+
st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**")
|
| 317 |
+
if jobs:
|
| 318 |
+
st.dataframe(jobs, hide_index=True, use_container_width=True)
|
| 319 |
+
else:
|
| 320 |
+
st.info("No jobs in this window.")
|
| 321 |
+
except httpx.HTTPStatusError as e:
|
| 322 |
+
st.error(_fmt_api_error(e))
|
| 323 |
+
except httpx.RequestError as e:
|
| 324 |
+
st.error(_fmt_request_error(e))
|
| 325 |
+
except Exception as e:
|
| 326 |
+
st.exception(e)
|
| 327 |
+
|
| 328 |
+
st.subheader("Job detail")
|
| 329 |
+
default_job = st.session_state.get("last_job_id", "")
|
| 330 |
+
job_id = st.text_input("Job ID", value=default_job, key="job_id_in")
|
| 331 |
+
c1, c2 = st.columns(2)
|
| 332 |
+
with c1:
|
| 333 |
+
fetch_job = st.button("Fetch job", key="btn_job_one")
|
| 334 |
+
with c2:
|
| 335 |
+
poll_job = st.button("Poll until completed/failed", key="btn_job_poll")
|
| 336 |
+
|
| 337 |
+
if fetch_job and job_id.strip():
|
| 338 |
+
try:
|
| 339 |
+
with _client() as c:
|
| 340 |
+
r = c.get(f"/jobs/{job_id.strip()}")
|
| 341 |
+
r.raise_for_status()
|
| 342 |
+
detail = r.json()
|
| 343 |
+
st.json(detail)
|
| 344 |
+
except httpx.HTTPStatusError as e:
|
| 345 |
+
st.error(_fmt_api_error(e))
|
| 346 |
+
except httpx.RequestError as e:
|
| 347 |
+
st.error(_fmt_request_error(e))
|
| 348 |
+
except Exception as e:
|
| 349 |
+
st.exception(e)
|
| 350 |
+
|
| 351 |
+
if poll_job and job_id.strip():
|
| 352 |
+
status_ph = st.empty()
|
| 353 |
+
try:
|
| 354 |
+
with _client() as c:
|
| 355 |
+
for i in range(120):
|
| 356 |
+
r = c.get(f"/jobs/{job_id.strip()}")
|
| 357 |
+
r.raise_for_status()
|
| 358 |
+
body = r.json()
|
| 359 |
+
st_ = body.get("status", "")
|
| 360 |
+
status_ph.write(f"Poll {i + 1}: **{st_}** — {body.get('progress_percent', 0)}%")
|
| 361 |
+
if st_ in ("completed", "failed"):
|
| 362 |
+
st.json(body)
|
| 363 |
+
break
|
| 364 |
+
time.sleep(1)
|
| 365 |
+
else:
|
| 366 |
+
status_ph.write("Stopped after 120 attempts (~2 min).")
|
| 367 |
+
st.json(body)
|
| 368 |
+
except httpx.HTTPStatusError as e:
|
| 369 |
+
st.error(_fmt_api_error(e))
|
| 370 |
+
except httpx.RequestError as e:
|
| 371 |
+
st.error(_fmt_request_error(e))
|
| 372 |
+
except Exception as e:
|
| 373 |
+
st.exception(e)
|
| 374 |
+
|
| 375 |
+
with tab_ask:
|
| 376 |
+
st.subheader("Ask a question")
|
| 377 |
+
q_col = st.text_input("Collection", value="default", key="ask_col")
|
| 378 |
+
question = st.text_area("Question", height=120, key="ask_q")
|
| 379 |
+
if st.button("Ask", key="btn_ask"):
|
| 380 |
+
if not question.strip():
|
| 381 |
+
st.warning("Enter a question.")
|
| 382 |
+
else:
|
| 383 |
+
try:
|
| 384 |
+
with st.spinner(
|
| 385 |
+
"Calling the API (embeddings + LLM can take several minutes on a slow machine; "
|
| 386 |
+
"ensure Ollama is running). Timeout is controlled by DOC_AUDI_HTTP_READ_TIMEOUT…"
|
| 387 |
+
):
|
| 388 |
+
with _client() as c:
|
| 389 |
+
r = _post_query_ask(
|
| 390 |
+
c,
|
| 391 |
+
question=question,
|
| 392 |
+
collection_name=q_col,
|
| 393 |
+
)
|
| 394 |
+
r.raise_for_status()
|
| 395 |
+
ans = r.json()
|
| 396 |
+
st.success(f"Query id: `{ans.get('query_id', '')}`")
|
| 397 |
+
if ans.get("answer"):
|
| 398 |
+
st.markdown("### Answer")
|
| 399 |
+
st.markdown(ans["answer"])
|
| 400 |
+
else:
|
| 401 |
+
st.warning(
|
| 402 |
+
"The API returned no **answer** text. "
|
| 403 |
+
"Check the collection has ingested chunks, LLM env, and expand **Raw response** below."
|
| 404 |
+
)
|
| 405 |
+
src = ans.get("sources") or []
|
| 406 |
+
if src:
|
| 407 |
+
with st.expander(f"Sources ({len(src)})"):
|
| 408 |
+
st.json(src)
|
| 409 |
+
else:
|
| 410 |
+
st.caption("No sources in this response (empty retrieval or model returned nothing).")
|
| 411 |
+
with st.expander("Raw response (debug)"):
|
| 412 |
+
st.json(ans)
|
| 413 |
+
except httpx.HTTPStatusError as e:
|
| 414 |
+
st.error(_fmt_api_error(e))
|
| 415 |
+
except httpx.RequestError as e:
|
| 416 |
+
st.error(_fmt_request_error(e))
|
| 417 |
+
except Exception as e:
|
| 418 |
+
st.exception(e)
|
| 419 |
+
|
| 420 |
+
with tab_sum:
|
| 421 |
+
st.subheader("Summarise collection")
|
| 422 |
+
s_col = st.text_input("Collection", value="default", key="sum_col")
|
| 423 |
+
focus = st.text_input("Optional focus / angle", value="", key="sum_focus")
|
| 424 |
+
if st.button("Summarise", key="btn_sum"):
|
| 425 |
+
try:
|
| 426 |
+
body: dict[str, Any] = {"collection_name": s_col}
|
| 427 |
+
if focus.strip():
|
| 428 |
+
body["focus"] = focus.strip()
|
| 429 |
+
with st.spinner("Calling summarise (can take 1–2 minutes on a cold model)…"):
|
| 430 |
+
with _client() as c:
|
| 431 |
+
r = c.post("/query/summarise", json=body)
|
| 432 |
+
r.raise_for_status()
|
| 433 |
+
ans = r.json()
|
| 434 |
+
st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**")
|
| 435 |
+
summary_text = ans.get("summary") or ans.get("answer")
|
| 436 |
+
if summary_text:
|
| 437 |
+
st.markdown("### Summary")
|
| 438 |
+
st.markdown(summary_text)
|
| 439 |
+
else:
|
| 440 |
+
st.warning("No summary text in the response; see **Raw response** below.")
|
| 441 |
+
src = ans.get("sources") or []
|
| 442 |
+
if src:
|
| 443 |
+
with st.expander(f"Sources ({len(src)})"):
|
| 444 |
+
st.json(src)
|
| 445 |
+
with st.expander("Raw response (debug)"):
|
| 446 |
+
st.json(ans)
|
| 447 |
+
except httpx.HTTPStatusError as e:
|
| 448 |
+
st.error(_fmt_api_error(e))
|
| 449 |
+
except httpx.RequestError as e:
|
| 450 |
+
st.error(_fmt_request_error(e))
|
| 451 |
+
except Exception as e:
|
| 452 |
+
st.exception(e)
|
| 453 |
+
|
| 454 |
+
with tab_audit:
|
| 455 |
+
st.subheader("Audit log")
|
| 456 |
+
a1, a2 = st.columns(2)
|
| 457 |
+
with a1:
|
| 458 |
+
a_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="a_lim")
|
| 459 |
+
with a2:
|
| 460 |
+
a_offset = st.number_input("Offset", min_value=0, value=0, key="a_off")
|
| 461 |
+
if st.button("List audit events", key="btn_audit_list"):
|
| 462 |
+
try:
|
| 463 |
+
with _client() as c:
|
| 464 |
+
r = _get_audit_logs(
|
| 465 |
+
c,
|
| 466 |
+
limit=int(a_limit),
|
| 467 |
+
offset=int(a_offset),
|
| 468 |
+
)
|
| 469 |
+
r.raise_for_status()
|
| 470 |
+
payload = r.json()
|
| 471 |
+
events = payload.get("logs", payload.get("events", []))
|
| 472 |
+
st.caption(f"Total matching: **{payload.get('total', len(events))}**")
|
| 473 |
+
if events:
|
| 474 |
+
st.dataframe(events, hide_index=True, use_container_width=True)
|
| 475 |
+
ids = [
|
| 476 |
+
e.get("query_id") or e.get("event_id")
|
| 477 |
+
for e in events
|
| 478 |
+
if isinstance(e, dict) and (e.get("query_id") or e.get("event_id"))
|
| 479 |
+
]
|
| 480 |
+
if ids:
|
| 481 |
+
st.session_state["_audit_ids"] = ids
|
| 482 |
+
else:
|
| 483 |
+
st.info("No audit events.")
|
| 484 |
+
except httpx.HTTPStatusError as e:
|
| 485 |
+
st.error(_fmt_api_error(e))
|
| 486 |
+
except httpx.RequestError as e:
|
| 487 |
+
st.error(_fmt_request_error(e))
|
| 488 |
+
except Exception as e:
|
| 489 |
+
st.exception(e)
|
| 490 |
+
|
| 491 |
+
st.subheader("Audit event detail")
|
| 492 |
+
ids_for_select = st.session_state.get("_audit_ids", [])
|
| 493 |
+
pick = ""
|
| 494 |
+
if ids_for_select:
|
| 495 |
+
pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
|
| 496 |
+
manual_id = st.text_input("Or enter query / event ID", key="audit_manual")
|
| 497 |
+
ev_id = (manual_id.strip() or (pick or "").strip()).strip()
|
| 498 |
+
if st.button("Load detail", key="btn_audit_detail") and ev_id:
|
| 499 |
+
try:
|
| 500 |
+
with _client() as c:
|
| 501 |
+
r = _get_audit_event_detail(c, ev_id)
|
| 502 |
+
r.raise_for_status()
|
| 503 |
+
st.json(r.json())
|
| 504 |
+
except httpx.HTTPStatusError as e:
|
| 505 |
+
st.error(_fmt_api_error(e))
|
| 506 |
+
except httpx.RequestError as e:
|
| 507 |
+
st.error(_fmt_request_error(e))
|
| 508 |
+
except Exception as e:
|
| 509 |
+
st.exception(e)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
main()
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest fixtures: isolated temp DB/Chroma paths and a patched FastAPI test client."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from fastapi.testclient import TestClient
|
| 8 |
+
|
| 9 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 10 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 11 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 12 |
+
|
| 13 |
+
from api.config import Settings
|
| 14 |
+
from api.main import app
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def test_settings(tmp_path) -> Settings:
|
| 19 |
+
return Settings(
|
| 20 |
+
llm_provider="ollama",
|
| 21 |
+
chroma_persist_directory=str(tmp_path / "chroma"),
|
| 22 |
+
audit_db_path=str(tmp_path / "audit.db"),
|
| 23 |
+
jobs_db_path=str(tmp_path / "jobs.db"),
|
| 24 |
+
max_file_size_mb=1,
|
| 25 |
+
top_k_results=3,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def settings(test_settings) -> Settings:
|
| 31 |
+
"""Alias for audit tests that name the fixture `settings`."""
|
| 32 |
+
return test_settings
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.fixture
|
| 36 |
+
def client(test_settings, monkeypatch):
|
| 37 |
+
monkeypatch.setattr("api.main.get_settings", lambda: test_settings)
|
| 38 |
+
for route_mod in ("ingest", "query", "audit", "jobs"):
|
| 39 |
+
monkeypatch.setattr(f"api.routes.{route_mod}.get_settings", lambda ts=test_settings: ts)
|
| 40 |
+
with TestClient(app) as test_client:
|
| 41 |
+
yield test_client
|
tests/test_audit.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for audit log list, detail, filters, and post-query persistence."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from unittest.mock import AsyncMock
|
| 5 |
+
from uuid import uuid4
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
from fastapi.testclient import TestClient
|
| 9 |
+
|
| 10 |
+
from api.config import Settings
|
| 11 |
+
from api.main import app
|
| 12 |
+
from models.responses import SourceCitation
|
| 13 |
+
from rag.retriever import RetrievedChunk
|
| 14 |
+
from storage.audit_store import persist_query_audit
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _seed_audit(settings: Settings, question: str = "What are key risks?", user_id: str = "analyst_001") -> str:
|
| 18 |
+
query_id = str(uuid4())
|
| 19 |
+
asyncio.run(
|
| 20 |
+
persist_query_audit(
|
| 21 |
+
settings.audit_db_path,
|
| 22 |
+
query_id=query_id,
|
| 23 |
+
action="query",
|
| 24 |
+
user_id=user_id,
|
| 25 |
+
question=question,
|
| 26 |
+
collection_name="default",
|
| 27 |
+
answer="Grounded answer text for audit trail.",
|
| 28 |
+
sources=[
|
| 29 |
+
SourceCitation(
|
| 30 |
+
document_name="report.pdf",
|
| 31 |
+
page_number=3,
|
| 32 |
+
chunk_text="Risk disclosure excerpt.",
|
| 33 |
+
relevance_score=0.9,
|
| 34 |
+
)
|
| 35 |
+
],
|
| 36 |
+
model_used="ollama:llama3.1:8b",
|
| 37 |
+
tokens_used=120,
|
| 38 |
+
response_time_ms=50,
|
| 39 |
+
kind="ask",
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
return query_id
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_audit_logs_and_detail_success(client, settings):
|
| 46 |
+
query_id = _seed_audit(settings)
|
| 47 |
+
|
| 48 |
+
list_response = client.get("/audit/logs?limit=10&offset=0")
|
| 49 |
+
assert list_response.status_code == 200
|
| 50 |
+
body = list_response.json()
|
| 51 |
+
assert "logs" in body
|
| 52 |
+
assert body["total"] >= 1
|
| 53 |
+
assert any(entry["query_id"] == query_id for entry in body["logs"])
|
| 54 |
+
|
| 55 |
+
detail_response = client.get(f"/audit/logs/{query_id}")
|
| 56 |
+
assert detail_response.status_code == 200
|
| 57 |
+
detail = detail_response.json()
|
| 58 |
+
assert detail["query_id"] == query_id
|
| 59 |
+
assert detail["question"] == "What are key risks?"
|
| 60 |
+
assert detail["full_answer"] == "Grounded answer text for audit trail."
|
| 61 |
+
assert len(detail["sources"]) == 1
|
| 62 |
+
assert detail["sources"][0]["document_name"] == "report.pdf"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_audit_logs_filter_by_user_id(client, settings):
|
| 66 |
+
q1 = _seed_audit(settings, question="Q one", user_id="user_a")
|
| 67 |
+
_seed_audit(settings, question="Q two", user_id="user_b")
|
| 68 |
+
|
| 69 |
+
r = client.get("/audit/logs", params={"user_id": "user_a", "limit": 50, "offset": 0})
|
| 70 |
+
assert r.status_code == 200
|
| 71 |
+
body = r.json()
|
| 72 |
+
ids = {e["query_id"] for e in body["logs"]}
|
| 73 |
+
assert q1 in ids
|
| 74 |
+
assert all(e["user_id"] == "user_a" for e in body["logs"])
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_audit_logs_filter_by_from_date(client, settings):
|
| 78 |
+
query_id = str(uuid4())
|
| 79 |
+
asyncio.run(
|
| 80 |
+
persist_query_audit(
|
| 81 |
+
settings.audit_db_path,
|
| 82 |
+
query_id=query_id,
|
| 83 |
+
action="query",
|
| 84 |
+
user_id="u",
|
| 85 |
+
question="Future dated row",
|
| 86 |
+
collection_name="default",
|
| 87 |
+
answer="A",
|
| 88 |
+
sources=[],
|
| 89 |
+
model_used="m",
|
| 90 |
+
tokens_used=0,
|
| 91 |
+
response_time_ms=1,
|
| 92 |
+
kind="ask",
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
r = client.get("/audit/logs", params={"from_date": "2099-01-01T00:00:00Z", "limit": 50, "offset": 0})
|
| 96 |
+
assert r.status_code == 200
|
| 97 |
+
body = r.json()
|
| 98 |
+
assert query_id not in {e["query_id"] for e in body["logs"]}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_audit_logs_filter_by_to_date(client, settings):
|
| 102 |
+
"""Spec: date filtering on /audit/logs (upper bound)."""
|
| 103 |
+
query_id = str(uuid4())
|
| 104 |
+
asyncio.run(
|
| 105 |
+
persist_query_audit(
|
| 106 |
+
settings.audit_db_path,
|
| 107 |
+
query_id=query_id,
|
| 108 |
+
action="query",
|
| 109 |
+
user_id="u",
|
| 110 |
+
question="Recent row",
|
| 111 |
+
collection_name="default",
|
| 112 |
+
answer="B",
|
| 113 |
+
sources=[],
|
| 114 |
+
model_used="m",
|
| 115 |
+
tokens_used=0,
|
| 116 |
+
response_time_ms=1,
|
| 117 |
+
kind="ask",
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
r = client.get("/audit/logs", params={"to_date": "2000-01-01T00:00:00Z", "limit": 50, "offset": 0})
|
| 121 |
+
assert r.status_code == 200
|
| 122 |
+
body = r.json()
|
| 123 |
+
assert query_id not in {e["query_id"] for e in body["logs"]}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def test_ask_is_logged_after_query_ask(client, monkeypatch):
|
| 127 |
+
"""Spec: ask is logged after POST /query/ask."""
|
| 128 |
+
chunks = [
|
| 129 |
+
RetrievedChunk(
|
| 130 |
+
text="Audit trail test chunk.",
|
| 131 |
+
score=0.9,
|
| 132 |
+
source="audit-test.txt",
|
| 133 |
+
page=1,
|
| 134 |
+
chunk_index=0,
|
| 135 |
+
)
|
| 136 |
+
]
|
| 137 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 138 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 139 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 140 |
+
monkeypatch.setattr(
|
| 141 |
+
"api.routes.query.answer_with_grounding",
|
| 142 |
+
lambda *_: ("Answer stored in audit.", 11),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
ask = client.post(
|
| 146 |
+
"/query/ask",
|
| 147 |
+
json={
|
| 148 |
+
"question": "What should appear in the audit log?",
|
| 149 |
+
"collection_name": "default",
|
| 150 |
+
"user_id": "audit_user",
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
assert ask.status_code == 200
|
| 154 |
+
query_id = ask.json()["query_id"]
|
| 155 |
+
|
| 156 |
+
detail = client.get(f"/audit/logs/{query_id}")
|
| 157 |
+
assert detail.status_code == 200
|
| 158 |
+
body = detail.json()
|
| 159 |
+
assert body["user_id"] == "audit_user"
|
| 160 |
+
assert body["full_answer"] == "Answer stored in audit."
|
| 161 |
+
assert body["question"] == "What should appear in the audit log?"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_summarise_is_logged_after_query_summarise(client, monkeypatch):
|
| 165 |
+
"""Spec: summarise is logged after POST /query/summarise."""
|
| 166 |
+
chunks = [
|
| 167 |
+
RetrievedChunk(
|
| 168 |
+
text="Summary source chunk.",
|
| 169 |
+
score=0.85,
|
| 170 |
+
source="summary.md",
|
| 171 |
+
page=2,
|
| 172 |
+
chunk_index=0,
|
| 173 |
+
)
|
| 174 |
+
]
|
| 175 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 176 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 177 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 178 |
+
monkeypatch.setattr(
|
| 179 |
+
"api.routes.query.summarise_with_grounding",
|
| 180 |
+
lambda *_, **__: ("Collection summary for audit.", 7),
|
| 181 |
+
)
|
| 182 |
+
monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 2)
|
| 183 |
+
|
| 184 |
+
summarise = client.post(
|
| 185 |
+
"/query/summarise",
|
| 186 |
+
json={"collection_name": "default", "focus": "key themes", "user_id": "sum_user"},
|
| 187 |
+
)
|
| 188 |
+
assert summarise.status_code == 200
|
| 189 |
+
query_id = summarise.json()["query_id"]
|
| 190 |
+
|
| 191 |
+
detail = client.get(f"/audit/logs/{query_id}")
|
| 192 |
+
assert detail.status_code == 200
|
| 193 |
+
assert detail.json()["full_answer"] == "Collection summary for audit."
|
| 194 |
+
assert detail.json()["user_id"] == "sum_user"
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def test_audit_logs_validation_error_for_bad_limit(client):
|
| 198 |
+
response = client.get("/audit/logs?limit=0&offset=0")
|
| 199 |
+
assert response.status_code == 422
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def test_audit_detail_not_found(client):
|
| 203 |
+
response = client.get("/audit/logs/does-not-exist")
|
| 204 |
+
assert response.status_code == 404
|
| 205 |
+
assert "not found" in response.json()["detail"].lower()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def test_audit_logs_returns_500_on_store_failure(settings, monkeypatch):
|
| 209 |
+
monkeypatch.setattr("api.main.get_settings", lambda: settings)
|
| 210 |
+
monkeypatch.setattr("api.routes.audit.get_settings", lambda: settings)
|
| 211 |
+
monkeypatch.setattr(
|
| 212 |
+
"api.routes.audit.list_audit_events",
|
| 213 |
+
AsyncMock(side_effect=RuntimeError("audit store failure")),
|
| 214 |
+
)
|
| 215 |
+
with TestClient(app, raise_server_exceptions=False) as test_client:
|
| 216 |
+
response = test_client.get("/audit/logs")
|
| 217 |
+
|
| 218 |
+
assert response.status_code == 500
|
tests/test_config.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Settings behaviour for Hugging Face Spaces and Hub tokens."""
|
| 2 |
+
|
| 3 |
+
from api.config import Settings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_space_id_without_llm_provider_env_uses_huggingface_and_hf_token(monkeypatch):
|
| 7 |
+
monkeypatch.setenv("SPACE_ID", "author/repo")
|
| 8 |
+
monkeypatch.delenv("LLM_PROVIDER", raising=False)
|
| 9 |
+
monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
|
| 10 |
+
monkeypatch.setenv("HF_TOKEN", "hf_test_token")
|
| 11 |
+
s = Settings(_env_file=None)
|
| 12 |
+
assert s.llm_provider == "huggingface"
|
| 13 |
+
assert s.huggingface_api_key == "hf_test_token"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_space_id_respects_explicit_llm_provider_ollama(monkeypatch):
|
| 17 |
+
monkeypatch.setenv("SPACE_ID", "author/repo")
|
| 18 |
+
monkeypatch.setenv("LLM_PROVIDER", "ollama")
|
| 19 |
+
monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
|
| 20 |
+
s = Settings(_env_file=None)
|
| 21 |
+
assert s.llm_provider == "ollama"
|
tests/test_health.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Smoke test for the liveness endpoint."""
|
| 2 |
+
|
| 3 |
+
def test_health_returns_ok(client):
|
| 4 |
+
response = client.get("/health")
|
| 5 |
+
assert response.status_code == 200
|
| 6 |
+
body = response.json()
|
| 7 |
+
assert body["status"] == "ok"
|
| 8 |
+
assert "app" in body
|
| 9 |
+
assert "version" in body
|
tests/test_ingest.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for ``/ingest`` upload, URL ingest, and collection management."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from unittest.mock import AsyncMock
|
| 5 |
+
|
| 6 |
+
from api.routes import ingest as ingest_route
|
| 7 |
+
from storage.job_store import create_ingest_job, mark_job_processing
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_upload_queues_job_success(client, monkeypatch):
|
| 11 |
+
monkeypatch.setattr("api.routes.ingest.create_ingest_job", AsyncMock(return_value="job-123"))
|
| 12 |
+
monkeypatch.setattr("api.routes.ingest.run_ingest_job", AsyncMock(return_value=None))
|
| 13 |
+
|
| 14 |
+
response = client.post(
|
| 15 |
+
"/ingest/upload",
|
| 16 |
+
data={"collection_name": "default"},
|
| 17 |
+
files=[("files", ("sample.txt", b"hello world", "text/plain"))],
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
assert response.status_code == 200
|
| 21 |
+
body = response.json()
|
| 22 |
+
assert body["status"] == "queued"
|
| 23 |
+
assert body["job_id"] == "job-123"
|
| 24 |
+
assert body["total_files"] == 1
|
| 25 |
+
assert body["filenames"] == ["sample.txt"]
|
| 26 |
+
assert "Poll /jobs/job-123" in body["message"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_upload_rejects_unsupported_extension(client):
|
| 30 |
+
response = client.post(
|
| 31 |
+
"/ingest/upload",
|
| 32 |
+
data={"collection_name": "default"},
|
| 33 |
+
files=[("files", ("sample.csv", b"a,b\n1,2", "text/csv"))],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
assert response.status_code == 400
|
| 37 |
+
assert "Unsupported file type" in response.json()["detail"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_upload_rejects_oversized_file(client):
|
| 41 |
+
oversized = b"x" * (2 * 1024 * 1024)
|
| 42 |
+
response = client.post(
|
| 43 |
+
"/ingest/upload",
|
| 44 |
+
data={"collection_name": "default"},
|
| 45 |
+
files=[("files", ("large.txt", oversized, "text/plain"))],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
assert response.status_code == 413
|
| 49 |
+
assert "too large" in response.json()["detail"].lower()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_upload_returns_500_on_job_creation_error(client, monkeypatch):
|
| 53 |
+
monkeypatch.setattr(
|
| 54 |
+
"api.routes.ingest.create_ingest_job",
|
| 55 |
+
AsyncMock(side_effect=RuntimeError("job store unavailable")),
|
| 56 |
+
)
|
| 57 |
+
monkeypatch.setattr("api.routes.ingest.run_ingest_job", AsyncMock(return_value=None))
|
| 58 |
+
|
| 59 |
+
response = client.post(
|
| 60 |
+
"/ingest/upload",
|
| 61 |
+
data={"collection_name": "default"},
|
| 62 |
+
files=[("files", ("sample.txt", b"hello", "text/plain"))],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
assert response.status_code == 500
|
| 66 |
+
assert "job store unavailable" in response.json()["detail"]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_download_request_headers_sec_compliant():
|
| 70 |
+
headers = ingest_route._download_request_headers("DocuAudit AI test@example.com")
|
| 71 |
+
assert headers["User-Agent"] == "DocuAudit AI test@example.com"
|
| 72 |
+
assert headers["Accept-Encoding"] == "gzip, deflate"
|
| 73 |
+
assert "application/pdf" in headers["Accept"]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_ingest_url_rejects_non_http_scheme(client, monkeypatch):
|
| 77 |
+
monkeypatch.setattr(
|
| 78 |
+
"api.routes.ingest._download_url_to_temp",
|
| 79 |
+
AsyncMock(
|
| 80 |
+
side_effect=ingest_route.HTTPException(status_code=400, detail="Only http and https URLs are supported.")
|
| 81 |
+
),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
response = client.post(
|
| 85 |
+
"/ingest/url",
|
| 86 |
+
json={"urls": ["https://example.com/file.txt"], "collection_name": "default"},
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
assert response.status_code == 400
|
| 90 |
+
assert "http and https" in response.json()["detail"]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_upload_pdf_queues_job_with_job_id(client, monkeypatch):
|
| 94 |
+
"""Spec: single PDF upload returns job_id."""
|
| 95 |
+
monkeypatch.setattr("api.routes.ingest.create_ingest_job", AsyncMock(return_value="pdf-job-99"))
|
| 96 |
+
monkeypatch.setattr("api.routes.ingest.run_ingest_job", AsyncMock(return_value=None))
|
| 97 |
+
|
| 98 |
+
response = client.post(
|
| 99 |
+
"/ingest/upload",
|
| 100 |
+
data={"collection_name": "default"},
|
| 101 |
+
files=[("files", ("brief.pdf", b"%PDF-1.4 minimal", "application/pdf"))],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
assert response.status_code == 200
|
| 105 |
+
body = response.json()
|
| 106 |
+
assert body["job_id"] == "pdf-job-99"
|
| 107 |
+
assert body["filenames"] == ["brief.pdf"]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_list_collections_backfills_created_at_from_jobs(client, test_settings, monkeypatch):
|
| 111 |
+
monkeypatch.setattr(
|
| 112 |
+
"api.routes.ingest.list_collection_names",
|
| 113 |
+
lambda *_: ["default"],
|
| 114 |
+
)
|
| 115 |
+
monkeypatch.setattr("api.routes.ingest.collection_document_count", lambda *_: 3)
|
| 116 |
+
monkeypatch.setattr("api.routes.ingest.collection_created_at", lambda *_: None)
|
| 117 |
+
monkeypatch.setattr(
|
| 118 |
+
"api.routes.ingest.earliest_job_created_at_for_collection",
|
| 119 |
+
AsyncMock(return_value="2026-05-21 07:05:38"),
|
| 120 |
+
)
|
| 121 |
+
monkeypatch.setattr(
|
| 122 |
+
"api.routes.ingest.ensure_collection_created_at",
|
| 123 |
+
lambda *_a, **_k: "2026-05-21T07:05:38Z",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
response = client.get("/ingest/collections")
|
| 127 |
+
assert response.status_code == 200
|
| 128 |
+
body = response.json()
|
| 129 |
+
assert body["total"] == 1
|
| 130 |
+
assert body["collections"][0]["name"] == "default"
|
| 131 |
+
assert body["collections"][0]["document_count"] == 3
|
| 132 |
+
assert body["collections"][0]["created_at"] is not None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def test_job_status_polling_after_real_job_create(client, test_settings):
|
| 136 |
+
"""Spec: job status polling returns correct structure."""
|
| 137 |
+
job_id = asyncio.run(
|
| 138 |
+
create_ingest_job(
|
| 139 |
+
test_settings.jobs_db_path,
|
| 140 |
+
collection_name="default",
|
| 141 |
+
filenames=["sample.txt"],
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
asyncio.run(mark_job_processing(test_settings.jobs_db_path, job_id))
|
| 145 |
+
|
| 146 |
+
response = client.get(f"/jobs/{job_id}")
|
| 147 |
+
assert response.status_code == 200
|
| 148 |
+
body = response.json()
|
| 149 |
+
assert body["job_id"] == job_id
|
| 150 |
+
assert body["status"] == "processing"
|
| 151 |
+
assert body["total_files"] == 1
|
| 152 |
+
assert "progress_percent" in body
|
| 153 |
+
assert "errors" in body
|
tests/test_jobs.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for ingest job listing and status endpoints."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
|
| 5 |
+
from storage.job_store import create_ingest_job, update_job_progress
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_get_job_status_returns_spec_shape(client, test_settings):
|
| 9 |
+
job_id = asyncio.run(
|
| 10 |
+
create_ingest_job(
|
| 11 |
+
test_settings.jobs_db_path,
|
| 12 |
+
collection_name="default",
|
| 13 |
+
filenames=["report.pdf", "notes.txt"],
|
| 14 |
+
)
|
| 15 |
+
)
|
| 16 |
+
asyncio.run(
|
| 17 |
+
update_job_progress(
|
| 18 |
+
test_settings.jobs_db_path,
|
| 19 |
+
job_id,
|
| 20 |
+
processed_files=1,
|
| 21 |
+
failed_files=0,
|
| 22 |
+
errors=[],
|
| 23 |
+
message="Processing first file",
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
response = client.get(f"/jobs/{job_id}")
|
| 28 |
+
assert response.status_code == 200
|
| 29 |
+
body = response.json()
|
| 30 |
+
assert body["job_id"] == job_id
|
| 31 |
+
assert body["status"] in ("queued", "processing", "completed", "failed")
|
| 32 |
+
assert body["total_files"] == 2
|
| 33 |
+
assert body["processed_files"] == 1
|
| 34 |
+
assert body["failed_files"] == 0
|
| 35 |
+
assert 0 <= body["progress_percent"] <= 100
|
| 36 |
+
assert isinstance(body["errors"], list)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_list_jobs_includes_total(client, test_settings):
|
| 40 |
+
job_id = asyncio.run(
|
| 41 |
+
create_ingest_job(
|
| 42 |
+
test_settings.jobs_db_path,
|
| 43 |
+
collection_name="default",
|
| 44 |
+
filenames=["sample.txt"],
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
response = client.get("/jobs", params={"limit": 10, "offset": 0})
|
| 49 |
+
assert response.status_code == 200
|
| 50 |
+
body = response.json()
|
| 51 |
+
assert body["total"] >= 1
|
| 52 |
+
assert any(j["job_id"] == job_id for j in body["jobs"])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_get_job_not_found_returns_404(client):
|
| 56 |
+
response = client.get("/jobs/nonexistent-job-id")
|
| 57 |
+
assert response.status_code == 404
|
| 58 |
+
assert "not found" in response.json()["detail"].lower()
|
tests/test_query.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for ``/query/ask``, ``/query/summarise``, and legacy ``POST /query``."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock
|
| 4 |
+
|
| 5 |
+
from rag.retriever import NO_MATCH_ANSWER, RetrievedChunk
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_ask_returns_grounded_answer_with_sources(client, monkeypatch):
|
| 9 |
+
chunks = [
|
| 10 |
+
RetrievedChunk(
|
| 11 |
+
text="Audi has strategic EV expansion plans.",
|
| 12 |
+
score=0.92,
|
| 13 |
+
source="strategy.md",
|
| 14 |
+
page=1,
|
| 15 |
+
chunk_index=0,
|
| 16 |
+
)
|
| 17 |
+
]
|
| 18 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 19 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 20 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 21 |
+
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Audi is expanding EV investment.", 42))
|
| 22 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock(return_value="evt-1"))
|
| 23 |
+
|
| 24 |
+
response = client.post(
|
| 25 |
+
"/query/ask",
|
| 26 |
+
json={
|
| 27 |
+
"question": "What is Audi doing in EV markets worldwide?",
|
| 28 |
+
"collection_name": "default",
|
| 29 |
+
"top_k": 3,
|
| 30 |
+
"user_id": "tester",
|
| 31 |
+
},
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
assert response.status_code == 200
|
| 35 |
+
body = response.json()
|
| 36 |
+
assert body["answer"] == "Audi is expanding EV investment."
|
| 37 |
+
assert "query_id" in body
|
| 38 |
+
assert body["question"].startswith("What is Audi")
|
| 39 |
+
assert len(body["sources"]) == 1
|
| 40 |
+
assert body["sources"][0]["document_name"] == "strategy.md"
|
| 41 |
+
assert body["sources"][0]["page_number"] == 1
|
| 42 |
+
assert body["tokens_used"] == 42
|
| 43 |
+
assert "response_time_ms" in body
|
| 44 |
+
assert "model_used" in body
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_ask_respects_top_k_in_retrieve_call(client, monkeypatch):
|
| 48 |
+
captured: dict[str, object] = {}
|
| 49 |
+
|
| 50 |
+
def capture_retrieve(vs, question, k):
|
| 51 |
+
captured["k"] = k
|
| 52 |
+
return []
|
| 53 |
+
|
| 54 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 55 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 56 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", capture_retrieve)
|
| 57 |
+
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("No match answer", 0))
|
| 58 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
|
| 59 |
+
|
| 60 |
+
response = client.post(
|
| 61 |
+
"/query/ask",
|
| 62 |
+
json={"question": "What is known about the topic here?", "collection_name": "default", "top_k": 7},
|
| 63 |
+
)
|
| 64 |
+
assert response.status_code == 200
|
| 65 |
+
assert captured.get("k") == 7
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_ask_empty_collection_returns_no_match_message(client, monkeypatch):
|
| 69 |
+
"""Spec: query on empty collection returns appropriate message."""
|
| 70 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 71 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 72 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: [])
|
| 73 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
|
| 74 |
+
|
| 75 |
+
response = client.post(
|
| 76 |
+
"/query/ask",
|
| 77 |
+
json={
|
| 78 |
+
"question": "What does the document say about revenue?",
|
| 79 |
+
"collection_name": "default",
|
| 80 |
+
"top_k": 5,
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
assert response.status_code == 200
|
| 85 |
+
assert response.json()["answer"] == NO_MATCH_ANSWER
|
| 86 |
+
assert response.json()["sources"] == []
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_ask_low_relevance_chunks_returns_no_match_message(client, monkeypatch):
|
| 90 |
+
low_score_chunks = [
|
| 91 |
+
RetrievedChunk(
|
| 92 |
+
text="Unrelated fragment.",
|
| 93 |
+
score=0.05,
|
| 94 |
+
source="noise.txt",
|
| 95 |
+
page=1,
|
| 96 |
+
chunk_index=0,
|
| 97 |
+
)
|
| 98 |
+
]
|
| 99 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 100 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 101 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: low_score_chunks)
|
| 102 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
|
| 103 |
+
|
| 104 |
+
response = client.post(
|
| 105 |
+
"/query/ask",
|
| 106 |
+
json={"question": "What are the key risk factors?", "collection_name": "default"},
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
assert response.status_code == 200
|
| 110 |
+
assert response.json()["answer"] == NO_MATCH_ANSWER
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_ask_returns_422_for_invalid_payload(client):
|
| 114 |
+
response = client.post("/query/ask", json={"collection_name": "default"})
|
| 115 |
+
assert response.status_code == 422
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_ask_returns_422_for_short_question(client):
|
| 119 |
+
response = client.post(
|
| 120 |
+
"/query/ask",
|
| 121 |
+
json={"question": "hi", "collection_name": "default"},
|
| 122 |
+
)
|
| 123 |
+
assert response.status_code == 422
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def test_ask_returns_500_when_retrieval_fails(client, monkeypatch):
|
| 127 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 128 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 129 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: (_ for _ in ()).throw(RuntimeError("retrieval failed")))
|
| 130 |
+
|
| 131 |
+
response = client.post(
|
| 132 |
+
"/query/ask",
|
| 133 |
+
json={"question": "What happened in the documents?", "collection_name": "default"},
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
assert response.status_code == 500
|
| 137 |
+
assert "retrieval failed" in response.json()["detail"]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def test_summarise_returns_summary_payload(client, monkeypatch):
|
| 141 |
+
"""Spec: /query/summarise returns summary payload when collection has documents."""
|
| 142 |
+
chunks = [
|
| 143 |
+
RetrievedChunk(
|
| 144 |
+
text="Revenue grew year over year.",
|
| 145 |
+
score=0.9,
|
| 146 |
+
source="report.txt",
|
| 147 |
+
page=2,
|
| 148 |
+
chunk_index=0,
|
| 149 |
+
)
|
| 150 |
+
]
|
| 151 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 152 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 153 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 154 |
+
monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Executive summary text.", 25))
|
| 155 |
+
monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 3)
|
| 156 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
|
| 157 |
+
|
| 158 |
+
response = client.post(
|
| 159 |
+
"/query/summarise",
|
| 160 |
+
json={"collection_name": "default", "focus": "financial highlights", "user_id": "analyst"},
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
assert response.status_code == 200
|
| 164 |
+
body = response.json()
|
| 165 |
+
assert body["summary"] == "Executive summary text."
|
| 166 |
+
assert body["document_count"] == 3
|
| 167 |
+
assert "query_id" in body
|
| 168 |
+
assert len(body["sources"]) == 1
|
| 169 |
+
assert body["sources"][0]["document_name"] == "report.txt"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def test_summarise_returns_500_when_audit_persist_fails(client, monkeypatch):
|
| 173 |
+
chunks = [
|
| 174 |
+
RetrievedChunk(
|
| 175 |
+
text="Revenue and risks are discussed in the report.",
|
| 176 |
+
score=0.88,
|
| 177 |
+
source="report.txt",
|
| 178 |
+
page=None,
|
| 179 |
+
chunk_index=2,
|
| 180 |
+
)
|
| 181 |
+
]
|
| 182 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 183 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 184 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 185 |
+
monkeypatch.setattr("api.routes.query.summarise_with_grounding", lambda *_, **__: ("Summary output", 10))
|
| 186 |
+
monkeypatch.setattr("api.routes.query.collection_document_count", lambda *_: 5)
|
| 187 |
+
monkeypatch.setattr(
|
| 188 |
+
"api.routes.query.persist_query_audit",
|
| 189 |
+
AsyncMock(side_effect=RuntimeError("audit write failed")),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
response = client.post(
|
| 193 |
+
"/query/summarise",
|
| 194 |
+
json={"collection_name": "default", "focus": "summarise risks", "user_id": "u1"},
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
assert response.status_code == 500
|
| 198 |
+
assert "audit write failed" in response.json()["detail"]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def test_legacy_query_endpoint_matches_ask(client, monkeypatch):
|
| 202 |
+
chunks = [
|
| 203 |
+
RetrievedChunk(
|
| 204 |
+
text="Clause about indemnity.",
|
| 205 |
+
score=0.8,
|
| 206 |
+
source="contract.md",
|
| 207 |
+
page=4,
|
| 208 |
+
chunk_index=1,
|
| 209 |
+
)
|
| 210 |
+
]
|
| 211 |
+
monkeypatch.setattr("api.routes.query.create_embedding_function", lambda: object())
|
| 212 |
+
monkeypatch.setattr("api.routes.query.get_vector_store", lambda **_: object())
|
| 213 |
+
monkeypatch.setattr("api.routes.query.retrieve_chunks", lambda *_: chunks)
|
| 214 |
+
monkeypatch.setattr("api.routes.query.answer_with_grounding", lambda *_: ("Indemnity is capped.", 5))
|
| 215 |
+
monkeypatch.setattr("api.routes.query.persist_query_audit", AsyncMock())
|
| 216 |
+
|
| 217 |
+
payload = {
|
| 218 |
+
"question": "What are the indemnity limits in the contract?",
|
| 219 |
+
"collection_name": "default",
|
| 220 |
+
"top_k": 3,
|
| 221 |
+
}
|
| 222 |
+
ask = client.post("/query/ask", json=payload)
|
| 223 |
+
legacy = client.post("/query", json=payload)
|
| 224 |
+
|
| 225 |
+
assert ask.status_code == 200
|
| 226 |
+
assert legacy.status_code == 200
|
| 227 |
+
assert legacy.json()["answer"] == ask.json()["answer"]
|
| 228 |
+
assert "query_id" in legacy.json()
|
| 229 |
+
assert legacy.json()["sources"][0]["document_name"] == "contract.md"
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
workers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Background workers (ingest pipeline)."""
|
workers/ingest_worker.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Background ingest worker invoked from FastAPI ``BackgroundTasks``.
|
| 2 |
+
|
| 3 |
+
For each temp file: load → chunk → embed → add to Chroma, then update job progress in SQLite.
|
| 4 |
+
Temp files are always deleted in a ``finally`` block.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from rag.chunker import chunk_documents
|
| 11 |
+
from rag.embedder import create_embedding_function
|
| 12 |
+
from rag.loader import load_documents
|
| 13 |
+
from rag.vector_store import add_documents, get_vector_store
|
| 14 |
+
from storage.job_store import (
|
| 15 |
+
complete_ingest_job,
|
| 16 |
+
fail_ingest_job,
|
| 17 |
+
mark_job_processing,
|
| 18 |
+
update_job_progress,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _ingest_one_file_sync(temp_path: str, collection_name: str, chroma_persist_directory: str) -> tuple[list[str], int]:
|
| 23 |
+
"""Blocking ingest for one path; returns ``(chunk_vector_ids, chunk_count)``."""
|
| 24 |
+
documents = load_documents(temp_path)
|
| 25 |
+
chunks = chunk_documents(documents)
|
| 26 |
+
if not chunks:
|
| 27 |
+
raise ValueError("No content to ingest.")
|
| 28 |
+
embedding_function = create_embedding_function()
|
| 29 |
+
vector_store = get_vector_store(
|
| 30 |
+
persist_directory=chroma_persist_directory,
|
| 31 |
+
collection_name=collection_name,
|
| 32 |
+
embedding_function=embedding_function,
|
| 33 |
+
)
|
| 34 |
+
document_ids = add_documents(vector_store, chunks)
|
| 35 |
+
return document_ids, len(chunks)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
async def run_ingest_job(
|
| 39 |
+
job_id: str,
|
| 40 |
+
files: list[tuple[str, str]],
|
| 41 |
+
collection_name: str,
|
| 42 |
+
jobs_db_path: str,
|
| 43 |
+
chroma_persist_directory: str,
|
| 44 |
+
) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Process one or more temp files for a single job. ``files`` is (temp_path, display_name).
|
| 47 |
+
"""
|
| 48 |
+
all_doc_ids: list[str] = []
|
| 49 |
+
errors: list[str] = []
|
| 50 |
+
processed = 0
|
| 51 |
+
failed = 0
|
| 52 |
+
total = len(files)
|
| 53 |
+
if total == 0:
|
| 54 |
+
await fail_ingest_job(jobs_db_path, job_id, message="No files to ingest.")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
await mark_job_processing(jobs_db_path, job_id)
|
| 59 |
+
for temp_path, display_name in files:
|
| 60 |
+
try:
|
| 61 |
+
doc_ids, num_chunks = await asyncio.to_thread(
|
| 62 |
+
_ingest_one_file_sync,
|
| 63 |
+
temp_path,
|
| 64 |
+
collection_name,
|
| 65 |
+
chroma_persist_directory,
|
| 66 |
+
)
|
| 67 |
+
all_doc_ids.extend(doc_ids)
|
| 68 |
+
processed += 1
|
| 69 |
+
await update_job_progress(
|
| 70 |
+
jobs_db_path,
|
| 71 |
+
job_id,
|
| 72 |
+
processed_files=processed,
|
| 73 |
+
failed_files=failed,
|
| 74 |
+
errors=errors,
|
| 75 |
+
message=f"Ingested {display_name} ({num_chunks} chunks).",
|
| 76 |
+
)
|
| 77 |
+
except Exception as exc:
|
| 78 |
+
failed += 1
|
| 79 |
+
errors.append(f"{display_name}: {exc}")
|
| 80 |
+
await update_job_progress(
|
| 81 |
+
jobs_db_path,
|
| 82 |
+
job_id,
|
| 83 |
+
processed_files=processed,
|
| 84 |
+
failed_files=failed,
|
| 85 |
+
errors=errors,
|
| 86 |
+
message=f"Failed on {display_name}: {exc}",
|
| 87 |
+
)
|
| 88 |
+
finally:
|
| 89 |
+
Path(temp_path).unlink(missing_ok=True)
|
| 90 |
+
|
| 91 |
+
if processed == 0:
|
| 92 |
+
await fail_ingest_job(
|
| 93 |
+
jobs_db_path,
|
| 94 |
+
job_id,
|
| 95 |
+
message="All files failed ingestion.",
|
| 96 |
+
errors=errors,
|
| 97 |
+
)
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
chunk_note = f"{len(all_doc_ids)} chunk vector(s) across {processed} file(s)."
|
| 101 |
+
await complete_ingest_job(
|
| 102 |
+
jobs_db_path,
|
| 103 |
+
job_id,
|
| 104 |
+
document_ids=all_doc_ids,
|
| 105 |
+
message=f"Ingestion completed. {chunk_note}",
|
| 106 |
+
)
|
| 107 |
+
except Exception as exc:
|
| 108 |
+
await fail_ingest_job(jobs_db_path, job_id, message=str(exc), errors=errors + [str(exc)])
|