SamSankar's picture
Upload grader.py
afe864b verified
"""Research-grade hallucination detection and grading system v4.0.
Upgrades in v4.0:
- NLI model: nli-deberta-v3-small (memory-efficient for HF Spaces)
- ROUGE-1/2/L added (Lin 2004)
- BERTScore added via DeBERTa-v3-base (Zhang et al. 2020)
- Alignment score via NLI CrossEncoder (replaces AlignScore — no separate model needed)
- Reward expanded from 6 to 9 components
"""
import re
import math
import logging
from typing import Tuple, Optional, Dict, Any, List, Set
from difflib import SequenceMatcher
from dataclasses import dataclass, field
from enum import Enum
import hashlib
logger = logging.getLogger(__name__)
class HallucinationSeverity(Enum):
"""Severity levels for hallucinations."""
NONE = 0
MINOR = 1
MODERATE = 2
SEVERE = 3
CRITICAL = 4
class HallucinationType(Enum):
"""Types of hallucinations."""
NONE = "none"
FABRICATED_FACT = "fabricated_fact"
FALSE_CITATION = "false_citation"
OVERCONFIDENT_WRONG = "overconfident_wrong"
CONTEXT_DRIFT = "context_drift"
NUMERICAL_FABRICATION = "numerical_fabrication"
ENTITY_CONFUSION = "entity_confusion"
TEMPORAL_ERROR = "temporal_error"
RELATIONSHIP_ERROR = "relationship_error"
# ── Embedding model (sentence-transformers) ───────────────────────────────────
_embedder = None
_embedder_available = False
def _get_embedder():
global _embedder, _embedder_available
if _embedder is not None:
return _embedder
try:
from sentence_transformers import SentenceTransformer
_embedder = SentenceTransformer("all-MiniLM-L6-v2")
_embedder_available = True
logger.info("sentence-transformers loaded: all-MiniLM-L6-v2")
except Exception as e:
logger.warning(f"sentence-transformers not available ({e}); using SequenceMatcher fallback")
_embedder = None
_embedder_available = False
return _embedder
def _cosine_similarity(a, b) -> float:
try:
import numpy as np
a, b = np.array(a), np.array(b)
denom = (np.linalg.norm(a) * np.linalg.norm(b))
return float(np.dot(a, b) / denom) if denom > 0 else 0.0
except Exception:
# Fallback: manual cosine similarity without numpy
dot_product = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
return dot_product / (norm_a * norm_b) if norm_a * norm_b > 0 else 0.0
# ── NLI cross-encoder ─────────────────────────────────────────────────────────
_nli_model = None
_nli_available = False
def _get_nli():
global _nli_model, _nli_available
if _nli_model is not None:
return _nli_model
try:
import os
from sentence_transformers import CrossEncoder
# nli-deberta-v3-large (~1.5 GB) causes HF Spaces OOM/restart loops.
# Default to small (~280 MB). Set USE_LARGE_NLI=true locally if needed.
_use_large = os.getenv("USE_LARGE_NLI", "false").lower() == "true"
_model_name = (
"cross-encoder/nli-deberta-v3-large" if _use_large
else "cross-encoder/nli-deberta-v3-small"
)
_nli_model = CrossEncoder(_model_name)
_nli_available = True
logger.info(f"NLI cross-encoder loaded: {_model_name}")
except Exception as e:
logger.warning(f"NLI cross-encoder not available ({e}); using heuristic fallback")
_nli_model = None
_nli_available = False
return _nli_model
# ── ROUGE scorer ──────────────────────────────────────────────────────────────
_rouge_scorer = None
# ── Nemotron / reasoning model support ────────────────────────────────────────
import re as _re
def _strip_thinking(text: str) -> str:
"""
Strip reasoning traces from Nemotron 3 Super and other chain-of-thought
models before grading the actual answer.
Handles: <think>, <reasoning>, and similar tags.
"""
if not text:
return text
text = _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL | _re.IGNORECASE)
text = _re.sub(r"<reasoning>.*?</reasoning>", "", text, flags=_re.DOTALL | _re.IGNORECASE)
text = _re.sub(r"<reflection>.*?</reflection>", "", text, flags=_re.DOTALL | _re.IGNORECASE)
text = text.strip()
return text.strip()
def _get_rouge():
global _rouge_scorer
if _rouge_scorer is not None:
return _rouge_scorer
try:
from rouge_score import rouge_scorer as rs
_rouge_scorer = rs.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
logger.info("ROUGE scorer loaded")
except Exception as e:
logger.warning(f"rouge-score not available ({e})")
_rouge_scorer = None
return _rouge_scorer
def compute_rouge(hypothesis: str, reference: str) -> Dict[str, float]:
"""Compute ROUGE-1, ROUGE-2, ROUGE-L F1 scores."""
scorer = _get_rouge()
if scorer is None or not hypothesis or not reference:
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
try:
scores = scorer.score(reference, hypothesis)
return {
"rouge1": round(scores["rouge1"].fmeasure, 4),
"rouge2": round(scores["rouge2"].fmeasure, 4),
"rougeL": round(scores["rougeL"].fmeasure, 4),
}
except Exception as e:
logger.warning(f"ROUGE computation failed: {e}")
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
# ── BERTScore ─────────────────────────────────────────────────────────────────
_bertscore_available = None
_bert_scorer = None
def _check_bertscore():
global _bertscore_available
if _bertscore_available is not None:
return _bertscore_available
try:
import bert_score # noqa: F401
_bertscore_available = True
logger.info("BERTScore available")
except Exception:
_bertscore_available = False
return _bertscore_available
def _get_bert_scorer():
"""Lazy singleton BERTScorer — loads roberta-base once and reuses it."""
global _bert_scorer
if _bert_scorer is not None:
return _bert_scorer
try:
import transformers
transformers.logging.set_verbosity_error()
from bert_score import BERTScorer
_bert_scorer = BERTScorer(
model_type="roberta-base",
lang="en",
device="cpu",
)
logger.info("BERTScorer (roberta-base) cached as singleton")
except Exception as e:
logger.warning(f"BERTScorer init failed: {e}")
_bert_scorer = None
return _bert_scorer
def compute_bertscore(hypothesis: str, reference: str) -> Dict[str, float]:
"""Compute BERTScore P/R/F1 using roberta-base.
Gracefully returns zeros if bert-score is unavailable or crashes
(e.g. due to incompatibilities with newer transformers versions).
"""
if not _check_bertscore():
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
# Guard against None / non-string inputs that cause internal crashes
if not hypothesis or not reference:
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
if not isinstance(hypothesis, str) or not isinstance(reference, str):
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
try:
scorer = _get_bert_scorer()
if scorer is None:
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
P, R, F = scorer.score(
[str(hypothesis)], [str(reference)],
verbose=False,
)
return {
"precision": round(float(P[0]), 4),
"recall": round(float(R[0]), 4),
"f1": round(float(F[0]), 4),
}
except Exception as e:
logger.debug(f"BERTScore failed: {e}")
# Mark as unavailable so future calls skip the import entirely
global _bertscore_available
_bertscore_available = False
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
# ── Alignment Score (NLI-based, replaces AlignScore) ──────────────────────────
# AlignScore requires a separate 1.3GB model + manual checkpoint download from GitHub.
# Instead, we compute alignment using the already-loaded NLI CrossEncoder,
# which provides an equivalent faithfulness signal at zero extra memory cost.
# The NLI model scores (entailment, contradiction, neutral) between context→answer
# are transformed into an alignment score: entailment → high, contradiction → low.
def compute_alignscore(context: str, answer: str) -> float:
"""
Compute alignment/faithfulness score using the NLI CrossEncoder.
Uses the already-loaded nli-deberta-v3-small model to measure how well
the answer is entailed by the context. This provides an equivalent signal
to AlignScore (Zha et al. ACL 2023) without requiring a separate model.
Returns a score in [0, 1]:
- 1.0: answer is fully entailed by context (faithful)
- 0.5: neutral / no clear entailment
- 0.0: answer contradicts context
Falls back to 0.5 (neutral) if NLI model is unavailable.
"""
if not context or not answer:
return 0.5
nli = _get_nli()
if nli is None:
return 0.5
try:
# NLI labels: [contradiction, entailment, neutral]
scores = nli.predict([(context, answer)], apply_softmax=True)
if hasattr(scores, 'tolist'):
scores = scores.tolist()
if isinstance(scores, list) and len(scores) > 0:
if isinstance(scores[0], list) and len(scores[0]) == 3:
# scores[0] = [contradiction_prob, entailment_prob, neutral_prob]
contradiction, entailment, neutral = scores[0]
# Convert to alignment: entailment → 1.0, neutral → 0.5, contradiction → 0.0
return float(max(0.0, min(1.0, entailment - contradiction + 0.5)))
else:
return 0.5
return 0.5
except Exception as e:
logger.debug(f"NLI alignment score failed: {e}")
return 0.5
def normalize_text(text: str, preserve_numbers: bool = False) -> str:
"""Normalize text for comparison with advanced preprocessing."""
if not text:
return ""
# Lowercase and strip
text = text.lower().strip()
# Remove extra whitespace
text = " ".join(text.split())
# Remove common punctuation for comparison
text = re.sub(r'[^\w\s\.]', ' ', text)
# Normalize numbers only if not preserving (for factual comparison)
if not preserve_numbers:
text = re.sub(r'\b\d+\.\d+\b', 'NUM', text)
text = re.sub(r'\b\d+\b', 'NUM', text)
return text.strip()
def extract_entities(text: str) -> Set[str]:
"""Extract potential entities (capitalized words, numbers, dates) from text."""
entities = set()
# Extract capitalized words (potential proper nouns)
capitalized = re.findall(r'\b[A-Z][a-z]+\b', text)
entities.update(capitalized)
# Extract numbers with units
numbers = re.findall(r'\b\d+(?:\.\d+)?\s*(?:%|USD|dollars|people|times)?\b', text, re.IGNORECASE)
entities.update(numbers)
# Extract dates
dates = re.findall(r'\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d+\b', text, re.IGNORECASE)
entities.update(dates)
return entities
def extract_key_claims(text: str) -> List[str]:
"""Extract key claims/statements from text for detailed analysis."""
if not text:
return []
# Split into sentences
sentences = re.split(r'[.!?]+', text)
# Filter out very short sentences and stop-word-only sentences
stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'it', 'this', 'that'}
claims = []
for sentence in sentences:
sentence = sentence.strip()
words = sentence.split()
if len(words) >= 3 and len([w for w in words if w.lower() not in stop_words]) >= 2:
claims.append(sentence)
return claims
# ── Edge Case Handling: Numerical Tolerance ────────────────────────────────────
NUMBER_WORDS = {
"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
"eleven": 11, "twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15,
"twenty": 20, "thirty": 30, "forty": 40, "fifty": 50,
"hundred": 100, "thousand": 1000, "million": 1000000,
"half": 0.5, "quarter": 0.25, "third": 0.333,
}
APPROXIMATION_WORDS = {"approximately", "about", "around", "roughly", "nearly", "almost", "close to", "approx."}
def normalize_numbers(text: str) -> Set[float]:
"""
Extract and normalize all numbers from text, handling:
- Digits: "50" -> 50.0
- Words: "fifty" -> 50.0
- Percentages: "50%" -> 0.5
- Fractions: "1/2" -> 0.5
- Units: "50 dollars" -> 50.0
Returns set of normalized float values.
"""
numbers = set()
# Extract digit-based numbers
digit_nums = re.findall(r'\d+(?:\.\d+)?', text)
for n in digit_nums:
try:
numbers.add(float(n))
except ValueError:
pass
# Extract word-based numbers
text_lower = text.lower()
for word, value in NUMBER_WORDS.items():
if word in text_lower:
numbers.add(float(value))
# Extract percentages and normalize to decimals
percentages = re.findall(r'(\d+(?:\.\d+)?)\s*%', text)
for p in percentages:
try:
numbers.add(float(p) / 100.0)
except ValueError:
pass
# Extract fractions
fractions = re.findall(r'(\d+)\s*/\s*(\d+)', text)
for num, denom in fractions:
try:
numbers.add(float(num) / float(denom))
except ValueError:
pass
return numbers
def numbers_approx_match(a: float, b: float, tolerance: float = 0.1) -> bool:
"""
Check if two numbers match within relative tolerance.
Handles cases like "approximately 50" vs "50" or "48".
"""
if a == b:
return True
max_val = max(abs(a), abs(b), 1e-10)
return abs(a - b) / max_val < tolerance
def check_numerical_match(answer_nums: Set[float], truth_nums: Set[float], tolerance: float = 0.1) -> Tuple[bool, float]:
"""
Check if answer numbers match truth numbers with tolerance.
Returns: (is_match, match_score)
- is_match: True if all critical numbers match
- match_score: 0.0-1.0 indicating match quality
"""
if not truth_nums:
# No numbers in ground truth, no penalty
return True, 1.0
if not answer_nums:
# Numbers expected but none provided
return False, 0.0
# Check each truth number for approximate match in answer
matched = 0
for truth_n in truth_nums:
for ans_n in answer_nums:
if numbers_approx_match(truth_n, ans_n, tolerance):
matched += 1
break
match_ratio = matched / len(truth_nums)
return match_ratio >= 0.8, match_ratio
def detect_hedging(text: str) -> Tuple[bool, float]:
"""
Detect hedging language in answer.
Returns: (has_hedging, hedging_intensity)
- has_hedging: True if hedging detected
- hedging_intensity: 0.0-1.0 (higher = more hedging)
"""
text_lower = text.lower()
hedging_count = 0
for phrase in APPROXIMATION_WORDS:
if phrase in text_lower:
hedging_count += 1
# Check for modal verbs indicating uncertainty
modal_verbs = ["might", "could", "may", "possibly", "perhaps", "seems"]
for modal in modal_verbs:
if modal in text_lower:
hedging_count += 0.5
intensity = min(1.0, hedging_count / 3.0)
return intensity > 0, intensity
def handle_ambiguous_answer(
answer: str,
ground_truth: str,
valid_alternatives: List[str] = None
) -> Tuple[float, str]:
"""
Handle cases where multiple answers may be valid.
Returns: (score, matched_answer)
"""
# Normalize answer
answer_norm = normalize_text(answer)
truth_norm = normalize_text(ground_truth)
# Check primary answer
if answer_norm == truth_norm or truth_norm in answer_norm:
return 1.0, ground_truth
# Check alternatives if provided
if valid_alternatives:
for alt in valid_alternatives:
alt_norm = normalize_text(alt)
if answer_norm == alt_norm or alt_norm in answer_norm:
return 0.95, alt
# Check semantic similarity
similarity = compute_string_similarity(answer, ground_truth)
if similarity > 0.8:
return similarity, ground_truth
return 0.0, ""
def compute_string_similarity(s1: str, s2: str) -> float:
"""Compute semantic similarity between two strings.
Uses sentence-transformers (all-MiniLM-L6-v2) cosine similarity when
available, with SequenceMatcher + Jaccard as a graceful fallback.
"""
if not s1 or not s2:
return 0.0
s1_norm = normalize_text(s1)
s2_norm = normalize_text(s2)
if s1_norm == s2_norm:
return 1.0
if s1_norm in s2_norm or s2_norm in s1_norm:
return 0.9
# ── Try embedding-based cosine similarity first ───────────────────────
embedder = _get_embedder()
if embedder is not None:
try:
vecs = embedder.encode([s1, s2], convert_to_numpy=True, show_progress_bar=False)
score = _cosine_similarity(vecs[0], vecs[1])
# Clip to [0,1] — cosine can be slightly negative for unrelated text
return max(0.0, min(1.0, float(score)))
except Exception as e:
logger.warning(f"Embedding similarity failed: {e}; using fallback")
# ── Fallback: SequenceMatcher + Jaccard ──────────────────────────────
seq_sim = SequenceMatcher(None, s1_norm, s2_norm).ratio()
s1_words = set(s1_norm.split())
s2_words = set(s2_norm.split())
if not s1_words or not s2_words:
return seq_sim
jaccard = len(s1_words & s2_words) / len(s1_words | s2_words)
return max(seq_sim, jaccard)
def check_quote_in_context_advanced(source_quote: str, context: str) -> Tuple[float, Dict[str, Any]]:
"""
Advanced citation verification with detailed analysis.
Returns:
Tuple of (score, analysis_dict)
"""
analysis = {
"exact_match": False,
"partial_matches": [],
"best_match_score": 0.0,
"match_location": None,
"surrounding_context": "",
"quote_length": len(source_quote),
"context_length": len(context)
}
if not source_quote or not context:
return 0.0, analysis
normalized_quote = normalize_text(source_quote)
normalized_context = normalize_text(context)
# Exact match check
if normalized_quote in normalized_context:
analysis["exact_match"] = True
analysis["best_match_score"] = 1.0
analysis["match_location"] = normalized_context.find(normalized_quote)
# Get surrounding context
start = max(0, analysis["match_location"] - 50)
end = min(len(context), analysis["match_location"] + len(source_quote) + 50)
analysis["surrounding_context"] = context[start:end]
return 1.0, analysis
# Sliding window for fuzzy matching
quote_words = normalized_quote.split()
context_words = normalized_context.split()
if len(quote_words) == 0:
return 0.0, analysis
best_match_score = 0.0
best_match_window = None
best_match_idx = 0
window_size = len(quote_words)
for i in range(len(context_words) - window_size + 1):
window = context_words[i:i + window_size]
window_text = " ".join(window)
similarity = SequenceMatcher(None, normalized_quote, window_text).ratio()
if similarity > best_match_score:
best_match_score = similarity
best_match_window = window_text
best_match_idx = i
if best_match_score > 0.7:
analysis["partial_matches"].append({
"text": best_match_window,
"score": best_match_score,
"position": best_match_idx
})
analysis["best_match_score"] = best_match_score
# Get surrounding context
char_pos = sum(len(w) + 1 for w in context_words[:best_match_idx])
start = max(0, char_pos - 50)
end = min(len(context), char_pos + len(best_match_window) + 50)
analysis["surrounding_context"] = context[start:end]
return best_match_score, analysis
# Try matching key phrases (relaxed matching)
quote_key_phrases = [p for p in normalized_quote.split() if len(p) > 3]
# Strip trailing periods from context words for matching (normalize_text preserves periods)
context_set = set(w.rstrip('.') for w in normalized_context.split())
if quote_key_phrases:
phrase_match_ratio = sum(1 for p in quote_key_phrases if p in context_set) / len(quote_key_phrases)
if phrase_match_ratio > 0.5:
score = 0.5 + 0.3 * phrase_match_ratio
analysis["best_match_score"] = score
analysis["partial_matches"].append({
"type": "key_phrase_match",
"ratio": phrase_match_ratio
})
return score, analysis
return 0.0, analysis
def check_factual_accuracy_advanced(answer: str, ground_truth: str, context: str = "") -> Tuple[float, Dict[str, Any]]:
"""
Advanced factual accuracy checking with semantic understanding.
Returns:
Tuple of (score, analysis_dict)
"""
analysis = {
"exact_match": False,
"semantic_similarity": 0.0,
"entity_overlap": 0.0,
"key_info_present": False,
"contradictions": [],
"answer_length": len(answer),
"truth_length": len(ground_truth)
}
if not answer or not ground_truth:
return 0.0, analysis
# First check for number mismatches BEFORE normalization
answer_nums = set(re.findall(r'\d+(?:\.\d+)?', answer.lower()))
truth_nums = set(re.findall(r'\d+(?:\.\d+)?', ground_truth.lower()))
# If numbers differ significantly, this is a factual error
number_mismatch = False
if truth_nums and answer_nums and truth_nums != answer_nums:
number_mismatch = True
analysis["number_mismatch"] = True
normalized_answer = normalize_text(answer)
normalized_truth = normalize_text(ground_truth)
# Exact match
if normalized_answer == normalized_truth:
analysis["exact_match"] = True
# But if numbers mismatch, reduce score
if number_mismatch:
return 0.2, analysis
return 1.0, analysis
# Check if truth is contained in answer
if normalized_truth in normalized_answer:
analysis["key_info_present"] = True
if number_mismatch:
return 0.3, analysis
return 0.95, analysis
# Entity overlap analysis
answer_entities = extract_entities(answer)
truth_entities = extract_entities(ground_truth)
if truth_entities:
entity_overlap = len(answer_entities & truth_entities) / len(truth_entities)
analysis["entity_overlap"] = entity_overlap
else:
analysis["entity_overlap"] = 1.0 # No entities to check
# Word-level analysis
truth_words = set(normalized_truth.split())
answer_words = set(normalized_answer.split())
# Remove stop words for content comparison
stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been',
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
'would', 'could', 'should', 'may', 'might', 'must', 'to', 'of',
'in', 'for', 'on', 'with', 'at', 'by', 'from', 'as', 'into'}
truth_content = truth_words - stop_words
answer_content = answer_words - stop_words
if truth_content:
content_overlap = len(answer_content & truth_content) / len(truth_content)
analysis["key_info_present"] = content_overlap > 0.5
else:
content_overlap = 1.0
# Sequence similarity
sequence_sim = SequenceMatcher(None, normalized_answer, normalized_truth).ratio()
# Entity overlap is critical - if different entities, penalize heavily
entity_penalty = 1.0 if analysis["entity_overlap"] > 0.5 else 0.3
# Check for numerical contradiction BEFORE normalization (e.g., $50,000 vs $30,000)
answer_nums = set(re.findall(r'\d+(?:\.\d+)?', answer.lower()))
truth_nums = set(re.findall(r'\d+(?:\.\d+)?', ground_truth.lower()))
number_mismatch = truth_nums and answer_nums and answer_nums != truth_nums
number_penalty = 0.2 if number_mismatch else 1.0
# Combine scores
semantic_similarity = max(
content_overlap * entity_penalty * number_penalty,
sequence_sim * 0.8 * entity_penalty * number_penalty,
analysis["entity_overlap"] * 0.7 * number_penalty
)
# Apply contradiction penalty for conflicting info
if sequence_sim < 0.5 and content_overlap < 0.5:
semantic_similarity *= 0.3 # Strong penalty for wrong answers
analysis["semantic_similarity"] = semantic_similarity
# Check for contradictions (simple heuristic: negation words)
negation_words = {'not', 'no', 'never', 'none', 'neither', 'nobody', 'nothing'}
has_negation_answer = any(w in answer_words for w in negation_words)
has_negation_truth = any(w in truth_words for w in negation_words)
if has_negation_answer != has_negation_truth:
analysis["contradictions"].append("negation_mismatch")
semantic_similarity *= 0.5
# Number mismatch already handled above
if number_mismatch:
analysis["number_mismatch"] = True
semantic_similarity = min(semantic_similarity, 0.3)
return semantic_similarity, analysis
def detect_hallucination_advanced(
answer: str,
context: str,
ground_truth: str = "",
confidence: float = 0.5
) -> Tuple[float, HallucinationType, HallucinationSeverity, Dict[str, Any]]:
"""
Advanced hallucination detection with type classification and severity scoring.
Returns:
Tuple of (hallucination_score, hallucination_type, severity, analysis)
"""
analysis = {
"word_coverage": 0.0,
"entity_hallucination": 0.0,
"numerical_fabrication": 0.0,
"temporal_errors": 0.0,
"relationship_errors": 0.0,
"confidence_mismatch": 0.0,
"answer_truth_overlap": 0.0
}
if not answer:
return 0.0, HallucinationType.NONE, HallucinationSeverity.NONE, analysis
normalized_answer = normalize_text(answer)
normalized_context = normalize_text(context)
# Stop words
stop_words = {
'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been',
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
'would', 'could', 'should', 'may', 'might', 'must', 'shall',
'can', 'need', 'dare', 'ought', 'used', 'to', 'of', 'in',
'for', 'on', 'with', 'at', 'by', 'from', 'as', 'into',
'through', 'during', 'before', 'after', 'above', 'below',
'between', 'under', 'and', 'but', 'or', 'yet', 'so',
'if', 'because', 'although', 'though', 'while', 'where',
'when', 'that', 'which', 'who', 'whom', 'whose', 'what',
'this', 'these', 'those', 'i', 'you', 'he', 'she', 'it',
'we', 'they', 'them', 'their', 'there', 'then', 'than'
}
# Word coverage analysis
answer_words = set(normalized_answer.split())
context_words = set(normalized_context.split())
content_words = answer_words - stop_words
context_content = context_words - stop_words
if content_words:
words_in_context = content_words & context_content
analysis["word_coverage"] = len(words_in_context) / len(content_words)
else:
analysis["word_coverage"] = 1.0
# Entity hallucination check - more aggressive detection
answer_entities = extract_entities(answer)
context_entities = extract_entities(context)
novel_entities = answer_entities - context_entities
missing_entities = context_entities - answer_entities
if answer_entities:
# Base entity hallucination ratio
analysis["entity_hallucination"] = len(novel_entities) / len(answer_entities)
# Boost if there are ANY novel entities (fabricated facts)
if novel_entities:
analysis["entity_hallucination"] = max(analysis["entity_hallucination"], 0.5)
# Also penalize if context entities are missing from answer
if missing_entities and context_entities:
analysis["entity_hallucination"] += 0.2 * (len(missing_entities) / len(context_entities))
analysis["entity_hallucination"] = min(1.0, analysis["entity_hallucination"])
else:
analysis["entity_hallucination"] = 0.0
# Numerical fabrication check — extract numbers from ORIGINAL text
# (normalize_text replaces numbers with NUM, so we must check before normalization)
answer_numbers = set(re.findall(r'\d+(?:\.\d+)?', answer))
context_numbers = set(re.findall(r'\d+(?:\.\d+)?', context))
novel_numbers = answer_numbers - context_numbers
if answer_numbers:
analysis["numerical_fabrication"] = min(1.0, len(novel_numbers) / len(answer_numbers))
else:
analysis["numerical_fabrication"] = 0.0
# Compute base hallucination score
hallucination_score = 0.0
# Check for contradiction with ground truth (most important signal)
if ground_truth:
truth_sim, _ = check_factual_accuracy_advanced(answer, ground_truth, "")
analysis["answer_truth_overlap"] = truth_sim
# Strong penalty for wrong answers
if truth_sim < 0.5: # Answer doesn't match ground truth
hallucination_score += 0.4 * (1.0 - truth_sim)
if truth_sim < 0.3: # Answer contradicts truth
hallucination_score += 0.3
analysis["contradiction_with_truth"] = True
# Additional penalty for very low similarity
if truth_sim < 0.2:
hallucination_score += 0.2
# Low word coverage indicates hallucination
if analysis["word_coverage"] < 0.5:
hallucination_score += 0.3 * (1.0 - analysis["word_coverage"])
# Entity hallucination is strong signal - boost the contribution
if analysis["entity_hallucination"] > 0.3:
hallucination_score += 0.4 * analysis["entity_hallucination"]
elif analysis["entity_hallucination"] > 0.1:
hallucination_score += 0.2 * analysis["entity_hallucination"]
# Numerical fabrication
if analysis["numerical_fabrication"] > 0:
hallucination_score += 0.35 * analysis["numerical_fabrication"]
# Confidence mismatch penalty (overconfident wrong answers)
if (analysis["word_coverage"] < 0.5 or analysis["entity_hallucination"] > 0.3) and confidence > 0.7:
analysis["confidence_mismatch"] = confidence - 0.5
hallucination_score += 0.3 * analysis["confidence_mismatch"]
# Cap at 1.0
hallucination_score = min(1.0, hallucination_score)
# Classify hallucination type
hallucination_type = HallucinationType.NONE
if analysis["numerical_fabrication"] > 0.5:
hallucination_type = HallucinationType.NUMERICAL_FABRICATION
elif analysis["entity_hallucination"] > 0.5:
hallucination_type = HallucinationType.ENTITY_CONFUSION
elif analysis.get("answer_truth_overlap", 1.0) < 0.3:
hallucination_type = HallucinationType.FABRICATED_FACT
elif analysis["word_coverage"] < 0.3:
hallucination_type = HallucinationType.FABRICATED_FACT
elif analysis["confidence_mismatch"] > 0.3:
hallucination_type = HallucinationType.OVERCONFIDENT_WRONG
# Determine severity
severity = HallucinationSeverity.NONE
if hallucination_score >= 0.7:
severity = HallucinationSeverity.CRITICAL
elif hallucination_score >= 0.5:
severity = HallucinationSeverity.SEVERE
elif hallucination_score >= 0.3:
severity = HallucinationSeverity.MODERATE
elif hallucination_score > 0.1:
severity = HallucinationSeverity.MINOR
return hallucination_score, hallucination_type, severity, analysis
def compute_calibration_error(confidence: float, correctness: float) -> float:
"""
Compute calibration error between confidence and actual correctness.
Perfect calibration: confidence == correctness
Overconfidence: confidence > correctness (dangerous)
Underconfidence: confidence < correctness (safe but inefficient)
"""
base_error = abs(confidence - correctness)
# Penalize overconfidence more heavily
if confidence > correctness:
overconfidence_penalty = (confidence - correctness) * 0.5
base_error += overconfidence_penalty
return min(1.0, base_error)
def compute_expected_calibration_error(
confidence_history: List[float],
correctness_history: List[float],
num_bins: int = 10
) -> float:
"""
Compute Expected Calibration Error (ECE) with confidence binning.
ECE measures how well-calibrated confidence estimates are across all predictions.
Lower ECE = better calibration. Perfect calibration = 0.0.
Args:
confidence_history: List of confidence scores (0-1)
correctness_history: List of correctness scores (0-1)
num_bins: Number of confidence bins (default 10)
Returns:
ECE score (0-1, lower is better)
Reference: Guo et al., "On Calibration of Modern Neural Networks" (ICML 2017)
"""
if not confidence_history or not correctness_history:
return 0.0
try:
import numpy as np
confidence_arr = np.array(confidence_history)
correctness_arr = np.array(correctness_history)
# Create bins
bins = np.linspace(0, 1, num_bins + 1)
ece = 0.0
for i in range(num_bins):
# Find samples in this bin
if i == num_bins - 1:
# Include 1.0 in last bin
mask = (confidence_arr >= bins[i]) & (confidence_arr <= bins[i + 1])
else:
mask = (confidence_arr >= bins[i]) & (confidence_arr < bins[i + 1])
bin_count = mask.sum()
if bin_count > 0:
bin_confidence = confidence_arr[mask].mean()
bin_accuracy = correctness_arr[mask].mean()
# Weight by proportion of samples in this bin
ece += (bin_count / len(confidence_arr)) * abs(bin_accuracy - bin_confidence)
return float(min(1.0, ece))
except Exception:
# Fallback to simple calibration error
if len(confidence_history) == 0:
return 0.0
return sum(abs(c - r) for c, r in zip(confidence_history, correctness_history)) / len(confidence_history)
def compute_semantic_consistency(answer: str, context: str, ground_truth: str) -> Tuple[float, Dict[str, Any]]:
"""
Compute semantic consistency between answer, context, and ground truth.
When sentence-transformers is available:
- Uses NLI cross-encoder (nli-deberta-v3-small) for entailment/contradiction
detection between (context, answer) and (ground_truth, answer) pairs.
- Scores: entailment -> high consistency, contradiction -> low, neutral -> mid.
Falls back to embedding cosine similarity + negation heuristics when the
cross-encoder is not installed.
"""
analysis = {
"context_answer_similarity": 0.0,
"truth_answer_similarity": 0.0,
"key_claim_overlap": 0.0,
"contradiction_detected": False,
"entailment_score": 0.0,
"nli_used": False,
}
if not answer:
return 0.0, analysis
# ── NLI cross-encoder path ────────────────────────────────────────────
nli = _get_nli()
if nli is not None and context and ground_truth:
try:
# Pairs: (premise, hypothesis). Labels: contradiction=0, entailment=1, neutral=2
pairs = [(context, answer), (ground_truth, answer)]
scores = nli.predict(pairs, apply_softmax=True)
# scores shape: (2, 3) — [contradiction, entailment, neutral]
ctx_entail = float(scores[0][1]) # context entails answer
ctx_contra = float(scores[0][0]) # context contradicts answer
truth_entail = float(scores[1][1]) # ground_truth entails answer
truth_contra = float(scores[1][0])
analysis["entailment_score"] = (ctx_entail + truth_entail) / 2.0
analysis["contradiction_detected"] = (ctx_contra > 0.5) or (truth_contra > 0.5)
analysis["nli_used"] = True
# Consistency = weighted entailment, penalised by contradiction
# Short answers (<=10 words) get low context-entailment from NLI models
# because a paragraph doesn't naturally "entail" a single word.
# Weight truth-entailment more for short answers to compensate.
answer_word_count = len(answer.split())
if answer_word_count <= 5:
# Short answer: trust truth-entailment primarily
truth_weight = 0.8
ctx_weight = 0.2
elif answer_word_count <= 15:
truth_weight = 0.6
ctx_weight = 0.4
else:
truth_weight = 0.5
ctx_weight = 0.5
consistency_score = (ctx_entail * ctx_weight + truth_entail * truth_weight)
if analysis["contradiction_detected"]:
consistency_score *= max(0.1, 1.0 - max(ctx_contra, truth_contra))
analysis["context_answer_similarity"] = ctx_entail
analysis["truth_answer_similarity"] = truth_entail
return max(0.0, min(1.0, consistency_score)), analysis
except Exception as e:
logger.warning(f"NLI inference failed: {e}; falling back to similarity")
# ── Embedding similarity path ─────────────────────────────────────────
analysis["context_answer_similarity"] = compute_string_similarity(answer, context)
analysis["truth_answer_similarity"] = compute_string_similarity(answer, ground_truth)
answer_claims = extract_key_claims(answer)
context_claims = extract_key_claims(context)
if answer_claims and context_claims:
matching = sum(
1 for ac in answer_claims
if any(compute_string_similarity(ac, cc) > 0.6 for cc in context_claims)
)
analysis["key_claim_overlap"] = matching / len(answer_claims)
# Improved negation-based contradiction detection:
# Check for negation asymmetry around the same key noun phrase
negation_re = re.compile(r"\b(not|no|never|none|neither|isn't|aren't|wasn't|weren't|doesn't|don't)\b")
answer_negated = bool(negation_re.search(answer.lower()))
truth_negated = bool(negation_re.search(ground_truth.lower()))
# Only flag contradiction if they share vocabulary but negate differently
shared_words = set(normalize_text(answer).split()) & set(normalize_text(ground_truth).split())
if answer_negated != truth_negated and len(shared_words) >= 2:
analysis["contradiction_detected"] = True
consistency_score = (
0.4 * analysis["key_claim_overlap"] +
0.3 * analysis["context_answer_similarity"] +
0.3 * analysis["truth_answer_similarity"]
)
if analysis["contradiction_detected"]:
consistency_score *= 0.5
return max(0.0, min(1.0, consistency_score)), analysis
def is_refusal_answer(answer: str) -> Tuple[bool, float]:
"""
Detect if the answer is a proper refusal ("I don't know" style response).
Returns:
Tuple of (is_refusal, confidence_score)
- is_refusal: True if answer is a refusal
- confidence_score: 0.6-0.8 for proper low-confidence refusals
"""
if not answer:
return True, 0.5
answer_lower = answer.lower().strip()
# Common refusal phrases
refusal_phrases = [
"i don't know",
"i cannot answer",
"i can't answer",
"i am unable to answer",
"i'm unable to answer",
"not mentioned",
"not provided",
"not in the context",
"not in context",
"cannot be determined",
"cannot determine",
"i cannot determine",
"i can't determine",
"can't be determined",
"insufficient information",
"not enough information",
"no information",
"the context does not",
"the document does not",
"i cannot find",
"i can't find",
"not stated",
"not specified",
"unknown",
]
for phrase in refusal_phrases:
if phrase in answer_lower:
# Check if it's a proper low-confidence refusal
if len(answer_lower) < 100: # Short refusal is best
return True, 0.75
else:
return True, 0.65
# Very short non-committal answers
if len(answer_lower) < 15 and any(w in answer_lower for w in ["unknown", "unclear", "uncertain", "not sure"]):
return True, 0.70
return False, 0.0
def calculate_reward(
answer: str,
confidence: float,
source_quote: str,
context: str,
ground_truth: str,
difficulty_level: str = "intermediate",
difficulty: str = None,
previous_performance: float = 0.5,
recent_rewards: list = None,
reward_weights: Optional[Dict[str, float]] = None
) -> Tuple[float, Dict[str, Any]]:
"""
Calculate comprehensive multi-factor reward.
This is the main entry point for reward calculation, combining:
1. Factual correctness (30%)
2. Source grounding (20%)
3. Citation accuracy (15%)
4. Confidence calibration (15%)
5. Semantic consistency (10%)
6. Hallucination penalty (10%)
Plus difficulty bonuses and consistency bonuses.
Args:
answer: The AI's answer
confidence: AI's confidence level (0-1)
source_quote: Quote cited from context
context: The source document
ground_truth: The correct answer
difficulty_level: Question difficulty
previous_performance: Running performance metric
reward_weights: Optional custom weights
Returns:
Tuple of (total_reward, info_dict)
"""
# Resolve parameter aliases for README compatibility
if difficulty is not None:
difficulty_level = difficulty
if recent_rewards is not None and len(recent_rewards) > 0:
previous_performance = sum(recent_rewards) / len(recent_rewards)
# Strip<think> blocks from Nemotron 3 Super and other reasoning models
# before any grading — the thinking trace is not part of the answer
answer = _strip_thinking(answer)
if source_quote:
source_quote = _strip_thinking(source_quote)
# Check if this is a refusal answer ("I don't know" style)
is_refusal, refusal_confidence_score = is_refusal_answer(answer)
# For adversarial questions (ground_truth indicates unanswerable), reward proper refusals
ground_truth_lower = ground_truth.lower() if ground_truth else ""
is_unanswerable = any(marker in ground_truth_lower for marker in [
"not mentioned", "not in context", "unknown", "unanswerable",
"cannot be determined", "insufficient", "no information"
])
# If this is an unanswerable question and agent properly refused
if is_unanswerable and is_refusal:
refusal_reward = refusal_confidence_score
if confidence <= 0.5:
refusal_reward = min(1.0, refusal_reward + 0.15)
return refusal_reward, {
"correctness": refusal_reward,
"grounding": 1.0,
"calibration": 1.0 if confidence <= 0.5 else 0.7,
"semantic_consistency": 1.0,
"hallucination_score": 0.0,
"hallucination_penalty": 1.0,
"is_hallucination": False,
"hallucination_type": "none",
"hallucination_severity": "NONE",
"is_refusal": True,
"is_unanswerable": True,
"total_reward": refusal_reward,
"feedback": "Properly refused to answer unanswerable question.",
"confidence": confidence,
}
# If refusal but question IS answerable, it's underconfident
if is_refusal and not is_unanswerable:
return 0.3, {
"correctness": 0.0,
"grounding": 0.5,
"calibration": 0.5 if confidence <= 0.3 else 0.3,
"semantic_consistency": 0.5,
"hallucination_score": 0.0,
"hallucination_penalty": 1.0,
"is_hallucination": False,
"hallucination_type": "none",
"hallucination_severity": "NONE",
"is_refusal": True,
"is_unanswerable": False,
"total_reward": 0.3,
"feedback": "Underconfident refusal — answer exists in context.",
"confidence": confidence,
}
# Default weights - tuned for proper reward calibration
# Grounded correct answers should receive 0.6+ rewards
# Incorrect but grounded answers should receive < 0.4
if reward_weights is None:
reward_weights = {
"factual_correctness": 0.35,
"source_grounding": 0.20,
"citation_accuracy": 0.10,
"confidence_calibration": 0.10,
"semantic_consistency": 0.10,
"hallucination_penalty": 0.10,
"rouge_score": 0.02,
"bertscore": 0.02,
"alignscore": 0.01,
}
# Component 1: Factual correctness
correctness, correctness_analysis = check_factual_accuracy_advanced(answer, ground_truth, context)
# Component 2 & 3: Source grounding and citation accuracy
grounding_score, citation_analysis = check_quote_in_context_advanced(source_quote, context)
# Component 4: Confidence calibration
calibration_error = compute_calibration_error(confidence, correctness)
calibration_score = 1.0 - calibration_error
# Component 5: Semantic consistency
semantic_score, semantic_analysis = compute_semantic_consistency(answer, context, ground_truth)
# Component 6: Hallucination detection
hallucination_score, hallucination_type, hallucination_severity, hallucination_analysis = \
detect_hallucination_advanced(answer, context, ground_truth, confidence)
hallucination_penalty_score = 1.0 - hallucination_score
# Component 7: ROUGE
rouge_scores = compute_rouge(answer, ground_truth)
rouge_combined = (
0.2 * rouge_scores["rouge1"] +
0.3 * rouge_scores["rouge2"] +
0.5 * rouge_scores["rougeL"]
)
# Component 8: BERTScore
bs_scores = compute_bertscore(answer, ground_truth)
bertscore_f1 = bs_scores["f1"]
# Component 9: AlignScore
align_score = compute_alignscore(context, answer)
# Factual correctness gate: if answer is factually wrong, cap the reward
# This prevents high rewards for well-grounded but incorrect answers
# A wrong answer should still get some credit for being grounded (partial credit)
factual_cap = min(1.0, 0.40 + 0.60 * correctness) # Minimum 0.40 for grounded wrong answers
# Grounding contribution is reduced for incorrect answers
# but still gives significant credit for being grounded
effective_grounding = grounding_score * (0.7 + 0.3 * correctness)
# Calculate base reward
base_reward = (
reward_weights["factual_correctness"] * correctness +
reward_weights["source_grounding"] * effective_grounding +
reward_weights["citation_accuracy"] * min(citation_analysis.get("best_match_score", 0.0), factual_cap) +
reward_weights["confidence_calibration"] * calibration_score +
reward_weights["semantic_consistency"] * min(semantic_score, factual_cap) +
reward_weights["hallucination_penalty"] * hallucination_penalty_score +
reward_weights.get("rouge_score", 0.02) * min(rouge_combined, factual_cap) +
reward_weights.get("bertscore", 0.02) * min(bertscore_f1, factual_cap) +
reward_weights.get("alignscore", 0.01) * min(align_score, factual_cap)
)
# Difficulty bonus
difficulty_multipliers = {
"beginner": 0.9,
"intermediate": 1.0,
"advanced": 1.1,
"expert": 1.2
}
difficulty_multiplier = difficulty_multipliers.get(difficulty_level.lower(), 1.0)
# Consistency bonus (for maintaining good performance)
consistency_bonus = 0.0
if previous_performance > 0.7:
consistency_bonus = 0.05 * (previous_performance - 0.7) / 0.3
# Apply adjustments
total_reward = base_reward * difficulty_multiplier + consistency_bonus
total_reward = max(0.0, min(1.0, total_reward)) # Clamp to [0, 1]
# Determine if hallucination occurred
is_hallucination = hallucination_score > 0.5
# Build comprehensive info dict
info = {
# Core scores
"correctness": correctness,
"grounding": grounding_score,
"calibration": calibration_score,
"semantic_consistency": semantic_score,
"hallucination_score": hallucination_score,
"hallucination_penalty": hallucination_penalty_score,
# Classification
"is_hallucination": is_hallucination,
"hallucination_type": hallucination_type.value,
"hallucination_severity": hallucination_severity.name,
# Reward breakdown
"total_reward": total_reward,
"base_reward": base_reward,
"difficulty_multiplier": difficulty_multiplier,
"consistency_bonus": consistency_bonus,
# Component contributions (matching actual reward calculation)
"components": {
"correctness_contrib": reward_weights["factual_correctness"] * correctness,
"grounding_contrib": reward_weights["source_grounding"] * effective_grounding,
"citation_contrib": reward_weights["citation_accuracy"] * min(citation_analysis.get("best_match_score", 0.0), factual_cap),
"calibration_contrib": reward_weights["confidence_calibration"] * calibration_score,
"semantic_contrib": reward_weights["semantic_consistency"] * min(semantic_score, factual_cap),
"hallucination_contrib": reward_weights["hallucination_penalty"] * hallucination_penalty_score,
},
# Research-grade metrics (v4.0)
"rouge": rouge_scores,
"rouge_combined": round(rouge_combined, 4),
"bertscore": bs_scores,
"alignscore": align_score,
# Component contributions (matching actual reward calculation with factual_cap)
"rouge_contrib": reward_weights.get("rouge_score", 0.02) * min(rouge_combined, factual_cap),
"bertscore_contrib": reward_weights.get("bertscore", 0.02) * min(bertscore_f1, factual_cap),
"alignscore_contrib": reward_weights.get("alignscore", 0.01) * min(align_score, factual_cap),
# Detailed analyses
"correctness_analysis": correctness_analysis,
"citation_analysis": citation_analysis,
"semantic_analysis": semantic_analysis,
"hallucination_analysis": hallucination_analysis,
# Hallucination explanation (human-readable)
"hallucination_explanation": explain_hallucination(hallucination_analysis) if is_hallucination else "",
# Confidence info
"confidence": confidence,
"calibration_error": calibration_error,
}
return total_reward, info
def generate_feedback(
answer: str,
ground_truth: str,
is_hallucination: bool,
hallucination_type: HallucinationType,
hallucination_severity: HallucinationSeverity,
grounding_score: float,
correctness: float,
calibration_score: float,
total_reward: float,
hallucination_analysis: Dict[str, Any] = None
) -> str:
"""Generate detailed, actionable feedback with hallucination explanation."""
feedback_parts = []
# Correctness feedback
if correctness > 0.8:
feedback_parts.append("Excellent! Answer is factually accurate.")
elif correctness > 0.5:
feedback_parts.append("Answer is partially correct but could be improved.")
else:
feedback_parts.append("Answer is factually incorrect.")
# Grounding feedback
if grounding_score > 0.8:
feedback_parts.append("Source citation is verified in context.")
elif grounding_score > 0.5:
feedback_parts.append("Source citation partially matches context.")
else:
feedback_parts.append("WARNING: Source citation NOT found in context.")
# Hallucination feedback with explanation
if is_hallucination:
severity_str = hallucination_severity.name.lower()
type_str = hallucination_type.value.replace("_", " ")
feedback_parts.append(f"HALLUCINATION DETECTED ({severity_str}): {type_str}.")
# Add explanation based on hallucination analysis
if hallucination_analysis:
if hallucination_analysis.get("entity_hallucination", 0) > 0.3:
entities = hallucination_analysis.get("novel_entities", [])
if entities:
feedback_parts.append(f"Fabricated entities: {', '.join(list(entities)[:3])}.")
if hallucination_analysis.get("numerical_fabrication", 0) > 0.3:
feedback_parts.append("Numbers in answer not found in context.")
if hallucination_analysis.get("word_coverage", 1.0) < 0.5:
feedback_parts.append(f"Only {int(hallucination_analysis.get('word_coverage', 0) * 100)}% of answer words appear in context.")
if hallucination_analysis.get("confidence_mismatch", 0) > 0.2:
feedback_parts.append("Confidence too high for answer quality.")
if hallucination_severity in [HallucinationSeverity.SEVERE, HallucinationSeverity.CRITICAL]:
feedback_parts.append("This is a serious hallucination that significantly undermines trust.")
# Calibration feedback
if calibration_score > 0.8:
feedback_parts.append("Confidence level is well-calibrated.")
elif calibration_score < 0.5:
feedback_parts.append("WARNING: Confidence is poorly calibrated to accuracy.")
# Summary
if total_reward > 0.8:
feedback_parts.append("Overall: OUTSTANDING performance!")
elif total_reward > 0.6:
feedback_parts.append("Overall: Good performance with room for improvement.")
elif total_reward > 0.4:
feedback_parts.append("Overall: Adequate but needs significant improvement.")
else:
feedback_parts.append("Overall: Poor performance - review and recalibrate.")
return " ".join(feedback_parts)
def explain_hallucination(hallucination_analysis: Dict[str, Any]) -> str:
"""
Generate a human-readable explanation of why hallucination was detected.
Returns a concise explanation suitable for debugging or user feedback.
"""
if not hallucination_analysis:
return "No hallucination analysis available."
explanations = []
entity_score = hallucination_analysis.get("entity_hallucination", 0)
if entity_score > 0.5:
explanations.append(f"Entity hallucination ({entity_score:.0%}): Answer contains names/entities not in source.")
num_score = hallucination_analysis.get("numerical_fabrication", 0)
if num_score > 0.3:
explanations.append(f"Numerical fabrication ({num_score:.0%}): Numbers invented or misstated.")
word_coverage = hallucination_analysis.get("word_coverage", 1.0)
if word_coverage < 0.5:
explanations.append(f"Low word coverage ({word_coverage:.0%}): Many answer words not in context.")
truth_overlap = hallucination_analysis.get("answer_truth_overlap", 1.0)
if truth_overlap < 0.3:
explanations.append(f"Ground truth mismatch ({truth_overlap:.0%}): Answer differs from correct answer.")
confidence_mismatch = hallucination_analysis.get("confidence_mismatch", 0)
if confidence_mismatch > 0.3:
explanations.append(f"Overconfidence ({confidence_mismatch:.0%}): Confidence exceeds answer quality.")
if not explanations:
return "Hallucination detected but specific cause unclear."
return " | ".join(explanations)
def generate_feedback_from_info(info: Dict[str, Any]) -> str:
"""Convenience wrapper: generate feedback from the info dict returned by calculate_reward."""
return generate_feedback(
answer=info.get("answer", ""),
ground_truth=info.get("ground_truth", ""),
is_hallucination=info.get("is_hallucination", False),
hallucination_type=info.get("hallucination_type", HallucinationType.NONE),
hallucination_severity=info.get("hallucination_severity", HallucinationSeverity.NONE),
grounding_score=info.get("grounding_score", 0.0),
correctness=info.get("correctness", 0.0),
calibration_score=info.get("calibration_score", 0.0),
total_reward=info.get("total_reward", 0.0),
)
def get_reward_breakdown(info: Dict[str, Any]) -> Dict[str, Any]:
"""Convert info dict to RewardBreakdown dataclass format."""
from models import RewardBreakdown
return {
"factual_correctness": info.get("correctness", 0.0),
"source_grounding": info.get("grounding", 0.0),
"citation_accuracy": info.get("citation_analysis", {}).get("best_match_score", 0.0),
"confidence_calibration": info.get("calibration", 0.0),
"semantic_consistency": info.get("semantic_consistency", 0.0),
"hallucination_penalty": info.get("hallucination_penalty", 0.0),
"total": info.get("total_reward", 0.0),
}