Spaces:
Sleeping
Sleeping
| """Module D: Attribution & Grounding — verifies that RAG answers are grounded in the document.""" | |
| from __future__ import annotations | |
| import json | |
| import numpy as np | |
| import anthropic | |
| from kvl.ingestor import Document | |
| _GROUNDING_PROMPT = """Analyze whether an AI-generated answer is properly grounded in the provided source document. | |
| Source document excerpt: | |
| {context} | |
| AI-generated answer: | |
| {answer} | |
| Evaluate grounding quality: | |
| 1. What fraction of specific claims in the answer can be traced to the source document? (0.0-1.0) | |
| 2. Does the answer avoid hallucinating facts not in the document? | |
| 3. Are the answer's assertions supported by evidence in the document? | |
| Return ONLY JSON: | |
| {{ | |
| "grounding_fraction": <float 0-1>, | |
| "hallucination_detected": <bool>, | |
| "grounded_claims": ["list of claims that are in the document"], | |
| "ungrounded_claims": ["list of claims NOT found in the document"], | |
| "reason": "one sentence summary" | |
| }}""" | |
| def _call_claude(client: anthropic.Anthropic, prompt: str) -> str: | |
| msg = client.messages.create( | |
| model="claude-sonnet-4-6", | |
| max_tokens=1024, | |
| messages=[{"role": "user", "content": prompt}], | |
| system="You are an expert fact-checker assessing source attribution in AI-generated text.", | |
| ) | |
| return msg.content[0].text.strip() | |
| def _semantic_overlap(answer: str, context: str, embedder) -> float: | |
| """Cosine similarity between answer and document context as a grounding signal.""" | |
| embs = embedder.encode([answer, context], normalize_embeddings=True, show_progress_bar=False) | |
| return float(np.dot(embs[0], embs[1])) | |
| def evaluate(client: anthropic.Anthropic, doc: Document, generation_results: dict, embedder, progress_cb=None, max_workers: int = 6) -> dict: | |
| """Return grounding score (0-100) using outputs from the generation module.""" | |
| from concurrent.futures import ThreadPoolExecutor | |
| details_list = generation_results.get("details", []) | |
| if not details_list: | |
| return {"score": 50, "details": [], "summary": "No generation results to assess grounding."} | |
| context = " ".join(doc.raw.split()[:4000]) | |
| def _check_grounding(item): | |
| rag_answer = item.get("rag_answer", "") | |
| if not rag_answer: | |
| return None | |
| raw = _call_claude(client, _GROUNDING_PROMPT.format(context=context, answer=rag_answer)) | |
| raw = raw.strip() | |
| if raw.startswith("```"): | |
| raw = "\n".join(raw.split("\n")[1:]) | |
| raw = raw.rsplit("```", 1)[0] | |
| try: | |
| judgment = json.loads(raw) | |
| except json.JSONDecodeError: | |
| judgment = {"grounding_fraction": 0.5, "hallucination_detected": False, "reason": "Parse error."} | |
| llm_grounding = judgment.get("grounding_fraction", 0.5) | |
| semantic_sim = _semantic_overlap(rag_answer, context, embedder) | |
| hallucination_penalty = 0.2 if judgment.get("hallucination_detected", False) else 0.0 | |
| combined = max(0.0, min(1.0, (0.7 * llm_grounding + 0.3 * semantic_sim) - hallucination_penalty)) | |
| return { | |
| "question": item.get("question", ""), | |
| "answer": rag_answer, | |
| "grounding_fraction": llm_grounding, | |
| "semantic_similarity": round(semantic_sim, 3), | |
| "hallucination_detected": judgment.get("hallucination_detected", False), | |
| "grounded_claims": judgment.get("grounded_claims", []), | |
| "ungrounded_claims": judgment.get("ungrounded_claims", []), | |
| "reason": judgment.get("reason", ""), | |
| "combined_score": round(combined, 3), | |
| } | |
| if progress_cb: | |
| progress_cb(f"Checking grounding for {len(details_list)} answers in parallel...") | |
| with ThreadPoolExecutor(max_workers=max_workers) as pool: | |
| raw_results = list(pool.map(_check_grounding, details_list)) | |
| results = [r for r in raw_results if r is not None] | |
| grounding_scores = [r["combined_score"] for r in results] | |
| if not grounding_scores: | |
| return {"score": 50, "details": results, "summary": "No grounding assessments completed."} | |
| avg_grounding = sum(grounding_scores) / len(grounding_scores) | |
| score = round(avg_grounding * 100) | |
| hallucinations = sum(1 for r in results if r.get("hallucination_detected")) | |
| return { | |
| "score": score, | |
| "details": results, | |
| "summary": f"Average grounding: {score}/100. Hallucinations detected in {hallucinations}/{len(results)} answers.", | |
| } | |