HallucinationFirewall / core /firewall.py
Ram-090's picture
Restructure project with FastAPI + React frontend and Railway deployment
1bbe15b
"""
Firewall Decision Engine for VDHF
Controls output delivery based on verification score and threshold.
"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from config.settings import FIREWALL_THRESHOLD
from core.verifier import VerificationResult, ClaimVerifier
from core.claim_extractor import Claim
class FirewallDecision(Enum):
"""Possible firewall decisions."""
PASS = "PASS" # Response is safe to deliver
REGENERATE = "REGENERATE" # Trigger prompt refinement
@dataclass
class FirewallResult:
"""Result of firewall evaluation."""
decision: FirewallDecision
support_ratio: float
threshold: float
total_claims: int
supported_claims: int
unsupported_claims: int
verification_results: List[VerificationResult]
@property
def is_safe(self) -> bool:
return self.decision == FirewallDecision.PASS
def __str__(self) -> str:
status = "✓ PASSED" if self.is_safe else "✗ BLOCKED"
return (
f"Firewall Decision: {status}\n"
f" Support Ratio: {self.support_ratio:.2%} (threshold: {self.threshold:.2%})\n"
f" Claims: {self.supported_claims}/{self.total_claims} supported"
)
class FirewallScoringModule:
"""
Firewall Scoring Module
Computes global verification score for the response.
Metric:
SupportRatio = (Number of SUPPORTED claims) / (Total number of claims)
"""
def __init__(self, threshold: float = FIREWALL_THRESHOLD):
self.threshold = threshold
def compute_support_ratio(
self,
verification_results: List[VerificationResult]
) -> float:
"""
Compute the support ratio for verification results.
Args:
verification_results: List of VerificationResult objects
Returns:
Support ratio (0.0 to 1.0)
"""
if not verification_results:
return 1.0 # No claims = safe
supported = sum(1 for r in verification_results if r.is_supported)
total = len(verification_results)
return supported / total
def get_detailed_scores(
self,
verification_results: List[VerificationResult]
) -> Dict[str, Any]:
"""
Get detailed scoring breakdown.
Args:
verification_results: List of VerificationResult objects
Returns:
Dictionary with detailed scores
"""
total = len(verification_results)
supported = sum(1 for r in verification_results if r.is_supported)
similarity_scores = [r.similarity_score for r in verification_results]
entailment_scores = [r.entailment_score for r in verification_results]
return {
'support_ratio': supported / total if total > 0 else 1.0,
'total_claims': total,
'supported_claims': supported,
'unsupported_claims': total - supported,
'average_similarity': sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0.0,
'average_entailment': sum(entailment_scores) / len(entailment_scores) if entailment_scores else 0.0,
'min_similarity': min(similarity_scores) if similarity_scores else 0.0,
'max_similarity': max(similarity_scores) if similarity_scores else 0.0,
}
class FirewallDecisionEngine:
"""
Firewall Decision Engine
Purpose:
Control output delivery based on verification score.
Decision Logic:
If SupportRatio >= tau:
Allow the response to pass to the user
Else:
Trigger prompt refinement and regeneration
"""
def __init__(self, threshold: float = FIREWALL_THRESHOLD):
self.threshold = threshold
self.scoring_module = FirewallScoringModule(threshold)
def evaluate(
self,
verification_results: List[VerificationResult]
) -> FirewallResult:
"""
Evaluate whether response should pass the firewall.
Args:
verification_results: List of VerificationResult objects
Returns:
FirewallResult object with decision
"""
# Compute support ratio
support_ratio = self.scoring_module.compute_support_ratio(verification_results)
# Make decision
if support_ratio >= self.threshold:
decision = FirewallDecision.PASS
else:
decision = FirewallDecision.REGENERATE
# Count claims
total = len(verification_results)
supported = sum(1 for r in verification_results if r.is_supported)
return FirewallResult(
decision=decision,
support_ratio=support_ratio,
threshold=self.threshold,
total_claims=total,
supported_claims=supported,
unsupported_claims=total - supported,
verification_results=verification_results
)
def get_unsupported_claims(
self,
firewall_result: FirewallResult
) -> List[Claim]:
"""
Get list of unsupported claims for prompt refinement.
Args:
firewall_result: FirewallResult object
Returns:
List of unsupported Claim objects
"""
unsupported = []
for result in firewall_result.verification_results:
if not result.is_supported:
unsupported.append(result.claim)
return unsupported
def get_supported_claims(
self,
firewall_result: FirewallResult
) -> List[Claim]:
"""
Get list of supported claims.
Args:
firewall_result: FirewallResult object
Returns:
List of supported Claim objects
"""
supported = []
for result in firewall_result.verification_results:
if result.is_supported:
supported.append(result.claim)
return supported
def get_verified_evidence(
self,
firewall_result: FirewallResult
) -> List[str]:
"""
Get evidence that supports verified claims.
Args:
firewall_result: FirewallResult object
Returns:
List of evidence strings
"""
evidence = set()
for result in firewall_result.verification_results:
if result.is_supported and result.best_evidence:
evidence.add(result.best_evidence)
return list(evidence)
class HallucinationFirewall:
"""
Complete Hallucination Firewall
Combines verification and decision-making into a single interface.
"""
def __init__(
self,
similarity_threshold: float = None,
firewall_threshold: float = FIREWALL_THRESHOLD
):
from config.settings import SIMILARITY_THRESHOLD
self.similarity_threshold = similarity_threshold or SIMILARITY_THRESHOLD
self.firewall_threshold = firewall_threshold
self.verifier = ClaimVerifier(self.similarity_threshold)
self.decision_engine = FirewallDecisionEngine(firewall_threshold)
def check_response(
self,
claims: List[Claim],
evidence_list: List
) -> FirewallResult:
"""
Check a response through the firewall.
Args:
claims: Extracted claims from response
evidence_list: Retrieved evidence
Returns:
FirewallResult with decision
"""
# Verify all claims
verification_results = self.verifier.verify_all_claims(claims, evidence_list)
# Evaluate through firewall
result = self.decision_engine.evaluate(verification_results)
return result
def get_refinement_info(
self,
firewall_result: FirewallResult
) -> Dict[str, Any]:
"""
Get information needed for prompt refinement.
Args:
firewall_result: FirewallResult from check_response
Returns:
Dictionary with refinement information
"""
return {
'unsupported_claims': self.decision_engine.get_unsupported_claims(firewall_result),
'supported_claims': self.decision_engine.get_supported_claims(firewall_result),
'verified_evidence': self.decision_engine.get_verified_evidence(firewall_result),
'support_ratio': firewall_result.support_ratio,
'needs_regeneration': not firewall_result.is_safe
}