Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| from typing import Any | |
| from langgraph.graph import END, StateGraph | |
| from app.core.config import settings | |
| from app.db.sqlite import db | |
| from app.guardrails.pii import PIIGuardrails | |
| from app.guardrails.prompt_injection import detect_prompt_injection, strip_unsafe_retrieved_text | |
| from app.rag.hybrid_retriever import HybridRetriever | |
| from app.rag.reranker import FlashRankReranker | |
| from app.rag.semantic_cache import SemanticAnswerCache | |
| from app.rag.self_rag import SelfRAG | |
| from app.rag.state import RagState | |
| from app.rag.text import new_id, normalize_text | |
| from app.services.groq_llm import GroqLLM | |
| from app.services.memory import ClaimMemoryService | |
| class ClaimsRAGGraph: | |
| def __init__(self) -> None: | |
| self.cache = SemanticAnswerCache() | |
| self.guardrails = PIIGuardrails() | |
| self.retriever = HybridRetriever() | |
| self.reranker = FlashRankReranker() | |
| self.self_rag = SelfRAG() | |
| self.llm = GroqLLM() | |
| self.memory = ClaimMemoryService() | |
| self.graph = self._build_graph() | |
| def run( | |
| self, | |
| query: str, | |
| metadata_filter: dict[str, Any] | None = None, | |
| user_id: str = "default_user", | |
| use_cache: bool = True, | |
| ) -> RagState: | |
| request_id = new_id("request") | |
| state: RagState = { | |
| "request_id": request_id, | |
| "query": query, | |
| "user_id": user_id, | |
| "sanitized_query": query, | |
| "retrieval_query": query, | |
| "normalized_query": normalize_text(query), | |
| "memory_context": "", | |
| "iteration": 0, | |
| "cache_hit": False, | |
| "use_cache": use_cache, | |
| "trace": [{"node": "start", "query": query, "metadata_filter": metadata_filter or {}}], | |
| } | |
| if metadata_filter: | |
| state["metadata_filter"] = metadata_filter # type: ignore[typeddict-unknown-key] | |
| result = self.graph.invoke(state) | |
| self._save_trace(result) | |
| return result | |
| def _build_graph(self): | |
| workflow = StateGraph(RagState) | |
| workflow.add_node("planner", self._planner) | |
| workflow.add_node("semantic_cache", self._semantic_cache) | |
| workflow.add_node("guardrails", self._guardrails) | |
| workflow.add_node("load_memory", self._load_memory) | |
| workflow.add_node("direct_response", self._direct_response) | |
| workflow.add_node("retrieve", self._retrieve) | |
| workflow.add_node("rerank", self._rerank) | |
| workflow.add_node("generate", self._generate) | |
| workflow.add_node("critique", self._critique) | |
| workflow.add_node("rewrite", self._rewrite) | |
| workflow.add_node("finalize", self._finalize) | |
| workflow.set_entry_point("planner") | |
| workflow.add_edge("planner", "semantic_cache") | |
| workflow.add_conditional_edges( | |
| "semantic_cache", | |
| self._route_cache, | |
| {"hit": "finalize", "miss": "guardrails"}, | |
| ) | |
| workflow.add_conditional_edges( | |
| "guardrails", | |
| lambda state: "direct" if state.get("answer") and "override system" in state["answer"] else "memory", | |
| {"direct": "direct_response", "memory": "load_memory"}, | |
| ) | |
| workflow.add_conditional_edges( | |
| "load_memory", | |
| self._route_retrieval, | |
| {"direct": "direct_response", "out_of_domain": "direct_response", "retrieve": "retrieve"}, | |
| ) | |
| workflow.add_edge("direct_response", "finalize") | |
| workflow.add_edge("retrieve", "rerank") | |
| workflow.add_edge("rerank", "generate") | |
| workflow.add_edge("generate", "critique") | |
| workflow.add_conditional_edges( | |
| "critique", | |
| self._route_critique, | |
| {"accept": "finalize", "retry": "rewrite"}, | |
| ) | |
| workflow.add_edge("rewrite", "retrieve") | |
| workflow.add_edge("finalize", END) | |
| return workflow.compile() | |
| def _planner(self, state: RagState) -> RagState: | |
| decision = self.self_rag.grade_retrieval_need(state["query"]) | |
| state["should_retrieve"] = bool(decision.get("should_retrieve", True)) | |
| state["intent"] = str(decision.get("intent", "claims_question")) | |
| state["risk_level"] = decision.get("risk_level", "medium") # type: ignore[assignment] | |
| self._trace(state, "planner", decision) | |
| return state | |
| def _load_memory(self, state: RagState) -> RagState: | |
| memory_context = self.memory.search( | |
| user_id=state.get("user_id", "default_user"), | |
| query=state["sanitized_query"], | |
| ) | |
| state["memory_context"] = memory_context | |
| self._trace( | |
| state, | |
| "load_memory", | |
| { | |
| "langmem_available": self.memory.langmem_available, | |
| "has_memory": memory_context != "No prior memory for this user.", | |
| }, | |
| ) | |
| return state | |
| def _semantic_cache(self, state: RagState) -> RagState: | |
| if not state.get("use_cache", True): | |
| state["cache_hit"] = False | |
| self._trace(state, "semantic_cache", {"hit": False, "disabled": True}) | |
| return state | |
| hit = self.cache.lookup(state["query"]) | |
| if hit: | |
| state["cache_hit"] = True | |
| state["answer"] = hit["answer"] | |
| state["confidence"] = hit["confidence"] | |
| state["sources"] = hit["sources"] | |
| has_sources = bool(hit["sources"]) | |
| state["self_rag"] = { | |
| "passed": True, | |
| "retrieve": True, | |
| "isrel": has_sources, | |
| "issup": has_sources, | |
| "isuse": bool(hit["answer"]), | |
| "confidence": hit["confidence"], | |
| "issues": ["Served from semantic answer cache."], | |
| } | |
| self._trace(state, "semantic_cache", {"hit": True, "score": hit["score"]}) | |
| return state | |
| state["cache_hit"] = False | |
| self._trace(state, "semantic_cache", {"hit": False}) | |
| return state | |
| def _guardrails(self, state: RagState) -> RagState: | |
| pii = self.guardrails.sanitize(state["query"]) | |
| injection = detect_prompt_injection(state["query"]) | |
| state["sanitized_query"] = pii.text | |
| if injection: | |
| state["should_retrieve"] = False | |
| state["answer"] = "I cannot help with requests that try to override system or safety instructions." | |
| state["confidence"] = 0.99 | |
| self._trace( | |
| state, | |
| "guardrails", | |
| { | |
| "pii_findings": pii.findings, | |
| "prompt_injection_findings": injection, | |
| "langchain_pii_middlewares": len(self.guardrails.langchain_middlewares), | |
| }, | |
| ) | |
| return state | |
| def _direct_response(self, state: RagState) -> RagState: | |
| if state.get("answer"): | |
| return state | |
| if state.get("intent") == "out_of_domain": | |
| state["answer"] = self._scope_message() | |
| state["confidence"] = 0.98 | |
| state["sources"] = [] | |
| state["self_rag"] = { | |
| "passed": True, | |
| "retrieve": False, | |
| "isrel": False, | |
| "issup": False, | |
| "isuse": True, | |
| "confidence": 0.98, | |
| "issues": ["Out-of-domain request blocked before retrieval."], | |
| } | |
| self._trace(state, "direct_response", {"reason": "out_of_domain"}) | |
| return state | |
| if state.get("intent") == "smalltalk": | |
| state["answer"] = self._scope_message() | |
| state["confidence"] = 0.9 | |
| state["sources"] = [] | |
| state["self_rag"] = { | |
| "passed": True, | |
| "retrieve": False, | |
| "isrel": False, | |
| "issup": False, | |
| "isuse": True, | |
| "confidence": 0.9, | |
| "issues": ["Smalltalk redirected to insurance scope."], | |
| } | |
| self._trace(state, "direct_response", {"reason": "smalltalk"}) | |
| return state | |
| answer = self.llm.invoke_text( | |
| system=( | |
| "You are an insurance claims support copilot. Answer simple non-policy questions " | |
| "briefly. Do not invent coverage, policy terms, claim outcomes, or payments." | |
| ), | |
| user=state["sanitized_query"], | |
| ) | |
| state["answer"] = answer | |
| state["confidence"] = 0.72 | |
| state["sources"] = [] | |
| state["self_rag"] = { | |
| "passed": True, | |
| "confidence": 0.72, | |
| "issues": ["Direct response path. Retrieval was not required."], | |
| } | |
| self._trace(state, "direct_response", {"confidence": state["confidence"]}) | |
| return state | |
| def _scope_message(self) -> str: | |
| return ( | |
| "I can only help with insurance-related claims, coverage, policy terms, claim " | |
| "documents, and claim procedures. Please ask an insurance claim question." | |
| ) | |
| def _retrieve(self, state: RagState) -> RagState: | |
| query = self._prepare_retrieval_query(state) | |
| metadata_filter = state.get("metadata_filter") # type: ignore[typeddict-item] | |
| sources = self.retriever.retrieve(query, metadata_filter=metadata_filter) | |
| cleaned_sources = [] | |
| for source in sources: | |
| cleaned_sources.append({**source, "text": strip_unsafe_retrieved_text(source.get("text", ""))}) | |
| state["sources"] = cleaned_sources | |
| self._trace(state, "retrieve", {"count": len(cleaned_sources), "retrieval_query": query}) | |
| return state | |
| def _rerank(self, state: RagState) -> RagState: | |
| reranked = self.reranker.rerank( | |
| state.get("retrieval_query", state["sanitized_query"]), | |
| state.get("sources", []), | |
| top_k=settings.rerank_top_k, | |
| ) | |
| state["reranked_sources"] = reranked | |
| self._trace(state, "rerank", {"count": len(reranked)}) | |
| return state | |
| def _prepare_retrieval_query(self, state: RagState) -> str: | |
| if not settings.enable_query_rewrite: | |
| state["retrieval_query"] = state["sanitized_query"] | |
| return state["retrieval_query"] | |
| if state.get("retrieval_query") and state.get("retrieval_query") != state.get("query"): | |
| return state["retrieval_query"] | |
| result = self.llm.invoke_json( | |
| system=( | |
| "You rewrite user insurance questions into concise retrieval queries for a hybrid " | |
| "BM25 + vector RAG system. Preserve all facts from the user. Do not answer the " | |
| "question. Add only helpful insurance terminology that improves retrieval, such as " | |
| "coverage part, exclusion, deductible, endorsement, claim documents, fault, valuation, " | |
| "or policy limit when relevant.\n\n" | |
| "Important retrieval discipline:\n" | |
| "- Keep the rewritten query focused on the user's requested claim issue.\n" | |
| "- Do not add benefits, services, or subtopics the user did not ask about.\n" | |
| "- If the user asks about damage to insured property, focus on the coverage for that " | |
| "damage, required evidence, deductible, and exclusions.\n" | |
| "- If a term in the user question is vague, rewrite it into precise insurance language, " | |
| "but do not change the claim type.\n" | |
| "- Prefer compact keyword-rich wording over a sentence.\n\n" | |
| "Return JSON only with keys: query, changed, rationale." | |
| ), | |
| user=( | |
| f"Intent: {state.get('intent', 'unknown')}\n" | |
| f"Original user question:\n{state['sanitized_query']}" | |
| ), | |
| fallback={"query": state["sanitized_query"], "changed": False, "rationale": "fallback"}, | |
| ) | |
| rewritten = str(result.get("query") or state["sanitized_query"]).strip() | |
| if not rewritten: | |
| rewritten = state["sanitized_query"] | |
| state["retrieval_query"] = rewritten | |
| self._trace( | |
| state, | |
| "query_rewrite", | |
| { | |
| "original_query": state["sanitized_query"], | |
| "retrieval_query": rewritten, | |
| "changed": bool(result.get("changed", rewritten != state["sanitized_query"])), | |
| "rationale": str(result.get("rationale", ""))[:300], | |
| }, | |
| ) | |
| return rewritten | |
| def _generate(self, state: RagState) -> RagState: | |
| sources = state.get("reranked_sources", []) | |
| llm_sources = sources[: settings.max_sources_to_llm] | |
| evidence = "\n\n".join( | |
| ( | |
| f"Source {i + 1}: {src.get('source_name', 'unknown')}\n" | |
| f"Text: {src.get('text', '')[: settings.max_evidence_chars_per_source]}" | |
| ) | |
| for i, src in enumerate(llm_sources) | |
| ) | |
| if state.get("intent") == "general_insurance_concept": | |
| answer = self.llm.invoke_text( | |
| system=( | |
| "You are an insurance education assistant. The user is asking a general " | |
| "insurance concept question, not requesting a claim payment decision. Use only " | |
| "the provided evidence. Answer briefly and clearly in plain language. Do not use " | |
| "the claim triage structure. Do not say Likely covered, Likely not covered, or " | |
| "Needs human review unless the user asks about a claim scenario. Cite sources as " | |
| "[Source 1], [Source 2], etc. Do not use outside source names." | |
| ), | |
| user=( | |
| f"Question:\n{state['sanitized_query']}\n\n" | |
| f"Retrieved evidence:\n{evidence}" | |
| ), | |
| ) | |
| state["answer"] = self._ensure_source_citation(answer, sources) | |
| self._trace(state, "generate", {"source_count": len(sources), "mode": "concept_llm"}) | |
| return state | |
| answer = self._generate_claim_json_answer(state, evidence, sources) | |
| state["answer"] = self._ensure_source_citation(answer, sources) | |
| self._trace(state, "generate", {"source_count": len(sources), "mode": "llm"}) | |
| return state | |
| def _ensure_source_citation(self, answer: str, sources: list[dict[str, Any]]) -> str: | |
| if not sources or re.search(r"\[Source\s+\d+\]", answer, flags=re.IGNORECASE): | |
| return answer | |
| return answer.rstrip() + " [Source 1]" | |
| def _generate_claim_json_answer(self, state: RagState, evidence: str, sources: list[dict[str, Any]]) -> str: | |
| result = self.llm.invoke_json( | |
| system=( | |
| "You are an insurance claim support AI agent. The user describes a claim scenario. " | |
| "Use only the retrieved evidence. Do not use outside knowledge. Do not invent final " | |
| "payment approval, denial, claim status, policy terms, or source names.\n\n" | |
| "Return JSON only with exactly these keys:\n" | |
| "- decision: one of Likely covered, Likely not covered, Needs human review\n" | |
| "- reason: one or two evidence-grounded sentences\n" | |
| "- missing_evidence: short string listing missing documents/facts or None identified\n" | |
| "- recommended_action: short string with next step, escalation, or review action\n" | |
| "- sources: short string with citations like [Source 1], [Source 2]\n\n" | |
| "Rubric:\n" | |
| "- Likely covered: use only when evidence directly says this cause of loss or scenario is " | |
| "normally covered by the relevant coverage part and the user's facts do not leave a major " | |
| "coverage dependency unresolved.\n" | |
| "- Likely not covered: use only when evidence directly says this cause of loss is excluded, " | |
| "not covered by the standard policy, or requires separate coverage that the user says they " | |
| "do not have.\n" | |
| "- Needs human review: use when payment or coverage depends on unresolved policy-specific " | |
| "facts, endorsements, sublimits, deductibles, fault, valuation, contestability, fraud " | |
| "review, regulatory timing, guaranty fund state limits, settlement amount disputes, or " | |
| "other claim-file details.\n\n" | |
| "If evidence is incomplete, still choose the best triage label and explain what is missing. " | |
| "Allowed citations are only [Source 1], [Source 2], etc." | |
| ), | |
| user=( | |
| f"User memory context:\n{state.get('memory_context', 'No prior memory.')}\n\n" | |
| f"Claim scenario:\n{state['sanitized_query']}\n\n" | |
| f"Retrieved evidence:\n{evidence}" | |
| ), | |
| fallback={ | |
| "decision": "Needs human review", | |
| "reason": "The available evidence is not sufficient to make a final coverage triage.", | |
| "missing_evidence": "Policy-specific details and claim-file documentation.", | |
| "recommended_action": "Escalate for human review with the retrieved evidence.", | |
| "sources": "[Source 1]" if sources else "", | |
| }, | |
| ) | |
| decision = str(result.get("decision", "Needs human review")) | |
| if decision not in {"Likely covered", "Likely not covered", "Needs human review"}: | |
| decision = "Needs human review" | |
| sources_text = str(result.get("sources", "")).strip() | |
| if sources and not re.search(r"\[Source\s+\d+\]", sources_text, flags=re.IGNORECASE): | |
| sources_text = "[Source 1]" | |
| return ( | |
| f"Decision: {decision}\n" | |
| f"Reason: {str(result.get('reason', '')).strip()}\n" | |
| f"Missing evidence: {str(result.get('missing_evidence', '')).strip()}\n" | |
| f"Recommended action: {str(result.get('recommended_action', '')).strip()}\n" | |
| f"Sources: {sources_text}" | |
| ) | |
| def _critique(self, state: RagState) -> RagState: | |
| if state.get("intent") == "general_insurance_concept": | |
| sources = state.get("reranked_sources", []) | |
| answer = state.get("answer", "") | |
| critique = { | |
| "passed": bool(sources) and bool(answer), | |
| "retrieve": True, | |
| "isrel": bool(sources), | |
| "issup": bool(sources) and "[source" in answer.lower(), | |
| "isuse": len(answer.strip()) > 20, | |
| "confidence": 0.84 if sources and answer else 0.45, | |
| "relevance_score": 0.85 if sources else 0.0, | |
| "faithfulness_score": 0.8 if sources and "[source" in answer.lower() else 0.35, | |
| "evidence_score": min(0.9, 0.35 + len(sources) * 0.1), | |
| "needs_rewrite": False, | |
| "rewrite_query": None, | |
| "issues": ["General insurance concept answer with retrieved sources."], | |
| } | |
| state["self_rag"] = critique | |
| state["confidence"] = float(critique["confidence"]) | |
| self._trace(state, "critique", critique) | |
| return state | |
| critique = self.self_rag.critique( | |
| query=state["sanitized_query"], | |
| answer=state.get("answer", ""), | |
| sources=state.get("reranked_sources", []), | |
| iteration=int(state.get("iteration", 0)), | |
| ) | |
| state["self_rag"] = critique | |
| state["confidence"] = float(critique.get("confidence", 0.0)) | |
| self._trace(state, "critique", critique) | |
| return state | |
| def _rewrite(self, state: RagState) -> RagState: | |
| critique = state.get("self_rag", {}) | |
| fallback_query = critique.get("rewrite_query") or state.get("retrieval_query") or state["sanitized_query"] | |
| result = self.llm.invoke_json( | |
| system=( | |
| "You are rewriting a failed insurance RAG retrieval query for a retry. Use the " | |
| "critique issues and original user question to create a better retrieval query. " | |
| "Preserve the user's facts. Do not answer the question. Return JSON only with " | |
| "keys: query, rationale." | |
| ), | |
| user=( | |
| f"Original user question:\n{state['sanitized_query']}\n\n" | |
| f"Previous retrieval query:\n{state.get('retrieval_query', state['sanitized_query'])}\n\n" | |
| f"Critique issues:\n{json.dumps(critique.get('issues', []), ensure_ascii=True)}" | |
| ), | |
| fallback={"query": fallback_query, "rationale": "fallback"}, | |
| ) | |
| rewrite = str(result.get("query") or fallback_query).strip() or str(fallback_query) | |
| state["retrieval_query"] = rewrite | |
| state["iteration"] = int(state.get("iteration", 0)) + 1 | |
| self._trace( | |
| state, | |
| "rewrite", | |
| { | |
| "retrieval_query": state["retrieval_query"], | |
| "iteration": state["iteration"], | |
| "rationale": str(result.get("rationale", ""))[:300], | |
| }, | |
| ) | |
| return state | |
| def _finalize(self, state: RagState) -> RagState: | |
| if state.get("answer"): | |
| sanitized = self.guardrails.sanitize(state["answer"]).text | |
| state["answer"] = self.guardrails.clean_legacy_false_positive_placeholders(sanitized) | |
| if ( | |
| state.get("use_cache", True) | |
| and not state.get("cache_hit") | |
| and state.get("answer") | |
| and not self._has_unsupported_citation(state) | |
| ): | |
| self.cache.save( | |
| query=state["query"], | |
| answer=state["answer"], | |
| confidence=float(state.get("confidence", 0.0)), | |
| sources=state.get("reranked_sources") or state.get("sources", []), | |
| ) | |
| if state.get("answer") and state.get("reranked_sources"): | |
| self.memory.save_interaction( | |
| user_id=state.get("user_id", "default_user"), | |
| query=state["query"], | |
| answer=state["answer"], | |
| critique=state.get("self_rag", {}), | |
| sources=state.get("reranked_sources", []), | |
| ) | |
| self._trace( | |
| state, | |
| "finalize", | |
| { | |
| "cache_hit": state.get("cache_hit", False), | |
| "confidence": state.get("confidence", 0.0), | |
| "iterations": state.get("iteration", 0), | |
| }, | |
| ) | |
| return state | |
| def _has_unsupported_citation(self, state: RagState) -> bool: | |
| answer = state.get("answer", "").lower() | |
| blocked_markers = [ | |
| "insurance information institute", | |
| "source: none", | |
| "source: insurance", | |
| "according to standard insurance terminology", | |
| ] | |
| return any(marker in answer for marker in blocked_markers) | |
| def _route_cache(self, state: RagState) -> str: | |
| return "hit" if state.get("cache_hit") else "miss" | |
| def _route_retrieval(self, state: RagState) -> str: | |
| if state.get("answer") and "override system" in state["answer"]: | |
| return "direct" | |
| if state.get("intent") == "out_of_domain": | |
| return "out_of_domain" | |
| return "retrieve" if state.get("should_retrieve", True) else "direct" | |
| def _route_critique(self, state: RagState) -> str: | |
| critique = state.get("self_rag", {}) | |
| passed = bool(critique.get("passed", False)) | |
| confidence = float(critique.get("confidence", 0.0)) | |
| iteration = int(state.get("iteration", 0)) | |
| if passed and confidence >= 0.68: | |
| return "accept" | |
| if iteration >= settings.self_rag_max_loops: | |
| return "accept" | |
| if critique.get("needs_rewrite", True): | |
| return "retry" | |
| return "accept" | |
| def _trace(self, state: RagState, node: str, payload: dict[str, Any]) -> None: | |
| trace = state.setdefault("trace", []) | |
| trace.append({"node": node, **payload}) | |
| def _save_trace(self, state: RagState) -> None: | |
| with db() as conn: | |
| conn.execute( | |
| """ | |
| INSERT OR REPLACE INTO traces(request_id, query, trace_json) | |
| VALUES (?, ?, ?) | |
| """, | |
| ( | |
| state["request_id"], | |
| state["query"], | |
| json.dumps(state.get("trace", []), ensure_ascii=True), | |
| ), | |
| ) | |