from pathlib import Path import json import re import time from typing import Any from langchain_chroma import Chroma from langchain_core.documents import Document from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder from pydantic import BaseModel, Field from core.config import ( DEFAULT_BM25_CANDIDATE_K, DEFAULT_CONTEXT_WINDOW, DEFAULT_ENABLE_QUERY_TRANSFORM, DEFAULT_MAX_EXPANDED_CHUNKS, DEFAULT_MIN_GROUNDED_CHUNKS, DEFAULT_MIN_GROUNDED_RERANK_SCORE, DEFAULT_RERANK_CANDIDATE_K, DEFAULT_RRF_K, DEFAULT_SUBQUERY_COUNT, REGISTRY_PATH, ) from rag.prompt import build_citation_repair_prompt, build_rag_prompt, rewrite_query class QueryRewriteResult(BaseModel): sub_queries: list[str] = Field(default_factory=list) class AnswerDraftResult(BaseModel): answer: str = "" def load_vectorstore(persist_dir, embeddings): """ Load a Chroma vector store from the given directory. """ return Chroma( embedding_function=embeddings, persist_directory=persist_dir, ) def retrieve_documents(vectorstore, query, k=4, reranker=None, candidate_k=DEFAULT_RERANK_CANDIDATE_K): """ Retrieve candidate chunks with embeddings first, then optionally rerank them. """ docs = vectorstore.similarity_search_with_score(query, k=max(k, candidate_k)) unique_docs = deduplicate_documents(docs) if reranker is None: return unique_docs[:k] return rerank_documents(query, unique_docs, reranker, top_n=k) def retrieve_documents_with_query_transform( vectorstore, query, k=4, reranker=None, bm25_index=None, query_transformer=None, enable_query_transform=DEFAULT_ENABLE_QUERY_TRANSFORM, candidate_k=DEFAULT_RERANK_CANDIDATE_K, bm25_candidate_k=DEFAULT_BM25_CANDIDATE_K, rrf_k=DEFAULT_RRF_K, subquery_count=DEFAULT_SUBQUERY_COUNT, metadata_filter=None, include_debug=False, ): """ Retrieve with optional query transformation. The important part is that reranking happens once on the merged candidate set, using the original user query as the final ranking signal. """ search_queries = [query] debug = { "search_queries": [], "dense_hits": {}, "bm25_hits": {}, "fused_hits": [], "reranked_hits": [], "stage_counts": {}, } if enable_query_transform and query_transformer is not None: sub_queries = generate_sub_queries(query, query_transformer, max_queries=subquery_count) search_queries.extend(sub_queries) debug["search_queries"] = search_queries candidate_documents, fusion_debug = collect_candidate_documents( vectorstore, search_queries, bm25_index=bm25_index, candidate_k=candidate_k, bm25_candidate_k=bm25_candidate_k, rrf_k=rrf_k, metadata_filter=metadata_filter, include_debug=include_debug, ) if include_debug: debug.update(fusion_debug) debug["stage_counts"] = { "search_queries": len(search_queries), "fused_candidates": len(candidate_documents), } if reranker is None: top_documents = candidate_documents[:k] if include_debug: debug["reranked_hits"] = build_hit_debug(top_documents) debug["stage_counts"]["final_hits"] = len(top_documents) return top_documents, debug return top_documents top_documents = rerank_documents(query, candidate_documents, reranker, top_n=k) if include_debug: debug["reranked_hits"] = build_hit_debug(top_documents) debug["stage_counts"]["final_hits"] = len(top_documents) return top_documents, debug return top_documents def deduplicate_documents(documents): unique_docs = [] seen_content = set() for doc, score in documents: content = doc.page_content.strip() if content in seen_content: continue unique_docs.append((doc, score)) seen_content.add(content) return unique_docs def extract_response_text(response): if response is None: return "" content = getattr(response, "content", None) if isinstance(content, list): text_parts = [] for part in content: if isinstance(part, str): text_parts.append(part.strip()) elif isinstance(part, dict) and "text" in part: text_parts.append(str(part["text"]).strip()) return " ".join(part for part in text_parts if part).strip() if isinstance(content, str): return content.strip() text = getattr(response, "text", "") if isinstance(text, str): return text.strip() return "" def generate_sub_queries(query, query_transformer, max_queries=DEFAULT_SUBQUERY_COUNT): print(f"[Query Transform] Start | Query='{query}' | Max Queries={max_queries}") prompt = rewrite_query(query) print("[Query Transform] Prompt built.") structured_transformer = get_structured_query_transformer(query_transformer) if structured_transformer is not None: print("[Query Transform] Structured transformer available. Starting structured invoke...") started_at = time.perf_counter() try: structured_result = structured_transformer.invoke(prompt) elapsed = time.perf_counter() - started_at print(f"[Query Transform] Structured invoke completed in {elapsed:.2f}s.") except Exception: elapsed = time.perf_counter() - started_at print(f"[Query Transform] Structured invoke failed after {elapsed:.2f}s.") structured_result = None sub_queries = parse_structured_sub_queries(structured_result, query, max_queries=max_queries) if sub_queries: print(f"[Query Transform] Structured parse returned {len(sub_queries)} query(s): {sub_queries}") return sub_queries print("[Query Transform] Structured parse returned no usable queries.") else: print("[Query Transform] Structured transformer unavailable.") print("[Query Transform] Falling back to direct invoke...") fallback_started_at = time.perf_counter() response = query_transformer.invoke(prompt) fallback_elapsed = time.perf_counter() - fallback_started_at print(f"[Query Transform] Direct invoke completed in {fallback_elapsed:.2f}s.") response_text = extract_response_text(response) print(f"[Query Transform] Extracted response text length: {len(response_text)}") json_sub_queries = parse_sub_queries_json(response_text, query, max_queries=max_queries) if json_sub_queries: print(f"[Query Transform] JSON parse returned {len(json_sub_queries)} query(s): {json_sub_queries}") return json_sub_queries sub_queries = [] for line in response_text.splitlines(): cleaned = re.sub(r"^\s*\d+[\).\-\s]+", "", line).strip() if not cleaned or cleaned.lower().startswith("sub-queries"): continue if not is_valid_retrieval_query(cleaned, original_query=query): continue sub_queries.append(cleaned) if len(sub_queries) >= max_queries: break print(f"[Query Transform] Line parse returned {len(sub_queries)} query(s): {sub_queries}") return sub_queries def get_structured_query_transformer(query_transformer): if query_transformer is None: return None with_structured_output = getattr(query_transformer, "with_structured_output", None) if with_structured_output is None: return None try: # Explicitly use function_calling method for better compatibility # and much faster response times with Gemma models. return with_structured_output(QueryRewriteResult, method="function_calling") except Exception: return None def parse_structured_sub_queries(response: Any, query, max_queries=DEFAULT_SUBQUERY_COUNT): if isinstance(response, QueryRewriteResult): raw_sub_queries = response.sub_queries elif isinstance(response, dict): raw_sub_queries = response.get("sub_queries") else: raw_sub_queries = getattr(response, "sub_queries", None) if raw_sub_queries is None and isinstance(response, str): return parse_sub_queries_json(response, query, max_queries=max_queries) if not isinstance(raw_sub_queries, list): return [] cleaned_sub_queries = [] seen_queries = set() for item in raw_sub_queries: if not isinstance(item, str): continue cleaned = item.strip() if not is_valid_retrieval_query(cleaned, original_query=query): continue dedupe_key = cleaned.lower() if dedupe_key in seen_queries: continue cleaned_sub_queries.append(cleaned) seen_queries.add(dedupe_key) if len(cleaned_sub_queries) >= max_queries: break return cleaned_sub_queries def parse_sub_queries_json(response_text, query, max_queries=DEFAULT_SUBQUERY_COUNT): json_match = re.search(r"\{[\s\S]*\}", response_text) if not json_match: return [] try: payload = json.loads(json_match.group(0)) except json.JSONDecodeError: return [] raw_sub_queries = payload.get("sub_queries") if not isinstance(raw_sub_queries, list): return [] cleaned_sub_queries = [] seen_queries = set() for item in raw_sub_queries: if not isinstance(item, str): continue cleaned = item.strip() if not is_valid_retrieval_query(cleaned, original_query=query): continue dedupe_key = cleaned.lower() if dedupe_key in seen_queries: continue cleaned_sub_queries.append(cleaned) seen_queries.add(dedupe_key) if len(cleaned_sub_queries) >= max_queries: break return cleaned_sub_queries def is_valid_retrieval_query(candidate, *, original_query): if not candidate: return False normalized = candidate.strip() lowered = normalized.lower() if not normalized or lowered == original_query.lower(): return False blocked_prefixes = ( "i'm sorry", "i am sorry", "please provide", "no retrieval queries needed", "the query", ) blocked_phrases = ( "does not contain enough information", "does not contain any information", "please provide a specific question", "please provide a topic", "is a greeting", "cannot be converted into retrieval queries", "no retrieval queries needed", ) if lowered.startswith(blocked_prefixes): return False return not any(phrase in lowered for phrase in blocked_phrases) def collect_candidate_documents( vectorstore, queries, bm25_index=None, candidate_k=DEFAULT_RERANK_CANDIDATE_K, bm25_candidate_k=DEFAULT_BM25_CANDIDATE_K, rrf_k=DEFAULT_RRF_K, metadata_filter=None, include_debug=False, ): fused_rankings = [] debug = { "dense_hits": {}, "bm25_hits": {}, "dense_total_hits": 0, "bm25_total_hits": 0, } chroma_filter = build_chroma_filter(metadata_filter) dense_fetch_k = candidate_k * 4 if metadata_filter else candidate_k for query in queries: dense_docs = filter_documents_by_metadata( deduplicate_documents( vectorstore.similarity_search_with_score( query, k=dense_fetch_k, filter=chroma_filter, ) ), metadata_filter, ) dense_docs = dense_docs[:candidate_k] fused_rankings.append(dense_docs) if include_debug: debug["dense_hits"][query] = build_hit_debug(dense_docs) debug["dense_total_hits"] += len(dense_docs) if bm25_index is not None: bm25_docs = retrieve_bm25_documents( query, bm25_index, top_k=bm25_candidate_k, metadata_filter=metadata_filter, ) fused_rankings.append(bm25_docs) if include_debug: debug["bm25_hits"][query] = build_hit_debug(bm25_docs) debug["bm25_total_hits"] += len(bm25_docs) fused_documents = fuse_ranked_documents(fused_rankings, rrf_k=rrf_k) if include_debug: debug["fused_hits"] = build_hit_debug(fused_documents) return fused_documents, debug def load_bm25_index(chunk_registry=None, vectorstore=None): documents = load_corpus_documents(chunk_registry=chunk_registry, vectorstore=vectorstore) tokenized_documents = [tokenize_for_bm25(doc.page_content) for doc in documents] if not tokenized_documents: return None return { "documents": documents, "tokenized_documents": tokenized_documents, "retriever": BM25Okapi(tokenized_documents), } def load_corpus_documents(chunk_registry=None, vectorstore=None): documents = documents_from_registry(chunk_registry or {}) if documents: return documents if vectorstore is not None: return documents_from_vectorstore(vectorstore) return [] def documents_from_registry(chunk_registry): documents = [] records = chunk_registry.get("by_chunk_id", {}) for chunk_id, record in records.items(): document = _document_from_record(record) if document is None: continue if document.metadata.get("chunk_id") is None: document.metadata["chunk_id"] = chunk_id documents.append(document) documents.sort(key=lambda doc: doc.metadata.get("chunk_id", "")) return documents def documents_from_vectorstore(vectorstore): collection = vectorstore._collection.get(include=["documents", "metadatas"]) documents = [] for page_content, metadata in zip(collection.get("documents", []), collection.get("metadatas", [])): documents.append( Document( page_content=page_content, metadata=metadata or {}, ) ) return documents def tokenize_for_bm25(text): return re.findall(r"\w+", text.lower()) def retrieve_bm25_documents( query, bm25_index, top_k=DEFAULT_BM25_CANDIDATE_K, metadata_filter=None, ): tokens = tokenize_for_bm25(query) if not tokens: return [] scores = bm25_index["retriever"].get_scores(tokens) ranked_indexes = sorted( range(len(scores)), key=lambda index: scores[index], reverse=True, ) documents = [] for index in ranked_indexes[:top_k]: score = float(scores[index]) if score <= 0: continue doc = bm25_index["documents"][index] if not _matches_metadata_filter(doc, metadata_filter): continue documents.append((doc, score)) return documents def _matches_metadata_filter(doc, metadata_filter): if not metadata_filter: return True for key, expected_value in metadata_filter.items(): if key == "page_range": page_value = doc.metadata.get("page") if page_value is None: return False page_number = int(page_value) + 1 start = expected_value.get("start") end = expected_value.get("end") if start is not None and page_number < start: return False if end is not None and page_number > end: return False continue if doc.metadata.get(key) != expected_value: return False return True def filter_documents_by_metadata(documents, metadata_filter): if not metadata_filter: return documents filtered_documents = [] for doc, score in documents: if _matches_metadata_filter(doc, metadata_filter): filtered_documents.append((doc, score)) return filtered_documents def build_chroma_filter(metadata_filter): if not metadata_filter: return None chroma_filter = {} for key, value in metadata_filter.items(): if key == "page_range": continue chroma_filter[key] = value return chroma_filter or None def fuse_ranked_documents(rankings, rrf_k=DEFAULT_RRF_K): fused_scores = {} fused_documents = {} for ranking in rankings: for rank, (doc, _) in enumerate(ranking, start=1): chunk_id = doc.metadata.get("chunk_id") or doc.page_content.strip() fused_documents[chunk_id] = doc fused_scores[chunk_id] = fused_scores.get(chunk_id, 0.0) + 1.0 / (rrf_k + rank) rerank_candidates = [] for chunk_id, fused_score in sorted(fused_scores.items(), key=lambda item: item[1], reverse=True): doc = fused_documents[chunk_id] doc.metadata["fusion_score"] = fused_score rerank_candidates.append((doc, fused_score)) return rerank_candidates def build_hit_debug(documents): hits = [] for doc, score in documents: hits.append( { "chunk_id": doc.metadata.get("chunk_id", "unknown"), "source": doc.metadata.get("file_name", "Unknown"), "page": doc.metadata.get("page"), "score": round(float(score), 4), "preview": doc.page_content[:180].replace("\n", " "), } ) return hits def load_reranker(model_name): return CrossEncoder(model_name) def rerank_documents(query, documents, reranker, top_n): if not documents: return [] pairs = [(query, doc.page_content) for doc, _ in documents] rerank_scores = reranker.predict(pairs) reranked = [] for (doc, original_score), rerank_score in zip(documents, rerank_scores): reranked.append((doc, float(rerank_score), float(original_score))) reranked.sort(key=lambda item: item[1], reverse=True) top_documents = [] for doc, rerank_score, original_score in reranked[:top_n]: doc.metadata["rerank_score"] = rerank_score if doc.metadata.get("retrieval_score") is None: doc.metadata["retrieval_score"] = original_score top_documents.append((doc, rerank_score)) return top_documents def load_chunk_registry(registry_path=REGISTRY_PATH): registry_file = Path(registry_path) if not registry_file.exists(): return {"by_document": {}, "by_chunk_id": {}} return json.loads(registry_file.read_text(encoding="utf-8")) def _document_from_record(record): if not record: return None return Document( page_content=record.get("page_content", ""), metadata=record.get("metadata", {}), ) def expand_with_context_window( documents, chunk_registry, window_size=DEFAULT_CONTEXT_WINDOW, max_expanded_chunks=DEFAULT_MAX_EXPANDED_CHUNKS, ): """ Expand each retrieved chunk with neighbors from the same document. """ if window_size <= 0: return documents # The registry gives us O(1)-style neighbor lookup by document and chunk index. # Chroma finds the relevant chunk; the registry finds the chunks around it. by_document = chunk_registry.get("by_document", {}) expanded_documents = [] seen_chunk_ids = set() for hit_order, (doc, score) in enumerate(documents): hit_group = _expand_single_hit( doc=doc, score=score, hit_order=hit_order, by_document=by_document, seen_chunk_ids=seen_chunk_ids, window_size=window_size, ) expanded_documents.extend(hit_group) if len(expanded_documents) >= max_expanded_chunks: break # Final ordering preserves which retrieval hit came first, then keeps chunks # in document order inside that hit's local window. expanded_documents.sort(key=_hit_order_key) return expanded_documents[:max_expanded_chunks] def _expand_single_hit(doc, score, hit_order, by_document, seen_chunk_ids, window_size): document_id = doc.metadata.get("document_id") chunk_index = doc.metadata.get("chunk_index") if document_id is None or chunk_index is None: return _include_unindexed_hit(doc, score, hit_order, seen_chunk_ids) document_chunks = by_document.get(document_id, {}) center_index = int(chunk_index) hit_group = [] # Build a local window around the matched chunk so the LLM sees nearby # supporting text instead of one isolated chunk. for neighbor_index in range(center_index - window_size, center_index + window_size + 1): neighbor = _load_neighbor_chunk(document_chunks, neighbor_index) if neighbor is None: continue neighbor_chunk_id = neighbor.metadata.get("chunk_id") if not neighbor_chunk_id or neighbor_chunk_id in seen_chunk_ids: continue _mark_window_metadata( neighbor=neighbor, original_doc=doc, hit_order=hit_order, center_index=center_index, neighbor_index=neighbor_index, ) hit_group.append((neighbor, score)) seen_chunk_ids.add(neighbor_chunk_id) # Keep each hit group ordered like the original document: previous, hit, next. hit_group.sort(key=_document_position_key) return hit_group def _include_unindexed_hit(doc, score, hit_order, seen_chunk_ids): chunk_id = doc.metadata.get("chunk_id") if not chunk_id or chunk_id in seen_chunk_ids: return [] _mark_window_metadata( neighbor=doc, original_doc=doc, hit_order=hit_order, center_index=0, neighbor_index=0, ) seen_chunk_ids.add(chunk_id) return [(doc, score)] def _load_neighbor_chunk(document_chunks, neighbor_index): neighbor_record = document_chunks.get(str(neighbor_index)) if neighbor_record is None: return None return _document_from_record(neighbor_record) def _mark_window_metadata(neighbor, original_doc, hit_order, center_index, neighbor_index): neighbor.metadata["is_retrieved_hit"] = ( neighbor.metadata.get("chunk_id") == original_doc.metadata.get("chunk_id") ) neighbor.metadata["hit_order"] = hit_order neighbor.metadata["window_offset"] = neighbor_index - center_index def _document_position_key(item): doc, _ = item return ( doc.metadata.get("document_id", ""), int(doc.metadata.get("chunk_index", -1)), ) def _hit_order_key(item): doc, _ = item return ( int(doc.metadata.get("hit_order", 9999)), doc.metadata.get("document_id", ""), int(doc.metadata.get("chunk_index", -1)), ) def format_context(documents): """ Format retrieved chunks into one context string with source numbers. """ parts = [] for index, (doc, score) in enumerate(documents, start=1): file_name = doc.metadata.get("file_name", "unknown") page = doc.metadata.get("page") chunk_id = doc.metadata.get("chunk_id", "unknown") document_id = doc.metadata.get("document_id", "unknown") page_label = page if page is not None else "n/a" is_retrieved_hit = doc.metadata.get("is_retrieved_hit", False) window_offset = doc.metadata.get("window_offset", 0) rerank_score = doc.metadata.get("rerank_score") score_label = rerank_score if rerank_score is not None else score parts.append( ( f"Source [{index}]\n" f"File: {file_name}\n" f"Document ID: {document_id}\n" f"Chunk ID: {chunk_id}\n" f"Page: {page_label}\n" f"Matched Hit: {'yes' if is_retrieved_hit else 'no'}\n" f"Window Offset: {window_offset}\n" f"Score: {score_label}\n" f"Content: {doc.page_content}" ) ) return "\n\n".join(parts) def build_citation_sources(documents): citation_sources = [] for index, (doc, score) in enumerate(documents, start=1): citation_sources.append( { "number": index, "source": doc.metadata.get("file_name", "Unknown"), "page": doc.metadata.get("page"), "chunk_id": doc.metadata.get("chunk_id", "unknown"), "retrieval_score": doc.metadata.get("retrieval_score", score), "rerank_score": doc.metadata.get("rerank_score", score), "content": doc.page_content, } ) return citation_sources def check_grounding_evidence( retrieved_documents, expanded_documents, min_rerank_score=DEFAULT_MIN_GROUNDED_RERANK_SCORE, min_expanded_chunks=DEFAULT_MIN_GROUNDED_CHUNKS, ): if not retrieved_documents: return { "passed": False, "reason": "no_retrieved_chunks", "top_rerank_score": None, "retrieved_count": 0, "expanded_count": len(expanded_documents), } if len(expanded_documents) < min_expanded_chunks: return { "passed": False, "reason": "too_few_expanded_chunks", "top_rerank_score": _top_rerank_score(retrieved_documents), "retrieved_count": len(retrieved_documents), "expanded_count": len(expanded_documents), } top_rerank_score = _top_rerank_score(retrieved_documents) if top_rerank_score is None or top_rerank_score < min_rerank_score: return { "passed": False, "reason": "low_rerank_score", "top_rerank_score": top_rerank_score, "retrieved_count": len(retrieved_documents), "expanded_count": len(expanded_documents), } return { "passed": True, "reason": "passed", "top_rerank_score": top_rerank_score, "retrieved_count": len(retrieved_documents), "expanded_count": len(expanded_documents), } def generate_answer(query, context, llm): """ Generate an answer from the retrieved context. """ if not context.strip(): return "No relevant information found in the documents." if llm is None: return "Language model is not available." valid_sources = extract_valid_source_numbers(context) structured_llm = get_structured_answer_llm(llm) if structured_llm is not None: try: structured_response = structured_llm.invoke(build_rag_prompt(context, query)) except Exception: structured_response = None structured_answer = extract_structured_answer(structured_response) if structured_answer: if answer_has_valid_citations(structured_answer, valid_sources): return structured_answer repaired_answer = repair_answer_with_citations( query=query, context=context, draft_answer=structured_answer, llm=llm, valid_sources=valid_sources, ) if repaired_answer: return repaired_answer response = llm.invoke(build_rag_prompt(context, query)) answer = extract_response_text(response) if not answer: return "No response from language model." if answer_has_valid_citations(answer, valid_sources): return answer repaired_answer = repair_answer_with_citations( query=query, context=context, draft_answer=answer, llm=llm, valid_sources=valid_sources, ) if repaired_answer: return repaired_answer return answer def repair_answer_with_citations(query, context, draft_answer, llm, valid_sources): if not valid_sources: return draft_answer repair_prompt = build_citation_repair_prompt( context=context, question=query, answer=draft_answer, valid_sources=valid_sources, ) repaired_response = llm.invoke(repair_prompt) repaired_answer = extract_response_text(repaired_response) if not repaired_answer: return None if answer_has_valid_citations(repaired_answer, valid_sources): return repaired_answer return None def get_structured_answer_llm(llm): with_structured_output = getattr(llm, "with_structured_output", None) if with_structured_output is None: return None try: return with_structured_output(AnswerDraftResult) except Exception: return None def extract_structured_answer(response): if response is None: return "" if isinstance(response, AnswerDraftResult): return response.answer.strip() if isinstance(response, dict): return str(response.get("answer", "")).strip() response_answer = getattr(response, "answer", None) if response_answer is not None: return str(response_answer).strip() return extract_response_text(response) def extract_valid_source_numbers(context): numbers = re.findall(r"Source \[(\d+)\]", context) return sorted({int(number) for number in numbers}) def extract_citation_numbers(answer): numbers = re.findall(r"\[(\d+)\]", answer) return [int(number) for number in numbers] def answer_has_valid_citations(answer, valid_sources): citation_numbers = extract_citation_numbers(answer) if not citation_numbers: return False valid_source_set = set(valid_sources) return all(number in valid_source_set for number in citation_numbers) def _top_rerank_score(documents): doc, score = documents[0] rerank_score = doc.metadata.get("rerank_score") if rerank_score is None: return None return float(rerank_score if rerank_score is not None else score)