ai_exec / src /evaluation /factual_accuracy.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Factual Accuracy Module
Verify that generated responses align with CEO's documented positions.
Cross-references claims against source blog content.
Example usage:
checker = FactualAccuracyChecker.from_blogs("data/processed/posts.json")
result = checker.check_response("Generated response with claims...")
"""
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from loguru import logger
try:
from sentence_transformers import SentenceTransformer, util
import numpy as np
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
logger.warning("sentence-transformers not available")
@dataclass
class FactualCheckResult:
"""Results from factual accuracy check."""
accuracy_score: float # 0-1 score
verified_claims: list = field(default_factory=list)
unverified_claims: list = field(default_factory=list)
potential_hallucinations: list = field(default_factory=list)
source_citations: list = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"accuracy_score": round(self.accuracy_score, 4),
"num_verified": len(self.verified_claims),
"num_unverified": len(self.unverified_claims),
"num_potential_hallucinations": len(self.potential_hallucinations),
"verified_claims": self.verified_claims[:10],
"unverified_claims": self.unverified_claims[:10],
"potential_hallucinations": self.potential_hallucinations[:5],
}
def passes_threshold(self, threshold: float = 0.95) -> bool:
"""Check if accuracy meets threshold."""
return self.accuracy_score >= threshold
class FactualAccuracyChecker:
"""
Check factual accuracy of generated responses.
Compares claims in generated text against source blog content
to identify potential hallucinations or misrepresentations.
Example:
>>> checker = FactualAccuracyChecker.from_blogs("posts.json")
>>> result = checker.check_response("Response with claims...")
>>> print(f"Accuracy: {result.accuracy_score}")
"""
def __init__(
self,
source_texts: list[dict],
embedding_model: Optional[str] = "all-MiniLM-L6-v2",
similarity_threshold: float = 0.7,
):
"""
Initialize the checker.
Args:
source_texts: List of dicts with 'title' and 'content'
embedding_model: Sentence transformer model
similarity_threshold: Threshold for considering a claim verified
"""
self.source_texts = source_texts
self.similarity_threshold = similarity_threshold
# Build source corpus
self.source_corpus = []
for doc in source_texts:
content = doc.get("content", "")
# Split into paragraphs for finer-grained matching
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
self.source_corpus.extend(paragraphs)
# Load embedding model and encode corpus
self.embedding_model = None
self.corpus_embeddings = None
if SENTENCE_TRANSFORMERS_AVAILABLE and embedding_model:
try:
self.embedding_model = SentenceTransformer(embedding_model)
logger.info(f"Encoding {len(self.source_corpus)} source paragraphs...")
self.corpus_embeddings = self.embedding_model.encode(
self.source_corpus,
convert_to_tensor=True,
show_progress_bar=False,
)
logger.info("Source corpus encoded")
except Exception as e:
logger.warning(f"Failed to load embedding model: {e}")
@classmethod
def from_blogs(
cls,
posts_path: str | Path,
embedding_model: str = "all-MiniLM-L6-v2",
similarity_threshold: float = 0.7,
) -> "FactualAccuracyChecker":
"""
Create checker from parsed blog posts.
Args:
posts_path: Path to posts.json
embedding_model: Sentence transformer model
similarity_threshold: Verification threshold
Returns:
FactualAccuracyChecker instance
"""
with open(posts_path, "r", encoding="utf-8") as f:
posts = json.load(f)
return cls(posts, embedding_model, similarity_threshold)
def check_response(
self,
response: str,
extract_claims: bool = True,
) -> FactualCheckResult:
"""
Check factual accuracy of a generated response.
Args:
response: Generated response text
extract_claims: Whether to extract individual claims
Returns:
FactualCheckResult with accuracy metrics
"""
if extract_claims:
claims = self._extract_claims(response)
else:
# Treat each sentence as a claim
claims = self._split_sentences(response)
if not claims:
return FactualCheckResult(accuracy_score=1.0)
verified = []
unverified = []
hallucinations = []
citations = []
for claim in claims:
is_verified, similarity, source = self._verify_claim(claim)
if is_verified:
verified.append({
"claim": claim,
"similarity": similarity,
"source_excerpt": source[:200] if source else None,
})
if source:
citations.append(source[:100])
else:
# Check if it's a potential hallucination
if self._is_factual_claim(claim):
if similarity < 0.3:
hallucinations.append({
"claim": claim,
"similarity": similarity,
"reason": "No similar content in source",
})
else:
unverified.append({
"claim": claim,
"similarity": similarity,
})
else:
# Opinion or subjective statement - not hallucination
verified.append({
"claim": claim,
"similarity": similarity,
"type": "opinion",
})
# Calculate accuracy
total_factual = len(verified) + len(unverified) + len(hallucinations)
accuracy = len(verified) / total_factual if total_factual > 0 else 1.0
return FactualCheckResult(
accuracy_score=accuracy,
verified_claims=verified,
unverified_claims=unverified,
potential_hallucinations=hallucinations,
source_citations=list(set(citations)),
)
def check_batch(
self,
responses: list[str],
) -> dict:
"""
Check factual accuracy of multiple responses.
Args:
responses: List of generated responses
Returns:
Aggregate metrics
"""
results = [self.check_response(r) for r in responses]
def avg(values):
return sum(values) / len(values) if values else 0
return {
"num_responses": len(results),
"avg_accuracy": avg([r.accuracy_score for r in results]),
"total_verified": sum(len(r.verified_claims) for r in results),
"total_unverified": sum(len(r.unverified_claims) for r in results),
"total_hallucinations": sum(len(r.potential_hallucinations) for r in results),
"pass_rate_0.95": sum(1 for r in results if r.passes_threshold(0.95)) / len(results),
}
def _extract_claims(self, text: str) -> list[str]:
"""Extract factual claims from text."""
sentences = self._split_sentences(text)
claims = []
for sentence in sentences:
# Skip very short sentences
if len(sentence) < 20:
continue
# Skip questions
if sentence.strip().endswith("?"):
continue
# Look for factual indicators
factual_indicators = [
r"\b(is|are|was|were|has|have|had)\b",
r"\b(always|never|every|all|no)\b",
r"\b(percent|%|million|billion)\b",
r"\b(study|research|data|evidence)\b",
r"\b(founded|created|built|developed)\b",
]
is_likely_factual = any(
re.search(pattern, sentence, re.IGNORECASE)
for pattern in factual_indicators
)
if is_likely_factual:
claims.append(sentence)
else:
# Still include as a claim for verification
claims.append(sentence)
return claims
def _split_sentences(self, text: str) -> list[str]:
"""Split text into sentences."""
# Simple sentence splitting
sentences = re.split(r"[.!?]+", text)
return [s.strip() for s in sentences if s.strip()]
def _verify_claim(self, claim: str) -> tuple[bool, float, Optional[str]]:
"""
Verify a claim against source corpus.
Returns:
Tuple of (is_verified, similarity_score, matching_source)
"""
if not self.embedding_model or self.corpus_embeddings is None:
# Fallback to simple text matching
return self._verify_claim_text_match(claim)
try:
claim_embedding = self.embedding_model.encode(
claim, convert_to_tensor=True
)
# Calculate similarities
similarities = util.cos_sim(claim_embedding, self.corpus_embeddings)[0]
max_sim_idx = similarities.argmax().item()
max_similarity = similarities[max_sim_idx].item()
source = self.source_corpus[max_sim_idx] if max_sim_idx < len(self.source_corpus) else None
is_verified = max_similarity >= self.similarity_threshold
return is_verified, max_similarity, source
except Exception as e:
logger.warning(f"Embedding verification failed: {e}")
return self._verify_claim_text_match(claim)
def _verify_claim_text_match(self, claim: str) -> tuple[bool, float, Optional[str]]:
"""Fallback verification using text matching."""
claim_lower = claim.lower()
claim_words = set(re.findall(r"\b\w+\b", claim_lower))
best_match = 0.0
best_source = None
for source in self.source_corpus:
source_lower = source.lower()
source_words = set(re.findall(r"\b\w+\b", source_lower))
# Jaccard similarity
if claim_words and source_words:
intersection = len(claim_words & source_words)
union = len(claim_words | source_words)
similarity = intersection / union
if similarity > best_match:
best_match = similarity
best_source = source
is_verified = best_match >= self.similarity_threshold
return is_verified, best_match, best_source
def _is_factual_claim(self, claim: str) -> bool:
"""Determine if a claim is factual (vs opinion/subjective)."""
opinion_indicators = [
r"\b(i think|i believe|in my opinion|i feel)\b",
r"\b(should|could|might|may)\b",
r"\b(important|interesting|exciting|concerning)\b",
r"\b(best|worst|better|worse)\b",
]
claim_lower = claim.lower()
for pattern in opinion_indicators:
if re.search(pattern, claim_lower):
return False
return True
def main():
"""CLI entry point for testing factual accuracy."""
import argparse
parser = argparse.ArgumentParser(
description="Check factual accuracy of generated responses",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python factual_accuracy.py --posts posts.json --response "Text to check..."
python factual_accuracy.py --posts posts.json --responses-file outputs.json
""",
)
parser.add_argument("--posts", required=True, help="Parsed posts JSON path")
parser.add_argument("--response", help="Single response to check")
parser.add_argument("--responses-file", help="JSON file with responses")
parser.add_argument("--threshold", type=float, default=0.7, help="Similarity threshold")
args = parser.parse_args()
# Load checker
print(f"Loading source corpus: {args.posts}")
checker = FactualAccuracyChecker.from_blogs(
args.posts,
similarity_threshold=args.threshold,
)
if args.response:
# Check single response
result = checker.check_response(args.response)
print("\n=== Factual Accuracy Check ===")
print(f"Accuracy score: {result.accuracy_score:.2%}")
print(f"Verified claims: {len(result.verified_claims)}")
print(f"Unverified claims: {len(result.unverified_claims)}")
print(f"Potential hallucinations: {len(result.potential_hallucinations)}")
if result.potential_hallucinations:
print("\nPotential hallucinations:")
for h in result.potential_hallucinations[:3]:
print(f" - {h['claim'][:100]}...")
print(f"\nPasses 95% threshold: {result.passes_threshold()}")
elif args.responses_file:
# Check batch
with open(args.responses_file, "r") as f:
data = json.load(f)
responses = [d["response"] if isinstance(d, dict) else d for d in data]
results = checker.check_batch(responses)
print("\n=== Batch Accuracy Check ===")
for key, value in results.items():
print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")
else:
print("Provide --response or --responses-file")
return 1
return 0
if __name__ == "__main__":
exit(main())