Spaces:
Sleeping
Sleeping
| """ | |
| Firewall Decision Engine for VDHF | |
| Controls output delivery based on verification score and threshold. | |
| """ | |
| from typing import List, Dict, Any, Optional | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from config.settings import FIREWALL_THRESHOLD | |
| from core.verifier import VerificationResult, ClaimVerifier | |
| from core.claim_extractor import Claim | |
| class FirewallDecision(Enum): | |
| """Possible firewall decisions.""" | |
| PASS = "PASS" # Response is safe to deliver | |
| REGENERATE = "REGENERATE" # Trigger prompt refinement | |
| class FirewallResult: | |
| """Result of firewall evaluation.""" | |
| decision: FirewallDecision | |
| support_ratio: float | |
| threshold: float | |
| total_claims: int | |
| supported_claims: int | |
| unsupported_claims: int | |
| verification_results: List[VerificationResult] | |
| def is_safe(self) -> bool: | |
| return self.decision == FirewallDecision.PASS | |
| def __str__(self) -> str: | |
| status = "✓ PASSED" if self.is_safe else "✗ BLOCKED" | |
| return ( | |
| f"Firewall Decision: {status}\n" | |
| f" Support Ratio: {self.support_ratio:.2%} (threshold: {self.threshold:.2%})\n" | |
| f" Claims: {self.supported_claims}/{self.total_claims} supported" | |
| ) | |
| class FirewallScoringModule: | |
| """ | |
| Firewall Scoring Module | |
| Computes global verification score for the response. | |
| Metric: | |
| SupportRatio = (Number of SUPPORTED claims) / (Total number of claims) | |
| """ | |
| def __init__(self, threshold: float = FIREWALL_THRESHOLD): | |
| self.threshold = threshold | |
| def compute_support_ratio( | |
| self, | |
| verification_results: List[VerificationResult] | |
| ) -> float: | |
| """ | |
| Compute the support ratio for verification results. | |
| Args: | |
| verification_results: List of VerificationResult objects | |
| Returns: | |
| Support ratio (0.0 to 1.0) | |
| """ | |
| if not verification_results: | |
| return 1.0 # No claims = safe | |
| supported = sum(1 for r in verification_results if r.is_supported) | |
| total = len(verification_results) | |
| return supported / total | |
| def get_detailed_scores( | |
| self, | |
| verification_results: List[VerificationResult] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get detailed scoring breakdown. | |
| Args: | |
| verification_results: List of VerificationResult objects | |
| Returns: | |
| Dictionary with detailed scores | |
| """ | |
| total = len(verification_results) | |
| supported = sum(1 for r in verification_results if r.is_supported) | |
| similarity_scores = [r.similarity_score for r in verification_results] | |
| entailment_scores = [r.entailment_score for r in verification_results] | |
| return { | |
| 'support_ratio': supported / total if total > 0 else 1.0, | |
| 'total_claims': total, | |
| 'supported_claims': supported, | |
| 'unsupported_claims': total - supported, | |
| 'average_similarity': sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0.0, | |
| 'average_entailment': sum(entailment_scores) / len(entailment_scores) if entailment_scores else 0.0, | |
| 'min_similarity': min(similarity_scores) if similarity_scores else 0.0, | |
| 'max_similarity': max(similarity_scores) if similarity_scores else 0.0, | |
| } | |
| class FirewallDecisionEngine: | |
| """ | |
| Firewall Decision Engine | |
| Purpose: | |
| Control output delivery based on verification score. | |
| Decision Logic: | |
| If SupportRatio >= tau: | |
| Allow the response to pass to the user | |
| Else: | |
| Trigger prompt refinement and regeneration | |
| """ | |
| def __init__(self, threshold: float = FIREWALL_THRESHOLD): | |
| self.threshold = threshold | |
| self.scoring_module = FirewallScoringModule(threshold) | |
| def evaluate( | |
| self, | |
| verification_results: List[VerificationResult] | |
| ) -> FirewallResult: | |
| """ | |
| Evaluate whether response should pass the firewall. | |
| Args: | |
| verification_results: List of VerificationResult objects | |
| Returns: | |
| FirewallResult object with decision | |
| """ | |
| # Compute support ratio | |
| support_ratio = self.scoring_module.compute_support_ratio(verification_results) | |
| # Make decision | |
| if support_ratio >= self.threshold: | |
| decision = FirewallDecision.PASS | |
| else: | |
| decision = FirewallDecision.REGENERATE | |
| # Count claims | |
| total = len(verification_results) | |
| supported = sum(1 for r in verification_results if r.is_supported) | |
| return FirewallResult( | |
| decision=decision, | |
| support_ratio=support_ratio, | |
| threshold=self.threshold, | |
| total_claims=total, | |
| supported_claims=supported, | |
| unsupported_claims=total - supported, | |
| verification_results=verification_results | |
| ) | |
| def get_unsupported_claims( | |
| self, | |
| firewall_result: FirewallResult | |
| ) -> List[Claim]: | |
| """ | |
| Get list of unsupported claims for prompt refinement. | |
| Args: | |
| firewall_result: FirewallResult object | |
| Returns: | |
| List of unsupported Claim objects | |
| """ | |
| unsupported = [] | |
| for result in firewall_result.verification_results: | |
| if not result.is_supported: | |
| unsupported.append(result.claim) | |
| return unsupported | |
| def get_supported_claims( | |
| self, | |
| firewall_result: FirewallResult | |
| ) -> List[Claim]: | |
| """ | |
| Get list of supported claims. | |
| Args: | |
| firewall_result: FirewallResult object | |
| Returns: | |
| List of supported Claim objects | |
| """ | |
| supported = [] | |
| for result in firewall_result.verification_results: | |
| if result.is_supported: | |
| supported.append(result.claim) | |
| return supported | |
| def get_verified_evidence( | |
| self, | |
| firewall_result: FirewallResult | |
| ) -> List[str]: | |
| """ | |
| Get evidence that supports verified claims. | |
| Args: | |
| firewall_result: FirewallResult object | |
| Returns: | |
| List of evidence strings | |
| """ | |
| evidence = set() | |
| for result in firewall_result.verification_results: | |
| if result.is_supported and result.best_evidence: | |
| evidence.add(result.best_evidence) | |
| return list(evidence) | |
| class HallucinationFirewall: | |
| """ | |
| Complete Hallucination Firewall | |
| Combines verification and decision-making into a single interface. | |
| """ | |
| def __init__( | |
| self, | |
| similarity_threshold: float = None, | |
| firewall_threshold: float = FIREWALL_THRESHOLD | |
| ): | |
| from config.settings import SIMILARITY_THRESHOLD | |
| self.similarity_threshold = similarity_threshold or SIMILARITY_THRESHOLD | |
| self.firewall_threshold = firewall_threshold | |
| self.verifier = ClaimVerifier(self.similarity_threshold) | |
| self.decision_engine = FirewallDecisionEngine(firewall_threshold) | |
| def check_response( | |
| self, | |
| claims: List[Claim], | |
| evidence_list: List | |
| ) -> FirewallResult: | |
| """ | |
| Check a response through the firewall. | |
| Args: | |
| claims: Extracted claims from response | |
| evidence_list: Retrieved evidence | |
| Returns: | |
| FirewallResult with decision | |
| """ | |
| # Verify all claims | |
| verification_results = self.verifier.verify_all_claims(claims, evidence_list) | |
| # Evaluate through firewall | |
| result = self.decision_engine.evaluate(verification_results) | |
| return result | |
| def get_refinement_info( | |
| self, | |
| firewall_result: FirewallResult | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get information needed for prompt refinement. | |
| Args: | |
| firewall_result: FirewallResult from check_response | |
| Returns: | |
| Dictionary with refinement information | |
| """ | |
| return { | |
| 'unsupported_claims': self.decision_engine.get_unsupported_claims(firewall_result), | |
| 'supported_claims': self.decision_engine.get_supported_claims(firewall_result), | |
| 'verified_evidence': self.decision_engine.get_verified_evidence(firewall_result), | |
| 'support_ratio': firewall_result.support_ratio, | |
| 'needs_regeneration': not firewall_result.is_safe | |
| } | |