duongtruongbinh's picture
Add support for multiple PDF uploads and enhanced citation handling
bc2d97e
"""Retrieval, prompts, citations, and grounded answers."""
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from src.config import settings
from src.filters import filters_to_qdrant
from src.llm import invoke_llm
from src.schemas import ChunkMetadata, Citation, RagAnswer, RetrievedChunk
from src.store import get_vector_store, scroll_all
PROMPTS_DIR = Path(__file__).parent / "prompts"
ANSWER_TEMPLATE = "answer.jinja2"
def retrieve(
query: str,
k: int | None = None,
filters: dict[str, object] | None = None,
collection_name: str | None = None,
) -> list[RetrievedChunk]:
store = get_vector_store(collection_name=collection_name)
hits = store.similarity_search_with_score(
query=query,
k=k or settings.top_k,
filter=filters_to_qdrant(filters),
)
return [
RetrievedChunk(
text=doc.page_content,
score=float(score),
metadata=ChunkMetadata(**doc.metadata),
)
for doc, score in hits
]
def fetch_all_chunks(
filters: dict[str, object] | None = None,
collection_name: str | None = None,
) -> list[RetrievedChunk]:
"""Scroll every chunk matching the filter, ordered by filename → page → index."""
name = collection_name or settings.qdrant_collection
results: list[RetrievedChunk] = []
for page in scroll_all(name, scroll_filter=filters_to_qdrant(filters)):
for point in page:
payload = point.payload or {}
meta = payload.get("metadata") or {}
text = payload.get("page_content") or ""
if not meta or not text:
continue
results.append(RetrievedChunk(text=text, score=0.0, metadata=ChunkMetadata(**meta)))
results.sort(
key=lambda r: (
r.metadata.filename,
r.metadata.page,
int(r.metadata.chunk_id.rsplit(":", 1)[-1]),
)
)
return results
@lru_cache(maxsize=1)
def _jinja_env() -> Environment:
return Environment(
loader=FileSystemLoader(str(PROMPTS_DIR)),
autoescape=False,
undefined=StrictUndefined,
trim_blocks=True,
lstrip_blocks=True,
)
def render_prompt(template_name: str, **context: object) -> str:
"""Render an arbitrary Jinja template from the prompts directory."""
return _jinja_env().get_template(template_name).render(**context)
def format_citations(chunks: list[RetrievedChunk]) -> list[Citation]:
return [
Citation(
source_index=i,
source_marker=f"S{i}",
filename=c.metadata.filename,
page=c.metadata.page,
source_text=c.text.strip(),
section=c.metadata.section,
chunk_id=c.metadata.chunk_id,
)
for i, c in enumerate(chunks, start=1)
]
def answer(
question: str,
k: int | None = None,
filters: dict[str, object] | None = None,
collection_name: str | None = None,
) -> RagAnswer:
chunks = retrieve(question, k=k, filters=filters, collection_name=collection_name)
if not chunks:
return RagAnswer(
question=question,
answer="Tôi không có đủ thông tin trong ngữ cảnh được cung cấp để trả lời.",
)
prompt = render_prompt(ANSWER_TEMPLATE, question=question, chunks=chunks)
text = invoke_llm(prompt)
return RagAnswer(
question=question,
answer=text.strip(),
citations=format_citations(chunks),
chunks=chunks,
)