| from typing import Any, TypedDict |
|
|
| from agents.services import ( |
| AnswerResult, |
| answer_is_grounded, |
| build_citations, |
| build_metadata_filter, |
| build_sources, |
| ) |
| from app.actions import ingest_saved_document |
| from agents.research_graph import run_research_agent |
| from rag.registry import load_chunk_registry |
| from rag.retrieve import ( |
| check_grounding_evidence, |
| expand_with_context_window, |
| format_context, |
| generate_answer, |
| generate_sub_queries, |
| load_bm25_index, |
| load_vectorstore, |
| retrieve_documents_with_query_transform, |
| ) |
| from core.config import DB_DIR, URL_IMPORT_DIR |
| from rag.ingest import add_documents_to_vectorstore |
| from rag.web_ingest import build_research_import_payload, save_url_import |
|
|
| try: |
| from langgraph.graph import END, StateGraph |
| except ImportError: |
| END = None |
| StateGraph = None |
|
|
|
|
| class RagGraphState(TypedDict, total=False): |
| query: str |
| selected_file: str |
| selected_file_type: str |
| page_start: str |
| page_end: str |
| debug_mode: bool |
| metadata_filter: dict[str, Any] | None |
| retrieved_documents: list[Any] |
| expanded_documents: list[Any] |
| grounding: dict[str, Any] | None |
| next_action: str | None |
| decision_reason: str | None |
| retry_count: int |
| refined_query: str | None |
| web_research_attempted: bool |
| web_research_performed: bool |
| refreshed_vectorstore: Any |
| refreshed_chunk_registry: dict[str, Any] | None |
| refreshed_bm25_index: Any |
| answer: str | None |
| citations: list[dict] |
| sources: list[dict] |
| debug_data: dict[str, Any] | None |
|
|
|
|
| def build_initial_graph_state( |
| query, |
| *, |
| selected_file="", |
| selected_file_type="", |
| page_start="", |
| page_end="", |
| debug_mode=False, |
| ): |
| return { |
| "query": query, |
| "selected_file": selected_file, |
| "selected_file_type": selected_file_type, |
| "page_start": page_start, |
| "page_end": page_end, |
| "debug_mode": debug_mode, |
| "metadata_filter": None, |
| "retrieved_documents": [], |
| "expanded_documents": [], |
| "grounding": None, |
| "next_action": None, |
| "decision_reason": None, |
| "retry_count": 0, |
| "refined_query": None, |
| "web_research_attempted": False, |
| "web_research_performed": False, |
| "refreshed_vectorstore": None, |
| "refreshed_chunk_registry": None, |
| "refreshed_bm25_index": None, |
| "answer": None, |
| "citations": [], |
| "sources": [], |
| "debug_data": None, |
| } |
|
|
|
|
| def retrieve_node( |
| state: RagGraphState, |
| *, |
| vectorstore, |
| reranker, |
| bm25_index, |
| llm, |
| retrieval_k, |
| rerank_candidate_k, |
| bm25_candidate_k, |
| enable_query_transform, |
| ) -> RagGraphState: |
| active_vectorstore = state.get("refreshed_vectorstore") or vectorstore |
| active_bm25_index = state.get("refreshed_bm25_index") or bm25_index |
| current_query = state.get("refined_query") or state["query"] |
| should_transform_query = enable_query_transform and not state.get("refined_query") |
| print(f"\n[RAG Agent] Node: retrieve | Query: '{current_query}'") |
| print( |
| f"[RAG Agent] Retrieve Context | Retry Count: {state.get('retry_count', 0)}" |
| f" | Has Refined Query: {bool(state.get('refined_query'))}" |
| f" | Query Transform Enabled: {should_transform_query}" |
| ) |
| |
| updated_state = dict(state) |
| print("[RAG Agent] Retrieve Step: calling retrieve_documents_with_query_transform...") |
| retrieval_result = retrieve_documents_with_query_transform( |
| active_vectorstore, |
| current_query, |
| k=retrieval_k, |
| reranker=reranker, |
| bm25_index=active_bm25_index, |
| query_transformer=llm, |
| enable_query_transform=should_transform_query, |
| candidate_k=rerank_candidate_k, |
| bm25_candidate_k=bm25_candidate_k, |
| metadata_filter=state.get("metadata_filter"), |
| include_debug=state.get("debug_mode", False), |
| ) |
| print("[RAG Agent] Retrieve Step: retrieve_documents_with_query_transform completed.") |
|
|
| if state.get("debug_mode", False): |
| retrieved_documents, new_debug_data = retrieval_result |
| debug_data = dict(state.get("debug_data") or {}) |
| debug_data.update(new_debug_data) |
| debug_data["metadata_filter"] = state.get("metadata_filter") |
| debug_data["grounding"] = { |
| "stage": "retrieval", |
| "passed": None, |
| "reason": "not_checked", |
| } |
| else: |
| retrieved_documents = retrieval_result |
| debug_data = None |
|
|
| print(f"[RAG Agent] Retrieved {len(retrieved_documents)} documents.") |
| updated_state["retrieved_documents"] = retrieved_documents |
| updated_state["debug_data"] = debug_data |
| return updated_state |
|
|
|
|
| def expand_context_node( |
| state: RagGraphState, |
| *, |
| chunk_registry, |
| context_window, |
| max_expanded_chunks, |
| ) -> RagGraphState: |
| print(f"[RAG Agent] Node: expand_context | Window Size: {context_window}") |
| updated_state = dict(state) |
| active_chunk_registry = state.get("refreshed_chunk_registry") or chunk_registry |
| expanded_documents = expand_with_context_window( |
| state.get("retrieved_documents", []), |
| active_chunk_registry, |
| window_size=context_window, |
| max_expanded_chunks=max_expanded_chunks, |
| ) |
| updated_state["expanded_documents"] = expanded_documents |
|
|
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["expanded_hits"] = expanded_documents |
| updated_state["debug_data"]["stage_counts"]["expanded_context"] = len(expanded_documents) |
|
|
| return updated_state |
|
|
|
|
| def check_grounding_node( |
| state: RagGraphState, |
| *, |
| min_grounded_rerank_score, |
| min_grounded_chunks, |
| ) -> RagGraphState: |
| print("[RAG Agent] Node: check_grounding | Evaluating evidence...") |
| updated_state = dict(state) |
| grounding = check_grounding_evidence( |
| state.get("retrieved_documents", []), |
| state.get("expanded_documents", []), |
| min_rerank_score=min_grounded_rerank_score, |
| min_expanded_chunks=min_grounded_chunks, |
| ) |
| updated_state["grounding"] = grounding |
|
|
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["grounding"].update(grounding) |
| updated_state["debug_data"]["grounding"]["stage"] = "retrieval" |
|
|
| print(f"[RAG Agent] Grounding Result: {'PASS' if grounding['passed'] else 'FAIL'} (Score: {grounding.get('top_rerank_score')})") |
| return updated_state |
|
|
|
|
| def route_after_grounding(state: RagGraphState) -> str: |
| action = state.get("next_action") |
| print(f"[RAG Agent] Routing Decision: {action}") |
| if action == "answer": |
| return "generate_answer" |
| if action == "retry_retrieval": |
| return "refine_query" |
| if action == "web_research": |
| return "web_research" |
| return "fallback_answer" |
|
|
|
|
| def decide_local_action(grounding, *, has_any_local_evidence, retry_count): |
| if grounding.get("passed"): |
| return "answer", "grounding_passed" |
|
|
| if has_any_local_evidence and retry_count < 2: |
| return "retry_retrieval", f"partial_local_evidence_retry_{retry_count + 1}" |
|
|
| return None, None |
|
|
|
|
| def decide_terminal_action(*, enable_research, web_research_attempted, web_research_performed, retry_count): |
| if enable_research and not web_research_attempted: |
| return "web_research", "local_failed_switching_to_web" |
|
|
| if web_research_attempted and not web_research_performed: |
| return "fallback", "web_research_failed_no_documents_ingested" |
|
|
| if web_research_performed: |
| return "fallback", "web_research_also_failed" |
|
|
| if retry_count >= 2: |
| return "fallback", "local_retry_cap_reached_research_disabled" |
|
|
| return "fallback", "no_evidence_found_research_disabled" |
|
|
|
|
| def decide_next_action_node(state: RagGraphState, *, enable_research: bool = False) -> RagGraphState: |
| updated_state = dict(state) |
| grounding = state.get("grounding") or {} |
| retrieved_documents = state.get("retrieved_documents", []) |
| expanded_documents = state.get("expanded_documents", []) |
|
|
| retry_count = state.get("retry_count", 0) |
| web_research_attempted = state.get("web_research_attempted", False) |
| web_research_performed = state.get("web_research_performed", False) |
| has_any_local_evidence = len(retrieved_documents) > 0 or len(expanded_documents) > 0 |
| print( |
| f"[RAG Agent] Decision State | " |
| f"retry_count={retry_count} | " |
| f"web_research_attempted={web_research_attempted} | " |
| f"web_research_performed={web_research_performed} | " |
| f"has_any_local_evidence={has_any_local_evidence} | " |
| f"enable_research={enable_research}" |
| ) |
|
|
| next_action, decision_reason = decide_local_action( |
| grounding, |
| has_any_local_evidence=has_any_local_evidence, |
| retry_count=retry_count, |
| ) |
| if next_action is None: |
| next_action, decision_reason = decide_terminal_action( |
| enable_research=enable_research, |
| web_research_attempted=web_research_attempted, |
| web_research_performed=web_research_performed, |
| retry_count=retry_count, |
| ) |
|
|
| print(f"[RAG Agent] Decision: {next_action} | Reason: {decision_reason}") |
| updated_state["next_action"] = next_action |
| updated_state["decision_reason"] = decision_reason |
|
|
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["next_action"] = next_action |
| updated_state["debug_data"]["decision_reason"] = decision_reason |
| updated_state["debug_data"]["enable_research_flag"] = enable_research |
|
|
| return updated_state |
|
|
|
|
| def refine_query_node(state: RagGraphState, *, llm) -> RagGraphState: |
| print(f"[RAG Agent] Node: refine_query | Attempting to improve query (Retry {state.get('retry_count', 0) + 1})") |
| updated_state = dict(state) |
| print(f"[RAG Agent] Refine Step: original query='{state['query']}'") |
| print(f"[RAG Agent] Refine Step: decision reason='{state.get('decision_reason', 'unknown')}'") |
| print("[RAG Agent] Refine Step: calling shared query transform...") |
| sub_queries = generate_sub_queries( |
| state["query"], |
| llm, |
| max_queries=1, |
| ) |
| print(f"[RAG Agent] Refine Step: shared query transform returned {len(sub_queries)} candidate(s): {sub_queries}") |
| refined = sub_queries[0] if sub_queries else "" |
| updated_state["refined_query"] = refined or state["query"] |
| updated_state["retry_count"] = state.get("retry_count", 0) + 1 |
|
|
| print(f"[RAG Agent] Refined Query: '{updated_state['refined_query']}'") |
| if state.get("debug_mode") and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["refined_query"] = updated_state["refined_query"] |
| updated_state["debug_data"]["retry_count"] = updated_state["retry_count"] |
|
|
| return updated_state |
|
|
|
|
| def web_research_node( |
| state: RagGraphState, |
| *, |
| llm, |
| vectorstore, |
| ) -> RagGraphState: |
| print(f"[RAG Agent] Node: web_research | Handing off to Research Agent for topic: '{state['query']}'") |
| updated_state = dict(state) |
| |
| |
| research_result = run_research_agent( |
| topic=state["query"], |
| llm=llm, |
| debug_mode=state.get("debug_mode", False), |
| ) |
|
|
| saved_paths = [] |
| skipped_results = 0 |
| for rs in research_result.sources: |
| content = rs.get("content") or rs.get("snippet") or "" |
| if len(content) < 50: |
| print(f"[RAG Agent] Web Research Skip: content too short for '{rs.get('title', 'unknown')}'") |
| skipped_results += 1 |
| continue |
|
|
| payload = build_research_import_payload(rs) |
| try: |
| saved_path = save_url_import(payload, URL_IMPORT_DIR) |
| print(f"[RAG Agent] Web Research Saved Import: {saved_path.name}") |
| saved_paths.append(saved_path) |
| except Exception as exc: |
| print(f"[RAG Agent] Web Research Save FAILED for '{rs.get('url', '')}': {exc}") |
|
|
| ingested_paths = 0 |
| for saved_path in saved_paths: |
| try: |
| embedding_function = getattr(vectorstore, "embeddings", None) or getattr(vectorstore, "_embedding_function", None) |
| ingest_saved_document( |
| saved_path, |
| add_documents_to_vectorstore=add_documents_to_vectorstore, |
| embeddings=embedding_function, |
| ) |
| ingested_paths += 1 |
| print(f"[RAG Agent] Web Research Ingested: {saved_path.name}") |
| except Exception as exc: |
| print(f"[RAG Agent] Web Research Ingest FAILED for '{saved_path.name}': {exc}") |
|
|
| if ingested_paths: |
| refreshed_vectorstore = load_vectorstore(str(DB_DIR), embedding_function) |
| refreshed_chunk_registry = load_chunk_registry() |
| refreshed_bm25_index = load_bm25_index( |
| chunk_registry=refreshed_chunk_registry, |
| vectorstore=refreshed_vectorstore, |
| ) |
| updated_state["refreshed_vectorstore"] = refreshed_vectorstore |
| updated_state["refreshed_chunk_registry"] = refreshed_chunk_registry |
| updated_state["refreshed_bm25_index"] = refreshed_bm25_index |
| print( |
| f"[RAG Agent] Web Research Refresh Complete | " |
| f"Saved: {len(saved_paths)} | Ingested: {ingested_paths} | Skipped: {skipped_results}" |
| ) |
| else: |
| print( |
| f"[RAG Agent] Web Research Produced No Ingested Documents | " |
| f"Saved: {len(saved_paths)} | Ingested: 0 | Skipped: {skipped_results}" |
| ) |
| |
| updated_state["web_research_attempted"] = True |
| updated_state["web_research_performed"] = ingested_paths > 0 |
| updated_state["retry_count"] = 0 |
| |
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["web_research_attempted"] = True |
| updated_state["debug_data"]["web_research_performed"] = updated_state["web_research_performed"] |
| updated_state["debug_data"]["web_docs_saved"] = len(saved_paths) |
| updated_state["debug_data"]["web_docs_ingested"] = ingested_paths |
| updated_state["debug_data"]["web_docs_skipped"] = skipped_results |
|
|
| print( |
| f"[RAG Agent] Web Research Complete | " |
| f"web_research_attempted={updated_state['web_research_attempted']} | " |
| f"web_research_performed={updated_state['web_research_performed']} | " |
| f"returning_to=retrieve" |
| ) |
| |
| return updated_state |
|
|
|
|
| def fallback_answer_node( |
| state: RagGraphState, |
| *, |
| grounded_fallback_message, |
| ) -> RagGraphState: |
| print("[RAG Agent] Node: fallback_answer | Providing final fallback response.") |
| updated_state = dict(state) |
| updated_state["sources"] = build_sources(state.get("retrieved_documents", [])) |
| updated_state["answer"] = grounded_fallback_message |
| updated_state["citations"] = [] |
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["graph_path"] = "fallback_answer" |
| return updated_state |
|
|
|
|
| def generate_answer_node( |
| state: RagGraphState, |
| *, |
| llm, |
| grounded_fallback_message, |
| ) -> RagGraphState: |
| print("[RAG Agent] Node: generate_answer | Creating grounded response...") |
| updated_state = dict(state) |
| retrieved_documents = state.get("retrieved_documents", []) |
| updated_state["sources"] = build_sources(retrieved_documents) |
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["graph_path"] = "generate_answer" |
|
|
| context = format_context(state.get("expanded_documents", [])) |
| answer = generate_answer(state["query"], context, llm) |
| grounding = state.get("grounding") or {} |
| if not answer_is_grounded(answer, context): |
| print("[RAG Agent] Final Answer Validation FAILED. Falling back.") |
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["grounding"] = { |
| "stage": "answer", |
| "passed": False, |
| "reason": "citation_validation_failed", |
| "top_rerank_score": grounding.get("top_rerank_score"), |
| "retrieved_count": grounding.get("retrieved_count"), |
| "expanded_count": grounding.get("expanded_count"), |
| } |
| updated_state["answer"] = grounded_fallback_message |
| updated_state["citations"] = [] |
| return updated_state |
|
|
| print("[RAG Agent] Final Answer Validation PASSED.") |
| if state.get("debug_mode", False) and updated_state.get("debug_data") is not None: |
| updated_state["debug_data"]["grounding"] = { |
| "stage": "answer", |
| "passed": True, |
| "reason": "answer_is_grounded", |
| "top_rerank_score": grounding.get("top_rerank_score"), |
| "retrieved_count": grounding.get("retrieved_count"), |
| "expanded_count": grounding.get("expanded_count"), |
| } |
|
|
| updated_state["answer"] = answer |
| updated_state["citations"] = build_citations(state.get("expanded_documents", []), answer) |
| return updated_state |
|
|
|
|
| def prepare_input_node(state: RagGraphState) -> RagGraphState: |
| print(f"\n--- Starting RAG Graph Session: '{state['query'][:50]}...' ---") |
| updated_state = dict(state) |
| updated_state["metadata_filter"] = build_metadata_filter( |
| selected_file=state.get("selected_file", ""), |
| selected_file_type=state.get("selected_file_type", ""), |
| page_start=state.get("page_start", ""), |
| page_end=state.get("page_end", ""), |
| ) |
| return updated_state |
|
|
|
|
| def finalize_response_node(state: RagGraphState) -> RagGraphState: |
| print("--- RAG Graph Session Complete ---\n") |
| return dict(state) |
|
|
|
|
| def build_rag_graph( |
| *, |
| vectorstore=None, |
| chunk_registry=None, |
| reranker=None, |
| bm25_index=None, |
| llm=None, |
| retrieval_k=4, |
| rerank_candidate_k=8, |
| bm25_candidate_k=8, |
| context_window=1, |
| max_expanded_chunks=12, |
| min_grounded_rerank_score=1.0, |
| min_grounded_chunks=1, |
| grounded_fallback_message="I don't have enough support in the current documents to answer that confidently.", |
| enable_query_transform=True, |
| enable_research: bool = False, |
| ): |
| if StateGraph is None or END is None: |
| raise RuntimeError("langgraph is not installed. Install requirements before using the graph.") |
|
|
| graph = StateGraph(RagGraphState) |
| graph.add_node("prepare_input", prepare_input_node) |
| graph.add_node( |
| "retrieve", |
| lambda state: retrieve_node( |
| state, |
| vectorstore=vectorstore, |
| reranker=reranker, |
| bm25_index=bm25_index, |
| llm=llm, |
| retrieval_k=retrieval_k, |
| rerank_candidate_k=rerank_candidate_k, |
| bm25_candidate_k=bm25_candidate_k, |
| enable_query_transform=enable_query_transform, |
| ), |
| ) |
| graph.add_node( |
| "expand_context", |
| lambda state: expand_context_node( |
| state, |
| chunk_registry=chunk_registry, |
| context_window=context_window, |
| max_expanded_chunks=max_expanded_chunks, |
| ), |
| ) |
| graph.add_node( |
| "check_grounding", |
| lambda state: check_grounding_node( |
| state, |
| min_grounded_rerank_score=min_grounded_rerank_score, |
| min_grounded_chunks=min_grounded_chunks, |
| ), |
| ) |
| graph.add_node( |
| "decide_next_action", |
| lambda state: decide_next_action_node(state, enable_research=enable_research) |
| ) |
| graph.add_node( |
| "refine_query", |
| lambda state: refine_query_node(state, llm=llm), |
| ) |
| graph.add_node( |
| "web_research", |
| lambda state: web_research_node(state, llm=llm, vectorstore=vectorstore), |
| ) |
| graph.add_node( |
| "generate_answer", |
| lambda state: generate_answer_node( |
| state, |
| llm=llm, |
| grounded_fallback_message=grounded_fallback_message, |
| ), |
| ) |
| graph.add_node( |
| "fallback_answer", |
| lambda state: fallback_answer_node( |
| state, |
| grounded_fallback_message=grounded_fallback_message, |
| ), |
| ) |
| graph.add_node("finalize_response", finalize_response_node) |
| |
| graph.set_entry_point("prepare_input") |
| graph.add_edge("prepare_input", "retrieve") |
| graph.add_edge("retrieve", "expand_context") |
| graph.add_edge("expand_context", "check_grounding") |
| graph.add_edge("check_grounding", "decide_next_action") |
| |
| graph.add_conditional_edges( |
| "decide_next_action", |
| route_after_grounding, |
| { |
| "generate_answer": "generate_answer", |
| "refine_query": "refine_query", |
| "web_research": "web_research", |
| "fallback_answer": "fallback_answer", |
| }, |
| ) |
| |
| graph.add_edge("refine_query", "retrieve") |
| graph.add_edge("web_research", "retrieve") |
| |
| graph.add_edge("generate_answer", "finalize_response") |
| graph.add_edge("fallback_answer", "finalize_response") |
| graph.add_edge("finalize_response", END) |
| |
| return graph.compile() |
|
|
|
|
| def run_rag_graph_answer( |
| query, |
| *, |
| vectorstore, |
| chunk_registry, |
| reranker, |
| bm25_index, |
| llm, |
| retrieval_k, |
| rerank_candidate_k, |
| bm25_candidate_k, |
| context_window, |
| max_expanded_chunks, |
| min_grounded_rerank_score, |
| min_grounded_chunks, |
| grounded_fallback_message, |
| enable_query_transform, |
| selected_file="", |
| selected_file_type="", |
| page_start="", |
| page_end="", |
| debug_mode=False, |
| enable_research: bool = False, |
| ): |
| graph = build_rag_graph( |
| vectorstore=vectorstore, |
| chunk_registry=chunk_registry, |
| reranker=reranker, |
| bm25_index=bm25_index, |
| llm=llm, |
| retrieval_k=retrieval_k, |
| rerank_candidate_k=rerank_candidate_k, |
| bm25_candidate_k=bm25_candidate_k, |
| context_window=context_window, |
| max_expanded_chunks=max_expanded_chunks, |
| min_grounded_rerank_score=min_grounded_rerank_score, |
| min_grounded_chunks=min_grounded_chunks, |
| grounded_fallback_message=grounded_fallback_message, |
| enable_query_transform=enable_query_transform, |
| enable_research=enable_research, |
| ) |
| initial_state = build_initial_graph_state( |
| query, |
| selected_file=selected_file, |
| selected_file_type=selected_file_type, |
| page_start=page_start, |
| page_end=page_end, |
| debug_mode=debug_mode, |
| ) |
| final_state = graph.invoke(initial_state) |
| debug_data = final_state.get("debug_data") |
| if debug_mode and debug_data is not None: |
| debug_data["pipeline_mode"] = "langgraph_rag" |
| return AnswerResult( |
| answer=final_state.get("answer", grounded_fallback_message), |
| sources=final_state.get("sources", []), |
| citations=final_state.get("citations", []), |
| debug_data=debug_data, |
| ) |
|
|