knowledge-value-lab / kvl /modules /attribution.py
feedcomposer's picture
Upload folder using huggingface_hub
7cc493d verified
Raw
History Blame Contribute Delete
4.44 kB
"""Module D: Attribution & Grounding — verifies that RAG answers are grounded in the document."""
from __future__ import annotations
import json
import numpy as np
import anthropic
from kvl.ingestor import Document
_GROUNDING_PROMPT = """Analyze whether an AI-generated answer is properly grounded in the provided source document.
Source document excerpt:
{context}
AI-generated answer:
{answer}
Evaluate grounding quality:
1. What fraction of specific claims in the answer can be traced to the source document? (0.0-1.0)
2. Does the answer avoid hallucinating facts not in the document?
3. Are the answer's assertions supported by evidence in the document?
Return ONLY JSON:
{{
"grounding_fraction": <float 0-1>,
"hallucination_detected": <bool>,
"grounded_claims": ["list of claims that are in the document"],
"ungrounded_claims": ["list of claims NOT found in the document"],
"reason": "one sentence summary"
}}"""
def _call_claude(client: anthropic.Anthropic, prompt: str) -> str:
msg = client.messages.create(
model="claude-sonnet-4-6",
max_tokens=1024,
messages=[{"role": "user", "content": prompt}],
system="You are an expert fact-checker assessing source attribution in AI-generated text.",
)
return msg.content[0].text.strip()
def _semantic_overlap(answer: str, context: str, embedder) -> float:
"""Cosine similarity between answer and document context as a grounding signal."""
embs = embedder.encode([answer, context], normalize_embeddings=True, show_progress_bar=False)
return float(np.dot(embs[0], embs[1]))
def evaluate(client: anthropic.Anthropic, doc: Document, generation_results: dict, embedder, progress_cb=None, max_workers: int = 6) -> dict:
"""Return grounding score (0-100) using outputs from the generation module."""
from concurrent.futures import ThreadPoolExecutor
details_list = generation_results.get("details", [])
if not details_list:
return {"score": 50, "details": [], "summary": "No generation results to assess grounding."}
context = " ".join(doc.raw.split()[:4000])
def _check_grounding(item):
rag_answer = item.get("rag_answer", "")
if not rag_answer:
return None
raw = _call_claude(client, _GROUNDING_PROMPT.format(context=context, answer=rag_answer))
raw = raw.strip()
if raw.startswith("```"):
raw = "\n".join(raw.split("\n")[1:])
raw = raw.rsplit("```", 1)[0]
try:
judgment = json.loads(raw)
except json.JSONDecodeError:
judgment = {"grounding_fraction": 0.5, "hallucination_detected": False, "reason": "Parse error."}
llm_grounding = judgment.get("grounding_fraction", 0.5)
semantic_sim = _semantic_overlap(rag_answer, context, embedder)
hallucination_penalty = 0.2 if judgment.get("hallucination_detected", False) else 0.0
combined = max(0.0, min(1.0, (0.7 * llm_grounding + 0.3 * semantic_sim) - hallucination_penalty))
return {
"question": item.get("question", ""),
"answer": rag_answer,
"grounding_fraction": llm_grounding,
"semantic_similarity": round(semantic_sim, 3),
"hallucination_detected": judgment.get("hallucination_detected", False),
"grounded_claims": judgment.get("grounded_claims", []),
"ungrounded_claims": judgment.get("ungrounded_claims", []),
"reason": judgment.get("reason", ""),
"combined_score": round(combined, 3),
}
if progress_cb:
progress_cb(f"Checking grounding for {len(details_list)} answers in parallel...")
with ThreadPoolExecutor(max_workers=max_workers) as pool:
raw_results = list(pool.map(_check_grounding, details_list))
results = [r for r in raw_results if r is not None]
grounding_scores = [r["combined_score"] for r in results]
if not grounding_scores:
return {"score": 50, "details": results, "summary": "No grounding assessments completed."}
avg_grounding = sum(grounding_scores) / len(grounding_scores)
score = round(avg_grounding * 100)
hallucinations = sum(1 for r in results if r.get("hallucination_detected"))
return {
"score": score,
"details": results,
"summary": f"Average grounding: {score}/100. Hallucinations detected in {hallucinations}/{len(results)} answers.",
}