Insurance_Pilot / app /rag /graph.py
Shoaib-33's picture
Update app/rag/graph.py
5a03ff2 verified
Raw
History Blame Contribute Delete
24.6 kB
import json
import re
from typing import Any
from langgraph.graph import END, StateGraph
from app.core.config import settings
from app.db.sqlite import db
from app.guardrails.pii import PIIGuardrails
from app.guardrails.prompt_injection import detect_prompt_injection, strip_unsafe_retrieved_text
from app.rag.hybrid_retriever import HybridRetriever
from app.rag.reranker import FlashRankReranker
from app.rag.semantic_cache import SemanticAnswerCache
from app.rag.self_rag import SelfRAG
from app.rag.state import RagState
from app.rag.text import new_id, normalize_text
from app.services.groq_llm import GroqLLM
from app.services.memory import ClaimMemoryService
class ClaimsRAGGraph:
def __init__(self) -> None:
self.cache = SemanticAnswerCache()
self.guardrails = PIIGuardrails()
self.retriever = HybridRetriever()
self.reranker = FlashRankReranker()
self.self_rag = SelfRAG()
self.llm = GroqLLM()
self.memory = ClaimMemoryService()
self.graph = self._build_graph()
def run(
self,
query: str,
metadata_filter: dict[str, Any] | None = None,
user_id: str = "default_user",
use_cache: bool = True,
) -> RagState:
request_id = new_id("request")
state: RagState = {
"request_id": request_id,
"query": query,
"user_id": user_id,
"sanitized_query": query,
"retrieval_query": query,
"normalized_query": normalize_text(query),
"memory_context": "",
"iteration": 0,
"cache_hit": False,
"use_cache": use_cache,
"trace": [{"node": "start", "query": query, "metadata_filter": metadata_filter or {}}],
}
if metadata_filter:
state["metadata_filter"] = metadata_filter # type: ignore[typeddict-unknown-key]
result = self.graph.invoke(state)
self._save_trace(result)
return result
def _build_graph(self):
workflow = StateGraph(RagState)
workflow.add_node("planner", self._planner)
workflow.add_node("semantic_cache", self._semantic_cache)
workflow.add_node("guardrails", self._guardrails)
workflow.add_node("load_memory", self._load_memory)
workflow.add_node("direct_response", self._direct_response)
workflow.add_node("retrieve", self._retrieve)
workflow.add_node("rerank", self._rerank)
workflow.add_node("generate", self._generate)
workflow.add_node("critique", self._critique)
workflow.add_node("rewrite", self._rewrite)
workflow.add_node("finalize", self._finalize)
workflow.set_entry_point("planner")
workflow.add_edge("planner", "semantic_cache")
workflow.add_conditional_edges(
"semantic_cache",
self._route_cache,
{"hit": "finalize", "miss": "guardrails"},
)
workflow.add_conditional_edges(
"guardrails",
lambda state: "direct" if state.get("answer") and "override system" in state["answer"] else "memory",
{"direct": "direct_response", "memory": "load_memory"},
)
workflow.add_conditional_edges(
"load_memory",
self._route_retrieval,
{"direct": "direct_response", "out_of_domain": "direct_response", "retrieve": "retrieve"},
)
workflow.add_edge("direct_response", "finalize")
workflow.add_edge("retrieve", "rerank")
workflow.add_edge("rerank", "generate")
workflow.add_edge("generate", "critique")
workflow.add_conditional_edges(
"critique",
self._route_critique,
{"accept": "finalize", "retry": "rewrite"},
)
workflow.add_edge("rewrite", "retrieve")
workflow.add_edge("finalize", END)
return workflow.compile()
def _planner(self, state: RagState) -> RagState:
decision = self.self_rag.grade_retrieval_need(state["query"])
state["should_retrieve"] = bool(decision.get("should_retrieve", True))
state["intent"] = str(decision.get("intent", "claims_question"))
state["risk_level"] = decision.get("risk_level", "medium") # type: ignore[assignment]
self._trace(state, "planner", decision)
return state
def _load_memory(self, state: RagState) -> RagState:
memory_context = self.memory.search(
user_id=state.get("user_id", "default_user"),
query=state["sanitized_query"],
)
state["memory_context"] = memory_context
self._trace(
state,
"load_memory",
{
"langmem_available": self.memory.langmem_available,
"has_memory": memory_context != "No prior memory for this user.",
},
)
return state
def _semantic_cache(self, state: RagState) -> RagState:
if not state.get("use_cache", True):
state["cache_hit"] = False
self._trace(state, "semantic_cache", {"hit": False, "disabled": True})
return state
hit = self.cache.lookup(state["query"])
if hit:
state["cache_hit"] = True
state["answer"] = hit["answer"]
state["confidence"] = hit["confidence"]
state["sources"] = hit["sources"]
has_sources = bool(hit["sources"])
state["self_rag"] = {
"passed": True,
"retrieve": True,
"isrel": has_sources,
"issup": has_sources,
"isuse": bool(hit["answer"]),
"confidence": hit["confidence"],
"issues": ["Served from semantic answer cache."],
}
self._trace(state, "semantic_cache", {"hit": True, "score": hit["score"]})
return state
state["cache_hit"] = False
self._trace(state, "semantic_cache", {"hit": False})
return state
def _guardrails(self, state: RagState) -> RagState:
pii = self.guardrails.sanitize(state["query"])
injection = detect_prompt_injection(state["query"])
state["sanitized_query"] = pii.text
if injection:
state["should_retrieve"] = False
state["answer"] = "I cannot help with requests that try to override system or safety instructions."
state["confidence"] = 0.99
self._trace(
state,
"guardrails",
{
"pii_findings": pii.findings,
"prompt_injection_findings": injection,
"langchain_pii_middlewares": len(self.guardrails.langchain_middlewares),
},
)
return state
def _direct_response(self, state: RagState) -> RagState:
if state.get("answer"):
return state
if state.get("intent") == "out_of_domain":
state["answer"] = self._scope_message()
state["confidence"] = 0.98
state["sources"] = []
state["self_rag"] = {
"passed": True,
"retrieve": False,
"isrel": False,
"issup": False,
"isuse": True,
"confidence": 0.98,
"issues": ["Out-of-domain request blocked before retrieval."],
}
self._trace(state, "direct_response", {"reason": "out_of_domain"})
return state
if state.get("intent") == "smalltalk":
state["answer"] = self._scope_message()
state["confidence"] = 0.9
state["sources"] = []
state["self_rag"] = {
"passed": True,
"retrieve": False,
"isrel": False,
"issup": False,
"isuse": True,
"confidence": 0.9,
"issues": ["Smalltalk redirected to insurance scope."],
}
self._trace(state, "direct_response", {"reason": "smalltalk"})
return state
answer = self.llm.invoke_text(
system=(
"You are an insurance claims support copilot. Answer simple non-policy questions "
"briefly. Do not invent coverage, policy terms, claim outcomes, or payments."
),
user=state["sanitized_query"],
)
state["answer"] = answer
state["confidence"] = 0.72
state["sources"] = []
state["self_rag"] = {
"passed": True,
"confidence": 0.72,
"issues": ["Direct response path. Retrieval was not required."],
}
self._trace(state, "direct_response", {"confidence": state["confidence"]})
return state
def _scope_message(self) -> str:
return (
"I can only help with insurance-related claims, coverage, policy terms, claim "
"documents, and claim procedures. Please ask an insurance claim question."
)
def _retrieve(self, state: RagState) -> RagState:
query = self._prepare_retrieval_query(state)
metadata_filter = state.get("metadata_filter") # type: ignore[typeddict-item]
sources = self.retriever.retrieve(query, metadata_filter=metadata_filter)
cleaned_sources = []
for source in sources:
cleaned_sources.append({**source, "text": strip_unsafe_retrieved_text(source.get("text", ""))})
state["sources"] = cleaned_sources
self._trace(state, "retrieve", {"count": len(cleaned_sources), "retrieval_query": query})
return state
def _rerank(self, state: RagState) -> RagState:
reranked = self.reranker.rerank(
state.get("retrieval_query", state["sanitized_query"]),
state.get("sources", []),
top_k=settings.rerank_top_k,
)
state["reranked_sources"] = reranked
self._trace(state, "rerank", {"count": len(reranked)})
return state
def _prepare_retrieval_query(self, state: RagState) -> str:
if not settings.enable_query_rewrite:
state["retrieval_query"] = state["sanitized_query"]
return state["retrieval_query"]
if state.get("retrieval_query") and state.get("retrieval_query") != state.get("query"):
return state["retrieval_query"]
result = self.llm.invoke_json(
system=(
"You rewrite user insurance questions into concise retrieval queries for a hybrid "
"BM25 + vector RAG system. Preserve all facts from the user. Do not answer the "
"question. Add only helpful insurance terminology that improves retrieval, such as "
"coverage part, exclusion, deductible, endorsement, claim documents, fault, valuation, "
"or policy limit when relevant.\n\n"
"Important retrieval discipline:\n"
"- Keep the rewritten query focused on the user's requested claim issue.\n"
"- Do not add benefits, services, or subtopics the user did not ask about.\n"
"- If the user asks about damage to insured property, focus on the coverage for that "
"damage, required evidence, deductible, and exclusions.\n"
"- If a term in the user question is vague, rewrite it into precise insurance language, "
"but do not change the claim type.\n"
"- Prefer compact keyword-rich wording over a sentence.\n\n"
"Return JSON only with keys: query, changed, rationale."
),
user=(
f"Intent: {state.get('intent', 'unknown')}\n"
f"Original user question:\n{state['sanitized_query']}"
),
fallback={"query": state["sanitized_query"], "changed": False, "rationale": "fallback"},
)
rewritten = str(result.get("query") or state["sanitized_query"]).strip()
if not rewritten:
rewritten = state["sanitized_query"]
state["retrieval_query"] = rewritten
self._trace(
state,
"query_rewrite",
{
"original_query": state["sanitized_query"],
"retrieval_query": rewritten,
"changed": bool(result.get("changed", rewritten != state["sanitized_query"])),
"rationale": str(result.get("rationale", ""))[:300],
},
)
return rewritten
def _generate(self, state: RagState) -> RagState:
sources = state.get("reranked_sources", [])
llm_sources = sources[: settings.max_sources_to_llm]
evidence = "\n\n".join(
(
f"Source {i + 1}: {src.get('source_name', 'unknown')}\n"
f"Text: {src.get('text', '')[: settings.max_evidence_chars_per_source]}"
)
for i, src in enumerate(llm_sources)
)
if state.get("intent") == "general_insurance_concept":
answer = self.llm.invoke_text(
system=(
"You are an insurance education assistant. The user is asking a general "
"insurance concept question, not requesting a claim payment decision. Use only "
"the provided evidence. Answer briefly and clearly in plain language. Do not use "
"the claim triage structure. Do not say Likely covered, Likely not covered, or "
"Needs human review unless the user asks about a claim scenario. Cite sources as "
"[Source 1], [Source 2], etc. Do not use outside source names."
),
user=(
f"Question:\n{state['sanitized_query']}\n\n"
f"Retrieved evidence:\n{evidence}"
),
)
state["answer"] = self._ensure_source_citation(answer, sources)
self._trace(state, "generate", {"source_count": len(sources), "mode": "concept_llm"})
return state
answer = self._generate_claim_json_answer(state, evidence, sources)
state["answer"] = self._ensure_source_citation(answer, sources)
self._trace(state, "generate", {"source_count": len(sources), "mode": "llm"})
return state
def _ensure_source_citation(self, answer: str, sources: list[dict[str, Any]]) -> str:
if not sources or re.search(r"\[Source\s+\d+\]", answer, flags=re.IGNORECASE):
return answer
return answer.rstrip() + " [Source 1]"
def _generate_claim_json_answer(self, state: RagState, evidence: str, sources: list[dict[str, Any]]) -> str:
result = self.llm.invoke_json(
system=(
"You are an insurance claim support AI agent. The user describes a claim scenario. "
"Use only the retrieved evidence. Do not use outside knowledge. Do not invent final "
"payment approval, denial, claim status, policy terms, or source names.\n\n"
"Return JSON only with exactly these keys:\n"
"- decision: one of Likely covered, Likely not covered, Needs human review\n"
"- reason: one or two evidence-grounded sentences\n"
"- missing_evidence: short string listing missing documents/facts or None identified\n"
"- recommended_action: short string with next step, escalation, or review action\n"
"- sources: short string with citations like [Source 1], [Source 2]\n\n"
"Rubric:\n"
"- Likely covered: use only when evidence directly says this cause of loss or scenario is "
"normally covered by the relevant coverage part and the user's facts do not leave a major "
"coverage dependency unresolved.\n"
"- Likely not covered: use only when evidence directly says this cause of loss is excluded, "
"not covered by the standard policy, or requires separate coverage that the user says they "
"do not have.\n"
"- Needs human review: use when payment or coverage depends on unresolved policy-specific "
"facts, endorsements, sublimits, deductibles, fault, valuation, contestability, fraud "
"review, regulatory timing, guaranty fund state limits, settlement amount disputes, or "
"other claim-file details.\n\n"
"If evidence is incomplete, still choose the best triage label and explain what is missing. "
"Allowed citations are only [Source 1], [Source 2], etc."
),
user=(
f"User memory context:\n{state.get('memory_context', 'No prior memory.')}\n\n"
f"Claim scenario:\n{state['sanitized_query']}\n\n"
f"Retrieved evidence:\n{evidence}"
),
fallback={
"decision": "Needs human review",
"reason": "The available evidence is not sufficient to make a final coverage triage.",
"missing_evidence": "Policy-specific details and claim-file documentation.",
"recommended_action": "Escalate for human review with the retrieved evidence.",
"sources": "[Source 1]" if sources else "",
},
)
decision = str(result.get("decision", "Needs human review"))
if decision not in {"Likely covered", "Likely not covered", "Needs human review"}:
decision = "Needs human review"
sources_text = str(result.get("sources", "")).strip()
if sources and not re.search(r"\[Source\s+\d+\]", sources_text, flags=re.IGNORECASE):
sources_text = "[Source 1]"
return (
f"Decision: {decision}\n"
f"Reason: {str(result.get('reason', '')).strip()}\n"
f"Missing evidence: {str(result.get('missing_evidence', '')).strip()}\n"
f"Recommended action: {str(result.get('recommended_action', '')).strip()}\n"
f"Sources: {sources_text}"
)
def _critique(self, state: RagState) -> RagState:
if state.get("intent") == "general_insurance_concept":
sources = state.get("reranked_sources", [])
answer = state.get("answer", "")
critique = {
"passed": bool(sources) and bool(answer),
"retrieve": True,
"isrel": bool(sources),
"issup": bool(sources) and "[source" in answer.lower(),
"isuse": len(answer.strip()) > 20,
"confidence": 0.84 if sources and answer else 0.45,
"relevance_score": 0.85 if sources else 0.0,
"faithfulness_score": 0.8 if sources and "[source" in answer.lower() else 0.35,
"evidence_score": min(0.9, 0.35 + len(sources) * 0.1),
"needs_rewrite": False,
"rewrite_query": None,
"issues": ["General insurance concept answer with retrieved sources."],
}
state["self_rag"] = critique
state["confidence"] = float(critique["confidence"])
self._trace(state, "critique", critique)
return state
critique = self.self_rag.critique(
query=state["sanitized_query"],
answer=state.get("answer", ""),
sources=state.get("reranked_sources", []),
iteration=int(state.get("iteration", 0)),
)
state["self_rag"] = critique
state["confidence"] = float(critique.get("confidence", 0.0))
self._trace(state, "critique", critique)
return state
def _rewrite(self, state: RagState) -> RagState:
critique = state.get("self_rag", {})
fallback_query = critique.get("rewrite_query") or state.get("retrieval_query") or state["sanitized_query"]
result = self.llm.invoke_json(
system=(
"You are rewriting a failed insurance RAG retrieval query for a retry. Use the "
"critique issues and original user question to create a better retrieval query. "
"Preserve the user's facts. Do not answer the question. Return JSON only with "
"keys: query, rationale."
),
user=(
f"Original user question:\n{state['sanitized_query']}\n\n"
f"Previous retrieval query:\n{state.get('retrieval_query', state['sanitized_query'])}\n\n"
f"Critique issues:\n{json.dumps(critique.get('issues', []), ensure_ascii=True)}"
),
fallback={"query": fallback_query, "rationale": "fallback"},
)
rewrite = str(result.get("query") or fallback_query).strip() or str(fallback_query)
state["retrieval_query"] = rewrite
state["iteration"] = int(state.get("iteration", 0)) + 1
self._trace(
state,
"rewrite",
{
"retrieval_query": state["retrieval_query"],
"iteration": state["iteration"],
"rationale": str(result.get("rationale", ""))[:300],
},
)
return state
def _finalize(self, state: RagState) -> RagState:
if state.get("answer"):
sanitized = self.guardrails.sanitize(state["answer"]).text
state["answer"] = self.guardrails.clean_legacy_false_positive_placeholders(sanitized)
if (
state.get("use_cache", True)
and not state.get("cache_hit")
and state.get("answer")
and not self._has_unsupported_citation(state)
):
self.cache.save(
query=state["query"],
answer=state["answer"],
confidence=float(state.get("confidence", 0.0)),
sources=state.get("reranked_sources") or state.get("sources", []),
)
if state.get("answer") and state.get("reranked_sources"):
self.memory.save_interaction(
user_id=state.get("user_id", "default_user"),
query=state["query"],
answer=state["answer"],
critique=state.get("self_rag", {}),
sources=state.get("reranked_sources", []),
)
self._trace(
state,
"finalize",
{
"cache_hit": state.get("cache_hit", False),
"confidence": state.get("confidence", 0.0),
"iterations": state.get("iteration", 0),
},
)
return state
def _has_unsupported_citation(self, state: RagState) -> bool:
answer = state.get("answer", "").lower()
blocked_markers = [
"insurance information institute",
"source: none",
"source: insurance",
"according to standard insurance terminology",
]
return any(marker in answer for marker in blocked_markers)
def _route_cache(self, state: RagState) -> str:
return "hit" if state.get("cache_hit") else "miss"
def _route_retrieval(self, state: RagState) -> str:
if state.get("answer") and "override system" in state["answer"]:
return "direct"
if state.get("intent") == "out_of_domain":
return "out_of_domain"
return "retrieve" if state.get("should_retrieve", True) else "direct"
def _route_critique(self, state: RagState) -> str:
critique = state.get("self_rag", {})
passed = bool(critique.get("passed", False))
confidence = float(critique.get("confidence", 0.0))
iteration = int(state.get("iteration", 0))
if passed and confidence >= 0.68:
return "accept"
if iteration >= settings.self_rag_max_loops:
return "accept"
if critique.get("needs_rewrite", True):
return "retry"
return "accept"
def _trace(self, state: RagState, node: str, payload: dict[str, Any]) -> None:
trace = state.setdefault("trace", [])
trace.append({"node": node, **payload})
def _save_trace(self, state: RagState) -> None:
with db() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO traces(request_id, query, trace_json)
VALUES (?, ?, ?)
""",
(
state["request_id"],
state["query"],
json.dumps(state.get("trace", []), ensure_ascii=True),
),
)