FraudSimulator-AI / models /fraud_risk_agent.py
Bader Alabddan
Add master prompt compliance: models/, data/, docs/, fraud_engine.py
9d20d0b
"""Fraud Risk Agent - Model Contract Implementation
This module implements the fraud-risk-agent model with strict JSON contract.
Decision output: investigate | allow
"""
import json
from typing import Dict, List, Any
class FraudRiskAgent:
"""Fraud Risk Decision Agent with formal model contract."""
def __init__(self):
self.model_version = "1.0.0"
self.decision_threshold = 0.65
def analyze(self, claim_data: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze claim and return decision contract.
Args:
claim_data: Structured claim information
Returns:
Model contract (STRICT JSON):
{
"fraud_score": float,
"risk_band": "low | medium | high",
"top_indicators": list,
"recommended_action": "investigate | allow",
"confidence": float,
"explainability": {
"signals": list,
"weights": dict
}
}
"""
# Extract features
amount = claim_data.get('amount', 0)
claim_type = claim_data.get('type', 'unknown')
claimant_history = claim_data.get('claimant_history', {})
# Calculate fraud indicators
indicators = self._calculate_indicators(claim_data)
fraud_score = self._calculate_fraud_score(indicators)
risk_band = self._determine_risk_band(fraud_score)
# Determine action
recommended_action = "investigate" if fraud_score >= self.decision_threshold else "allow"
# Build explainability
explainability = self._build_explainability(indicators)
# Return strict model contract
return {
"fraud_score": round(fraud_score, 3),
"risk_band": risk_band,
"top_indicators": self._get_top_indicators(indicators, n=5),
"recommended_action": recommended_action,
"confidence": round(self._calculate_confidence(indicators), 3),
"explainability": explainability
}
def _calculate_indicators(self, claim_data: Dict[str, Any]) -> Dict[str, float]:
"""Calculate fraud indicators from claim data."""
indicators = {}
# Amount deviation
amount = claim_data.get('amount', 0)
avg_amount = claim_data.get('average_claim_amount', 5000)
indicators['amount_deviation'] = abs(amount - avg_amount) / avg_amount if avg_amount > 0 else 0
# Frequency signal
claim_count = claim_data.get('claimant_history', {}).get('claim_count', 0)
indicators['high_frequency'] = min(claim_count / 10.0, 1.0)
# Temporal pattern
days_since_policy = claim_data.get('days_since_policy_start', 365)
indicators['early_claim'] = 1.0 if days_since_policy < 30 else 0.0
# Document consistency
doc_score = claim_data.get('document_consistency_score', 1.0)
indicators['document_mismatch'] = 1.0 - doc_score
# Entity linkage
linked_entities = claim_data.get('linked_suspicious_entities', 0)
indicators['entity_linkage'] = min(linked_entities / 5.0, 1.0)
return indicators
def _calculate_fraud_score(self, indicators: Dict[str, float]) -> float:
"""Calculate weighted fraud score."""
weights = {
'amount_deviation': 0.25,
'high_frequency': 0.20,
'early_claim': 0.15,
'document_mismatch': 0.25,
'entity_linkage': 0.15
}
score = sum(indicators.get(k, 0) * w for k, w in weights.items())
return min(max(score, 0.0), 1.0)
def _determine_risk_band(self, fraud_score: float) -> str:
"""Determine risk band from fraud score."""
if fraud_score >= 0.7:
return "high"
elif fraud_score >= 0.4:
return "medium"
else:
return "low"
def _calculate_confidence(self, indicators: Dict[str, float]) -> float:
"""Calculate confidence in the decision."""
# Higher confidence when indicators are consistent
variance = sum((v - 0.5) ** 2 for v in indicators.values()) / len(indicators)
confidence = 1.0 - (variance * 2)
return min(max(confidence, 0.0), 1.0)
def _get_top_indicators(self, indicators: Dict[str, float], n: int = 5) -> List[str]:
"""Get top N fraud indicators."""
sorted_indicators = sorted(indicators.items(), key=lambda x: x[1], reverse=True)
return [k for k, v in sorted_indicators[:n] if v > 0.1]
def _build_explainability(self, indicators: Dict[str, float]) -> Dict[str, Any]:
"""Build explainability payload."""
signals = []
for indicator, value in indicators.items():
if value > 0.1:
signals.append({
"indicator": indicator,
"value": round(value, 3),
"description": self._get_indicator_description(indicator)
})
weights = {
'amount_deviation': 0.25,
'high_frequency': 0.20,
'early_claim': 0.15,
'document_mismatch': 0.25,
'entity_linkage': 0.15
}
return {
"signals": signals,
"weights": weights
}
def _get_indicator_description(self, indicator: str) -> str:
"""Get human-readable description of indicator."""
descriptions = {
'amount_deviation': 'Claim amount significantly differs from average',
'high_frequency': 'Claimant has high claim frequency',
'early_claim': 'Claim filed shortly after policy inception',
'document_mismatch': 'Inconsistencies detected in documentation',
'entity_linkage': 'Claimant linked to suspicious entities'
}
return descriptions.get(indicator, indicator)