Mayank Chugh commited on
Commit
bdfb32d
·
1 Parent(s): d19142f

Implement Milestone 8 by adding new endpoints for URL ingestion and collection management. Introduce `httpx` as a dependency for handling URL downloads, and enhance the API with endpoints for listing and deleting collections. Update request and response models to support new functionalities, and refactor existing routes for improved clarity and organization.

Browse files
api/routes/audit.py CHANGED
@@ -7,16 +7,19 @@ from models.requests import AuditListParams
7
  from models.responses import AuditDetailResponse, AuditEvent, AuditListResponse
8
  from storage.audit_store import get_audit_event, list_audit_events
9
 
 
10
  def _audit_list_params(
11
  limit: Annotated[int, Query(ge=1, le=100)] = 10,
12
  offset: Annotated[int, Query(ge=0)] = 0,
13
  ) -> AuditListParams:
14
  return AuditListParams(limit=limit, offset=offset)
15
 
16
- router = APIRouter(tags=["audit"])
17
 
18
- @router.get("/audit", response_model=AuditListResponse)
19
- async def audit_list(
 
 
 
20
  params: Annotated[AuditListParams, Depends(_audit_list_params)],
21
  ) -> AuditListResponse:
22
  settings = get_settings()
@@ -29,14 +32,14 @@ async def audit_list(
29
  )
30
 
31
 
32
- @router.get("/audit/{event_id}", response_model=AuditDetailResponse)
33
- async def audit_detail(event_id: str) -> AuditDetailResponse:
34
  settings = get_settings()
35
- event = await get_audit_event(settings.audit_db_path, event_id)
36
  if event is None:
37
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
38
  return AuditDetailResponse(
39
  status="success",
40
  message="Audit event retrieved.",
41
  event=event,
42
- )
 
7
  from models.responses import AuditDetailResponse, AuditEvent, AuditListResponse
8
  from storage.audit_store import get_audit_event, list_audit_events
9
 
10
+
11
  def _audit_list_params(
12
  limit: Annotated[int, Query(ge=1, le=100)] = 10,
13
  offset: Annotated[int, Query(ge=0)] = 0,
14
  ) -> AuditListParams:
15
  return AuditListParams(limit=limit, offset=offset)
16
 
 
17
 
18
+ router = APIRouter(prefix="/audit", tags=["audit"])
19
+
20
+
21
+ @router.get("/logs", response_model=AuditListResponse)
22
+ async def audit_logs(
23
  params: Annotated[AuditListParams, Depends(_audit_list_params)],
24
  ) -> AuditListResponse:
25
  settings = get_settings()
 
32
  )
33
 
34
 
35
+ @router.get("/logs/{query_id}", response_model=AuditDetailResponse)
36
+ async def audit_log_detail(query_id: str) -> AuditDetailResponse:
37
  settings = get_settings()
38
+ event = await get_audit_event(settings.audit_db_path, query_id)
39
  if event is None:
40
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Audit event not found.")
41
  return AuditDetailResponse(
42
  status="success",
43
  message="Audit event retrieved.",
44
  event=event,
45
+ )
api/routes/ingest.py CHANGED
@@ -1,11 +1,20 @@
1
  from pathlib import Path
2
  from tempfile import NamedTemporaryFile
3
  from typing import Annotated
 
4
 
 
5
  from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
6
 
7
  from api.config import get_settings
8
- from models.responses import IngestUploadResponse
 
 
 
 
 
 
 
9
  from storage.job_store import create_ingest_job
10
  from workers.ingest_worker import run_ingest_job
11
 
@@ -13,6 +22,13 @@ router = APIRouter(prefix="/ingest", tags=["ingest"])
13
 
14
  _SUPPORTED_EXTENSIONS = frozenset({".pdf", ".txt", ".md"})
15
 
 
 
 
 
 
 
 
16
 
17
  def _validate_file(file: UploadFile, max_bytes: int) -> str:
18
  filename = (file.filename or "").strip()
@@ -38,6 +54,84 @@ def _validate_file(file: UploadFile, max_bytes: int) -> str:
38
  return suffix
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @router.post("/upload", response_model=IngestUploadResponse)
42
  async def upload_endpoint(
43
  background_tasks: BackgroundTasks,
@@ -87,3 +181,83 @@ async def upload_endpoint(
87
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
88
  finally:
89
  await file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  from tempfile import NamedTemporaryFile
3
  from typing import Annotated
4
+ from urllib.parse import unquote, urlparse
5
 
6
+ import httpx
7
  from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
8
 
9
  from api.config import get_settings
10
+ from models.requests import IngestUrlRequest
11
+ from models.responses import (
12
+ IngestCollectionsResponse,
13
+ IngestDeleteCollectionResponse,
14
+ IngestUploadResponse,
15
+ CollectionItem,
16
+ )
17
+ from rag.vector_store import delete_collection, list_collection_names
18
  from storage.job_store import create_ingest_job
19
  from workers.ingest_worker import run_ingest_job
20
 
 
22
 
23
  _SUPPORTED_EXTENSIONS = frozenset({".pdf", ".txt", ".md"})
24
 
25
+ _CONTENT_TYPE_SUFFIX: dict[str, str] = {
26
+ "application/pdf": ".pdf",
27
+ "text/plain": ".txt",
28
+ "text/markdown": ".md",
29
+ "text/x-markdown": ".md",
30
+ }
31
+
32
 
33
  def _validate_file(file: UploadFile, max_bytes: int) -> str:
34
  filename = (file.filename or "").strip()
 
54
  return suffix
55
 
56
 
57
+ def _suffix_from_url_path(url: str) -> str | None:
58
+ path = urlparse(url).path
59
+ suffix = Path(unquote(path)).suffix.lower()
60
+ return suffix if suffix in _SUPPORTED_EXTENSIONS else None
61
+
62
+
63
+ def _suffix_from_content_type(content_type: str | None) -> str | None:
64
+ if not content_type:
65
+ return None
66
+ base = content_type.split(";")[0].strip().lower()
67
+ return _CONTENT_TYPE_SUFFIX.get(base)
68
+
69
+
70
+ def _display_name_from_url(url: str, suffix: str) -> str:
71
+ name = Path(unquote(urlparse(url).path)).name.strip()
72
+ if not name or name in {"/", "."}:
73
+ return f"download{suffix}"
74
+ if Path(name).suffix.lower() not in _SUPPORTED_EXTENSIONS:
75
+ return f"{name}{suffix}" if not name.endswith(suffix) else name
76
+ return name
77
+
78
+
79
+ async def _download_url_to_temp(url: str, max_bytes: int) -> tuple[str, str]:
80
+ parsed = urlparse(url)
81
+ if parsed.scheme not in ("http", "https"):
82
+ raise HTTPException(
83
+ status_code=status.HTTP_400_BAD_REQUEST,
84
+ detail="Only http and https URLs are supported.",
85
+ )
86
+
87
+ timeout = httpx.Timeout(60.0, connect=10.0)
88
+ limits = httpx.Limits(max_keepalive_connections=5, max_connections=5)
89
+ headers = {"User-Agent": "doc-audi-ai/ingest"}
90
+
91
+ try:
92
+ async with httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True) as client:
93
+ async with client.stream("GET", url, headers=headers) as response:
94
+ response.raise_for_status()
95
+ content_type = response.headers.get("content-type")
96
+ suffix = _suffix_from_url_path(url) or _suffix_from_content_type(content_type)
97
+ if not suffix:
98
+ raise HTTPException(
99
+ status_code=status.HTTP_400_BAD_REQUEST,
100
+ detail=(
101
+ "Could not determine file type from the URL path or Content-Type. "
102
+ "Provide a .pdf, .txt, or .md resource with matching content-type."
103
+ ),
104
+ )
105
+
106
+ display_name = _display_name_from_url(url, suffix)
107
+ total = 0
108
+ with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
109
+ temp_path = tmp.name
110
+ async for chunk in response.aiter_bytes(chunk_size=65536):
111
+ total += len(chunk)
112
+ if total > max_bytes:
113
+ raise HTTPException(
114
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
115
+ detail=f"Download too large. Max allowed is {max_bytes // (1024 * 1024)}MB.",
116
+ )
117
+ tmp.write(chunk)
118
+ except HTTPException:
119
+ raise
120
+ except httpx.HTTPStatusError as exc:
121
+ code = exc.response.status_code if exc.response else "unknown"
122
+ raise HTTPException(
123
+ status_code=status.HTTP_502_BAD_GATEWAY,
124
+ detail=f"Remote server returned HTTP {code}.",
125
+ ) from exc
126
+ except httpx.RequestError as exc:
127
+ raise HTTPException(
128
+ status_code=status.HTTP_502_BAD_GATEWAY,
129
+ detail=f"Failed to download URL: {exc}",
130
+ ) from exc
131
+
132
+ return temp_path, display_name
133
+
134
+
135
  @router.post("/upload", response_model=IngestUploadResponse)
136
  async def upload_endpoint(
137
  background_tasks: BackgroundTasks,
 
181
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
182
  finally:
183
  await file.close()
184
+
185
+
186
+ @router.post("/url", response_model=IngestUploadResponse)
187
+ async def ingest_url_endpoint(
188
+ background_tasks: BackgroundTasks,
189
+ payload: IngestUrlRequest,
190
+ ) -> IngestUploadResponse:
191
+ settings = get_settings()
192
+ max_bytes = settings.max_file_size_mb * 1024 * 1024
193
+ url_str = str(payload.url).strip()
194
+ temp_path = ""
195
+ try:
196
+ temp_path, display_name = await _download_url_to_temp(url_str, max_bytes)
197
+
198
+ job_id = await create_ingest_job(
199
+ settings.jobs_db_path,
200
+ collection_name=payload.collection_name,
201
+ filename=display_name,
202
+ )
203
+
204
+ background_tasks.add_task(
205
+ run_ingest_job,
206
+ job_id,
207
+ temp_path,
208
+ payload.collection_name,
209
+ settings.jobs_db_path,
210
+ settings.chroma_persist_directory,
211
+ )
212
+
213
+ return IngestUploadResponse(
214
+ status="queued",
215
+ message=f"Ingestion job accepted. Poll GET /jobs/{job_id} for status.",
216
+ job_id=job_id,
217
+ document_ids=[],
218
+ )
219
+ except HTTPException:
220
+ if temp_path:
221
+ Path(temp_path).unlink(missing_ok=True)
222
+ raise
223
+ except Exception as exc:
224
+ if temp_path:
225
+ Path(temp_path).unlink(missing_ok=True)
226
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
227
+
228
+
229
+ @router.get("/collections", response_model=IngestCollectionsResponse)
230
+ async def list_collections_endpoint() -> IngestCollectionsResponse:
231
+ settings = get_settings()
232
+ try:
233
+ names = list_collection_names(settings.chroma_persist_directory)
234
+ except Exception as exc:
235
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
236
+ items = [CollectionItem(name=n) for n in names]
237
+ return IngestCollectionsResponse(
238
+ status="success",
239
+ message=f"Found {len(items)} collection(s).",
240
+ collections=items,
241
+ )
242
+
243
+
244
+ @router.delete("/collection/{collection_name}", response_model=IngestDeleteCollectionResponse)
245
+ async def delete_collection_endpoint(collection_name: str) -> IngestDeleteCollectionResponse:
246
+ if not collection_name.strip():
247
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="collection_name is required.")
248
+ settings = get_settings()
249
+ name = collection_name.strip()
250
+ try:
251
+ existing = list_collection_names(settings.chroma_persist_directory)
252
+ if name not in existing:
253
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Collection not found.")
254
+ delete_collection(settings.chroma_persist_directory, name)
255
+ except HTTPException:
256
+ raise
257
+ except Exception as exc:
258
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
259
+ return IngestDeleteCollectionResponse(
260
+ status="success",
261
+ message=f"Deleted collection '{name}'.",
262
+ collection_name=name,
263
+ )
api/routes/query.py CHANGED
@@ -1,31 +1,29 @@
1
  from fastapi import APIRouter, HTTPException, status
2
 
3
  from api.config import get_settings
4
- from models.requests import QueryRequest
5
  from models.responses import QueryResponse, QueryResultItem, QuerySourceItem
6
  from rag.embedder import create_embedding_function
7
- from rag.retriever import answer_with_grounding, retrieve_chunks
 
 
 
 
 
 
8
  from rag.vector_store import get_vector_store
9
  from storage.audit_store import persist_query_audit
10
 
11
- router = APIRouter(tags=["query"])
12
 
13
 
14
- @router.post("/query", response_model=QueryResponse)
15
- async def query_endpoint(payload: QueryRequest) -> QueryResponse:
16
- settings = get_settings()
17
- try:
18
- embedding_function = create_embedding_function()
19
- vector_store = get_vector_store(
20
- persist_directory=settings.chroma_persist_directory,
21
- collection_name=payload.collection_name,
22
- embedding_function=embedding_function,
23
- )
24
- chunks = retrieve_chunks(vector_store, payload.question, settings.top_k_results)
25
- answer = answer_with_grounding(settings, payload.question, chunks)
26
- except Exception as exc:
27
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
28
-
29
  results = [QueryResultItem(text=chunk.text, score=chunk.score) for chunk in chunks]
30
  sources = [
31
  QuerySourceItem(
@@ -37,20 +35,84 @@ async def query_endpoint(payload: QueryRequest) -> QueryResponse:
37
  )
38
  for chunk in chunks
39
  ]
40
- response = QueryResponse(
41
  status="success",
42
- message=f"Retrieved {len(results)} chunks from '{payload.collection_name}' and generated grounded answer.",
43
  answer=answer,
44
  sources=sources,
45
  results=results,
46
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
  await persist_query_audit(
49
  settings.audit_db_path,
 
50
  question=payload.question,
51
  collection_name=payload.collection_name,
52
  response=response,
53
  )
54
  except Exception as exc:
55
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
56
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter, HTTPException, status
2
 
3
  from api.config import get_settings
4
+ from models.requests import QueryRequest, SummariseRequest
5
  from models.responses import QueryResponse, QueryResultItem, QuerySourceItem
6
  from rag.embedder import create_embedding_function
7
+ from rag.retriever import (
8
+ SUMMARY_RETRIEVAL_QUERY,
9
+ RetrievedChunk,
10
+ answer_with_grounding,
11
+ retrieve_chunks,
12
+ summarise_with_grounding,
13
+ )
14
  from rag.vector_store import get_vector_store
15
  from storage.audit_store import persist_query_audit
16
 
17
+ router = APIRouter(prefix="/query", tags=["query"])
18
 
19
 
20
+ def _response_from_chunks(
21
+ *,
22
+ collection_name: str,
23
+ chunks: list[RetrievedChunk],
24
+ answer: str,
25
+ message: str,
26
+ ) -> QueryResponse:
 
 
 
 
 
 
 
 
27
  results = [QueryResultItem(text=chunk.text, score=chunk.score) for chunk in chunks]
28
  sources = [
29
  QuerySourceItem(
 
35
  )
36
  for chunk in chunks
37
  ]
38
+ return QueryResponse(
39
  status="success",
40
+ message=message,
41
  answer=answer,
42
  sources=sources,
43
  results=results,
44
  )
45
+
46
+
47
+ @router.post("/ask", response_model=QueryResponse)
48
+ async def ask_endpoint(payload: QueryRequest) -> QueryResponse:
49
+ settings = get_settings()
50
+ try:
51
+ embedding_function = create_embedding_function()
52
+ vector_store = get_vector_store(
53
+ persist_directory=settings.chroma_persist_directory,
54
+ collection_name=payload.collection_name,
55
+ embedding_function=embedding_function,
56
+ )
57
+ chunks = retrieve_chunks(vector_store, payload.question, settings.top_k_results)
58
+ answer = answer_with_grounding(settings, payload.question, chunks)
59
+ except Exception as exc:
60
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
61
+
62
+ response = _response_from_chunks(
63
+ collection_name=payload.collection_name,
64
+ chunks=chunks,
65
+ answer=answer,
66
+ message=(
67
+ f"Retrieved {len(chunks)} chunks from '{payload.collection_name}' and generated a grounded answer."
68
+ ),
69
+ )
70
  try:
71
  await persist_query_audit(
72
  settings.audit_db_path,
73
+ action="query",
74
  question=payload.question,
75
  collection_name=payload.collection_name,
76
  response=response,
77
  )
78
  except Exception as exc:
79
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
80
+ return response
81
+
82
+
83
+ @router.post("/summarise", response_model=QueryResponse)
84
+ async def summarise_endpoint(payload: SummariseRequest) -> QueryResponse:
85
+ settings = get_settings()
86
+ retrieval_query = (payload.focus or "").strip() or SUMMARY_RETRIEVAL_QUERY
87
+ audit_question = payload.focus.strip() if payload.focus and payload.focus.strip() else "Summarise collection"
88
+ try:
89
+ embedding_function = create_embedding_function()
90
+ vector_store = get_vector_store(
91
+ persist_directory=settings.chroma_persist_directory,
92
+ collection_name=payload.collection_name,
93
+ embedding_function=embedding_function,
94
+ )
95
+ chunks = retrieve_chunks(vector_store, retrieval_query, settings.top_k_results)
96
+ answer = summarise_with_grounding(settings, focus=payload.focus, chunks=chunks)
97
+ except Exception as exc:
98
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
99
+
100
+ response = _response_from_chunks(
101
+ collection_name=payload.collection_name,
102
+ chunks=chunks,
103
+ answer=answer,
104
+ message=(
105
+ f"Retrieved {len(chunks)} chunks from '{payload.collection_name}' and generated a grounded summary."
106
+ ),
107
+ )
108
+ try:
109
+ await persist_query_audit(
110
+ settings.audit_db_path,
111
+ action="summarise",
112
+ question=audit_question,
113
+ collection_name=payload.collection_name,
114
+ response=response,
115
+ )
116
+ except Exception as exc:
117
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
118
+ return response
models/requests.py CHANGED
@@ -1,11 +1,44 @@
1
- from pydantic import BaseModel, ConfigDict, Field
2
 
3
- class QueryRequest(BaseModel):
4
 
 
5
  model_config = ConfigDict(extra="forbid")
6
 
7
  question: str = Field(min_length=1, max_length=8000, description="The question to ask the document")
8
- collection_name: str = Field(default="default", min_length=1, max_length=256, description="The name of the collection to ask the question from")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class IngestUploadRequest(BaseModel):
11
  model_config = ConfigDict(extra="forbid")
 
1
+ from pydantic import BaseModel, ConfigDict, Field, HttpUrl
2
 
 
3
 
4
+ class QueryRequest(BaseModel):
5
  model_config = ConfigDict(extra="forbid")
6
 
7
  question: str = Field(min_length=1, max_length=8000, description="The question to ask the document")
8
+ collection_name: str = Field(
9
+ default="default",
10
+ min_length=1,
11
+ max_length=256,
12
+ description="The name of the collection to ask the question from",
13
+ )
14
+
15
+
16
+ class SummariseRequest(BaseModel):
17
+ model_config = ConfigDict(extra="forbid")
18
+
19
+ collection_name: str = Field(
20
+ default="default",
21
+ min_length=1,
22
+ max_length=256,
23
+ description="Chroma collection to summarise from",
24
+ )
25
+ focus: str | None = Field(
26
+ default=None,
27
+ max_length=8000,
28
+ description="Optional angle or scope for retrieval and the summary (e.g. 'contract payment terms')",
29
+ )
30
+
31
+
32
+ class IngestUrlRequest(BaseModel):
33
+ model_config = ConfigDict(extra="forbid")
34
+
35
+ url: HttpUrl = Field(description="HTTP(S) URL to a PDF, TXT, or Markdown document")
36
+ collection_name: str = Field(
37
+ default="default",
38
+ min_length=1,
39
+ max_length=256,
40
+ description="Target Chroma collection name",
41
+ )
42
 
43
  class IngestUploadRequest(BaseModel):
44
  model_config = ConfigDict(extra="forbid")
models/responses.py CHANGED
@@ -25,6 +25,22 @@ class IngestUploadResponse(BaseModel):
25
  job_id: str
26
  document_ids: list[str] = Field(default_factory=list)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class JobSummary(BaseModel):
29
  job_id: str
30
  status: str
 
25
  job_id: str
26
  document_ids: list[str] = Field(default_factory=list)
27
 
28
+
29
+ class CollectionItem(BaseModel):
30
+ name: str
31
+
32
+
33
+ class IngestCollectionsResponse(BaseModel):
34
+ status: str
35
+ message: str
36
+ collections: list[CollectionItem] = Field(default_factory=list)
37
+
38
+
39
+ class IngestDeleteCollectionResponse(BaseModel):
40
+ status: str
41
+ message: str
42
+ collection_name: str
43
+
44
  class JobSummary(BaseModel):
45
  job_id: str
46
  status: str
pyproject.toml CHANGED
@@ -20,6 +20,7 @@ dependencies = [
20
  "pymupdf==1.24.3",
21
  "python-multipart==0.0.9",
22
  "aiosqlite>=0.21.0",
 
23
  "uvicorn[standard]==0.29.0",
24
  "huggingface-hub>=1.13.0",
25
  "langchain-huggingface>=0.0.3",
 
20
  "pymupdf==1.24.3",
21
  "python-multipart==0.0.9",
22
  "aiosqlite>=0.21.0",
23
+ "httpx>=0.27.0",
24
  "uvicorn[standard]==0.29.0",
25
  "huggingface-hub>=1.13.0",
26
  "langchain-huggingface>=0.0.3",
rag/retriever.py CHANGED
@@ -57,6 +57,11 @@ def retrieve_chunks(vector_store: Chroma, question: str, k: int) -> list[Retriev
57
  return chunks
58
 
59
 
 
 
 
 
 
60
  def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> str:
61
  ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
62
  if not ranked_chunks:
@@ -84,6 +89,43 @@ def answer_with_grounding(settings: Settings, question: str, chunks: list[Retrie
84
  return answer or NO_MATCH_ANSWER
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def _create_chat_model(settings: Settings) -> BaseChatModel:
88
  provider = settings.llm_provider.lower()
89
 
 
57
  return chunks
58
 
59
 
60
+ SUMMARY_RETRIEVAL_QUERY = (
61
+ "Overview of the document: main topics, key definitions, obligations, risks, and conclusions."
62
+ )
63
+
64
+
65
  def answer_with_grounding(settings: Settings, question: str, chunks: list[RetrievedChunk]) -> str:
66
  ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
67
  if not ranked_chunks:
 
89
  return answer or NO_MATCH_ANSWER
90
 
91
 
92
+ def summarise_with_grounding(
93
+ settings: Settings,
94
+ *,
95
+ focus: str | None,
96
+ chunks: list[RetrievedChunk],
97
+ ) -> str:
98
+ ranked_chunks = [chunk for chunk in chunks if chunk.score is None or chunk.score >= MIN_RELEVANCE_SCORE]
99
+ if not ranked_chunks:
100
+ return NO_MATCH_ANSWER
101
+
102
+ llm = _create_chat_model(settings)
103
+ prompt_context = _format_context(ranked_chunks)
104
+ user_instruction = (
105
+ focus.strip()
106
+ if focus and focus.strip()
107
+ else "Summarise the main themes, structure, and important details. Use bullet points where helpful."
108
+ )
109
+ messages = [
110
+ SystemMessage(
111
+ content=(
112
+ "You write accurate summaries using only the provided document excerpts. "
113
+ "Do not invent facts. If the excerpts are insufficient, say what is missing."
114
+ )
115
+ ),
116
+ HumanMessage(
117
+ content=(
118
+ f"Summary request: {user_instruction}\n\n"
119
+ f"Document excerpts:\n{prompt_context}\n\n"
120
+ "Return a structured, concise summary grounded in the excerpts above."
121
+ )
122
+ ),
123
+ ]
124
+ response = llm.invoke(messages)
125
+ answer = _extract_message_text(response).strip()
126
+ return answer or NO_MATCH_ANSWER
127
+
128
+
129
  def _create_chat_model(settings: Settings) -> BaseChatModel:
130
  provider = settings.llm_provider.lower()
131
 
rag/vector_store.py CHANGED
@@ -1,9 +1,10 @@
1
  from pathlib import Path
2
  from uuid import uuid4
3
 
 
 
4
  from langchain_core.documents import Document
5
  from langchain_core.embeddings import Embeddings
6
- from langchain_chroma import Chroma
7
 
8
 
9
  def get_vector_store(
@@ -24,3 +25,15 @@ def add_documents(vector_store: Chroma, chunks: list[Document]) -> list[str]:
24
  vector_store.add_documents(documents=chunks, ids=document_ids)
25
  return document_ids
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  from uuid import uuid4
3
 
4
+ import chromadb
5
+ from langchain_chroma import Chroma
6
  from langchain_core.documents import Document
7
  from langchain_core.embeddings import Embeddings
 
8
 
9
 
10
  def get_vector_store(
 
25
  vector_store.add_documents(documents=chunks, ids=document_ids)
26
  return document_ids
27
 
28
+
29
+ def list_collection_names(persist_directory: str) -> list[str]:
30
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
31
+ client = chromadb.PersistentClient(path=persist_directory)
32
+ return sorted(c.name for c in client.list_collections())
33
+
34
+
35
+ def delete_collection(persist_directory: str, collection_name: str) -> None:
36
+ Path(persist_directory).mkdir(parents=True, exist_ok=True)
37
+ client = chromadb.PersistentClient(path=persist_directory)
38
+ client.delete_collection(name=collection_name)
39
+
requirements.txt CHANGED
@@ -14,5 +14,6 @@ anthropic==0.28.1
14
  pymupdf==1.24.3
15
  python-multipart==0.0.9
16
  aiosqlite
 
17
  huggingface-hub
18
  langchain-huggingface
 
14
  pymupdf==1.24.3
15
  python-multipart==0.0.9
16
  aiosqlite
17
+ httpx>=0.27.0
18
  huggingface-hub
19
  langchain-huggingface
storage/audit_store.py CHANGED
@@ -34,6 +34,7 @@ async def init_audit_db(db_path: str) -> None:
34
  async def persist_query_audit(
35
  db_path: str,
36
  *,
 
37
  question: str,
38
  collection_name: str,
39
  response: QueryResponse,
@@ -49,7 +50,7 @@ async def persist_query_audit(
49
  """,
50
  (
51
  event_id,
52
- "query",
53
  question,
54
  collection_name,
55
  response.answer,
 
34
  async def persist_query_audit(
35
  db_path: str,
36
  *,
37
+ action: str,
38
  question: str,
39
  collection_name: str,
40
  response: QueryResponse,
 
50
  """,
51
  (
52
  event_id,
53
+ action,
54
  question,
55
  collection_name,
56
  response.answer,
uv.lock CHANGED
@@ -537,6 +537,7 @@ dependencies = [
537
  { name = "anthropic" },
538
  { name = "chromadb" },
539
  { name = "fastapi" },
 
540
  { name = "huggingface-hub" },
541
  { name = "langchain" },
542
  { name = "langchain-anthropic" },
@@ -561,6 +562,7 @@ requires-dist = [
561
  { name = "anthropic", specifier = "==0.28.1" },
562
  { name = "chromadb", specifier = "==0.5.0" },
563
  { name = "fastapi", specifier = "==0.111.0" },
 
564
  { name = "huggingface-hub", specifier = ">=1.13.0" },
565
  { name = "langchain", specifier = "==0.2.0" },
566
  { name = "langchain-anthropic", specifier = "==0.1.15" },
@@ -970,17 +972,18 @@ wheels = [
970
 
971
  [[package]]
972
  name = "httpx"
973
- version = "0.28.1"
974
  source = { registry = "https://pypi.org/simple" }
975
  dependencies = [
976
  { name = "anyio" },
977
  { name = "certifi" },
978
  { name = "httpcore" },
979
  { name = "idna" },
 
980
  ]
981
- sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
982
  wheels = [
983
- { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
984
  ]
985
 
986
  [[package]]
 
537
  { name = "anthropic" },
538
  { name = "chromadb" },
539
  { name = "fastapi" },
540
+ { name = "httpx" },
541
  { name = "huggingface-hub" },
542
  { name = "langchain" },
543
  { name = "langchain-anthropic" },
 
562
  { name = "anthropic", specifier = "==0.28.1" },
563
  { name = "chromadb", specifier = "==0.5.0" },
564
  { name = "fastapi", specifier = "==0.111.0" },
565
+ { name = "httpx", specifier = ">=0.27.0" },
566
  { name = "huggingface-hub", specifier = ">=1.13.0" },
567
  { name = "langchain", specifier = "==0.2.0" },
568
  { name = "langchain-anthropic", specifier = "==0.1.15" },
 
972
 
973
  [[package]]
974
  name = "httpx"
975
+ version = "0.27.0"
976
  source = { registry = "https://pypi.org/simple" }
977
  dependencies = [
978
  { name = "anyio" },
979
  { name = "certifi" },
980
  { name = "httpcore" },
981
  { name = "idna" },
982
+ { name = "sniffio" },
983
  ]
984
+ sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/3da5bdf4408b8b2800061c339f240c1802f2e82d55e50bd39c5a881f47f0/httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5", size = 126413, upload-time = "2024-02-21T13:07:52.434Z" }
985
  wheels = [
986
+ { url = "https://files.pythonhosted.org/packages/41/7b/ddacf6dcebb42466abd03f368782142baa82e08fc0c1f8eaa05b4bae87d5/httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5", size = 75590, upload-time = "2024-02-21T13:07:50.455Z" },
987
  ]
988
 
989
  [[package]]