Spaces:
Sleeping
Sleeping
| """Fraud Risk Agent - Model Contract Implementation | |
| This module implements the fraud-risk-agent model with strict JSON contract. | |
| Decision output: investigate | allow | |
| """ | |
| import json | |
| from typing import Dict, List, Any | |
| class FraudRiskAgent: | |
| """Fraud Risk Decision Agent with formal model contract.""" | |
| def __init__(self): | |
| self.model_version = "1.0.0" | |
| self.decision_threshold = 0.65 | |
| def analyze(self, claim_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Analyze claim and return decision contract. | |
| Args: | |
| claim_data: Structured claim information | |
| Returns: | |
| Model contract (STRICT JSON): | |
| { | |
| "fraud_score": float, | |
| "risk_band": "low | medium | high", | |
| "top_indicators": list, | |
| "recommended_action": "investigate | allow", | |
| "confidence": float, | |
| "explainability": { | |
| "signals": list, | |
| "weights": dict | |
| } | |
| } | |
| """ | |
| # Extract features | |
| amount = claim_data.get('amount', 0) | |
| claim_type = claim_data.get('type', 'unknown') | |
| claimant_history = claim_data.get('claimant_history', {}) | |
| # Calculate fraud indicators | |
| indicators = self._calculate_indicators(claim_data) | |
| fraud_score = self._calculate_fraud_score(indicators) | |
| risk_band = self._determine_risk_band(fraud_score) | |
| # Determine action | |
| recommended_action = "investigate" if fraud_score >= self.decision_threshold else "allow" | |
| # Build explainability | |
| explainability = self._build_explainability(indicators) | |
| # Return strict model contract | |
| return { | |
| "fraud_score": round(fraud_score, 3), | |
| "risk_band": risk_band, | |
| "top_indicators": self._get_top_indicators(indicators, n=5), | |
| "recommended_action": recommended_action, | |
| "confidence": round(self._calculate_confidence(indicators), 3), | |
| "explainability": explainability | |
| } | |
| def _calculate_indicators(self, claim_data: Dict[str, Any]) -> Dict[str, float]: | |
| """Calculate fraud indicators from claim data.""" | |
| indicators = {} | |
| # Amount deviation | |
| amount = claim_data.get('amount', 0) | |
| avg_amount = claim_data.get('average_claim_amount', 5000) | |
| indicators['amount_deviation'] = abs(amount - avg_amount) / avg_amount if avg_amount > 0 else 0 | |
| # Frequency signal | |
| claim_count = claim_data.get('claimant_history', {}).get('claim_count', 0) | |
| indicators['high_frequency'] = min(claim_count / 10.0, 1.0) | |
| # Temporal pattern | |
| days_since_policy = claim_data.get('days_since_policy_start', 365) | |
| indicators['early_claim'] = 1.0 if days_since_policy < 30 else 0.0 | |
| # Document consistency | |
| doc_score = claim_data.get('document_consistency_score', 1.0) | |
| indicators['document_mismatch'] = 1.0 - doc_score | |
| # Entity linkage | |
| linked_entities = claim_data.get('linked_suspicious_entities', 0) | |
| indicators['entity_linkage'] = min(linked_entities / 5.0, 1.0) | |
| return indicators | |
| def _calculate_fraud_score(self, indicators: Dict[str, float]) -> float: | |
| """Calculate weighted fraud score.""" | |
| weights = { | |
| 'amount_deviation': 0.25, | |
| 'high_frequency': 0.20, | |
| 'early_claim': 0.15, | |
| 'document_mismatch': 0.25, | |
| 'entity_linkage': 0.15 | |
| } | |
| score = sum(indicators.get(k, 0) * w for k, w in weights.items()) | |
| return min(max(score, 0.0), 1.0) | |
| def _determine_risk_band(self, fraud_score: float) -> str: | |
| """Determine risk band from fraud score.""" | |
| if fraud_score >= 0.7: | |
| return "high" | |
| elif fraud_score >= 0.4: | |
| return "medium" | |
| else: | |
| return "low" | |
| def _calculate_confidence(self, indicators: Dict[str, float]) -> float: | |
| """Calculate confidence in the decision.""" | |
| # Higher confidence when indicators are consistent | |
| variance = sum((v - 0.5) ** 2 for v in indicators.values()) / len(indicators) | |
| confidence = 1.0 - (variance * 2) | |
| return min(max(confidence, 0.0), 1.0) | |
| def _get_top_indicators(self, indicators: Dict[str, float], n: int = 5) -> List[str]: | |
| """Get top N fraud indicators.""" | |
| sorted_indicators = sorted(indicators.items(), key=lambda x: x[1], reverse=True) | |
| return [k for k, v in sorted_indicators[:n] if v > 0.1] | |
| def _build_explainability(self, indicators: Dict[str, float]) -> Dict[str, Any]: | |
| """Build explainability payload.""" | |
| signals = [] | |
| for indicator, value in indicators.items(): | |
| if value > 0.1: | |
| signals.append({ | |
| "indicator": indicator, | |
| "value": round(value, 3), | |
| "description": self._get_indicator_description(indicator) | |
| }) | |
| weights = { | |
| 'amount_deviation': 0.25, | |
| 'high_frequency': 0.20, | |
| 'early_claim': 0.15, | |
| 'document_mismatch': 0.25, | |
| 'entity_linkage': 0.15 | |
| } | |
| return { | |
| "signals": signals, | |
| "weights": weights | |
| } | |
| def _get_indicator_description(self, indicator: str) -> str: | |
| """Get human-readable description of indicator.""" | |
| descriptions = { | |
| 'amount_deviation': 'Claim amount significantly differs from average', | |
| 'high_frequency': 'Claimant has high claim frequency', | |
| 'early_claim': 'Claim filed shortly after policy inception', | |
| 'document_mismatch': 'Inconsistencies detected in documentation', | |
| 'entity_linkage': 'Claimant linked to suspicious entities' | |
| } | |
| return descriptions.get(indicator, indicator) | |