InvoiceAgenticAI / agents /risk_agent.py
PARTHASAKHAPAUL
Restructure project and sync local as source of truth
9107a5d
"""Risk Assessment Agent for Invoice Processing"""
# TODO: Implement agent
import os
import json
import re
from typing import Dict, Any, List
import google.generativeai as genai
from dotenv import load_dotenv
import numpy as np
from datetime import datetime, timedelta
from statistics import mean
import time
from agents.base_agent import BaseAgent
from state import (
InvoiceProcessingState, RiskAssessment, RiskLevel,
ValidationStatus, ProcessingStatus
)
from utils.logger import StructuredLogger
load_dotenv()
from collections import defaultdict
class APIKeyBalancer:
SAVE_FILE = "key_stats.json"
def __init__(self, keys):
self.keys = keys
self.usage = defaultdict(int)
self.errors = defaultdict(int)
self.load()
def load(self):
if os.path.exists(self.SAVE_FILE):
data = json.load(open(self.SAVE_FILE))
self.usage.update(data.get("usage", {}))
self.errors.update(data.get("errors", {}))
def save(self):
json.dump({
"usage": self.usage,
"errors": self.errors
}, open(self.SAVE_FILE, "w"))
def get_best_key(self):
# choose least used or least errored key
best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k]))
self.usage[best_key] += 1
self.save()
return best_key
def report_error(self, key):
self.errors[key] += 1
self.save()
balancer = APIKeyBalancer([
os.getenv("GEMINI_API_KEY_1"),
os.getenv("GEMINI_API_KEY_2"),
os.getenv("GEMINI_API_KEY_3"),
# os.getenv("GEMINI_API_KEY_4"),
os.getenv("GEMINI_API_KEY_5"),
os.getenv("GEMINI_API_KEY_6"),
# os.getenv("GEMINI_API_KEY_7"),
])
class RiskAgent(BaseAgent):
"""Agent responsible for risk assessment, fraud detection, and compliance checking"""
def __init__(self, config: Dict[str, Any] = None):
super().__init__("risk_agent",config)
# genai.configure(api_key=os.getenv("GEMINI_API_KEY_7"))
self.logger = StructuredLogger("risk_agent")
self.api_key = balancer.get_best_key()
print("self.api_key..........", self.api_key)
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel("gemini-2.0-flash")
# --- Metrics tracking ---
self.execution_history: List[Dict[str, Any]] = []
self.max_history = 50 # keep last 50 runs
def generate(self, prompt):
try:
response = self.model.generate_content(prompt)
return response
except Exception as e:
balancer.report_error(self.api_key)
raise
def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
return bool(state.invoice_data and state.validation_result)
def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
return bool(state.risk_assessment and state.risk_assessment.risk_score is not None)
async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
start_time = time.time()
success = False
try:
if not self._validate_preconditions(state, workflow_type):
state.overall_status = ProcessingStatus.FAILED
self._log_decision(state, "Risk Assessment Analysis Failed", "Preconditions not met", confidence=0.0)
invoice_data = state.invoice_data
validation_result = state.validation_result
base_score = await self._calculate_base_risk_score(invoice_data, validation_result)
print("base_score:",base_score)
fraud_indicators = await self._detect_fraud_indicators(invoice_data, validation_result)
print("fraud_indicators:",fraud_indicators)
compliance_issues = await self._check_compliance(invoice_data, state)
print("compliance_issues:",compliance_issues)
ai_assessment = await self._ai_risk_assessment(invoice_data, validation_result, fraud_indicators)
print("ai_assessment:",ai_assessment)
combined_score = self._combine_risk_factors(base_score, fraud_indicators, compliance_issues, ai_assessment)
print("combined_score:",combined_score)
risk_level = self._determine_risk_level(combined_score)
print("risk_level:",risk_level)
recommendation = self._generate_recommendation(risk_level, fraud_indicators, compliance_issues, validation_result)
print("recommendation:", recommendation)
state.risk_assessment = RiskAssessment(
risk_level = risk_level,
risk_score = combined_score,
fraud_indicators = fraud_indicators,
compliance_issues = compliance_issues,
recommendation = recommendation["action"],
reason = recommendation["reason"],
requires_human_review = recommendation["requires_human_review"]
)
state.current_agent = "risk_agent"
state.overall_status = ProcessingStatus.IN_PROGRESS
success = True
self._log_decision(
state,
"Risk Assessment Successful",
"PDF text successfully verified by Risk Agent and checked by AI",
combined_score,
state.process_id
)
return state
finally:
duration_ms = round((time.time() - start_time) * 1000, 2)
self._record_execution(success, duration_ms)
async def _calculate_base_risk_score(self, invoice_data, validation_result) -> float:
"""
Calculates an intelligent risk score (0.0–1.0) based on validation results,
invoice metadata, and contextual financial factors.
"""
score = 0.0
# --- 1. Validation & PO related risks ---
if validation_result:
if validation_result.validation_status == ValidationStatus.INVALID:
score += 0.4
elif validation_result.validation_status == ValidationStatus.PARTIAL_MATCH:
score += 0.25
elif validation_result.validation_status == ValidationStatus.MISSING_PO:
score += 0.3
# Core mismatch signals
if not validation_result.amount_match:
score += 0.2
if not validation_result.rate_match:
score += 0.15
if not validation_result.quantity_match:
score += 0.1
# Low confidence from validation adds risk
if validation_result.confidence_score is not None:
score += (0.5 - validation_result.confidence_score) * 0.3 if validation_result.confidence_score < 0.5 else 0
# --- 2. Invoice amount-based risk ---
if invoice_data and invoice_data.total is not None:
total = invoice_data.total
if total > 1_000_000:
score += 0.4 # Extremely high-value invoices
elif total > 100_000:
score += 0.25
elif total > 10_000:
score += 0.1
elif total < 10:
score += 0.15 # Suspiciously small invoice
# --- 3. Temporal risks (based on due date) ---
if invoice_data and getattr(invoice_data, "due_date", None):
try:
score += self._calculate_due_date_risk(invoice_data.due_date)
except Exception:
pass # Graceful degradation if due_date is invalid
# --- 4. Vendor / Customer risks ---
if invoice_data and getattr(invoice_data, "customer_name", None):
name = invoice_data.customer_name.lower()
if "new_vendor" in name or "test" in name or "demo" in name:
score += 0.2
elif any(flag in name for flag in ["fraud", "fake", "invalid"]):
score += 0.3
# --- 5. Data reliability / extraction confidence ---
if invoice_data and getattr(invoice_data, "extraction_confidence", None) is not None:
conf = invoice_data.extraction_confidence
if conf < 0.5:
score += 0.2
elif conf < 0.7:
score += 0.1
# --- 6. Currency and metadata anomalies ---
currency = getattr(invoice_data, "currency", "USD") or "USD"
if currency.upper() not in {"USD", "EUR", "GBP", "INR"}:
score += 0.15 # uncommon currencies add risk
# Normalize score within [0, 1.0]
return round(min(score, 1.0), 3)
def _calculate_due_date_risk(self, due_date_str: str) -> float:
try:
due_date = self._parse_date(due_date_str)
days_until_due = (due_date - datetime.utcnow().date()).days
if days_until_due < 0:
return 0.2
elif days_until_due < 5:
return 0.1
return 0.0
except Exception:
return 0.05
def _parse_date(self, date_str: str) -> datetime.date:
return datetime.strptime(date_str,"%Y-%m-%d").date()
async def _detect_fraud_indicators(self, invoice_data, validation_result) -> List[str]:
"""
Performs intelligent fraud detection on the given invoice and validation results.
Returns a list of detected fraud indicators.
"""
indicators = []
# 1. PO / Validation mismatches
if validation_result:
if not validation_result.po_found:
indicators.append("No matching Purchase Order found")
if not validation_result.amount_match:
indicators.append("Amount discrepancy detected")
if not validation_result.rate_match:
indicators.append("Rate inconsistency with Purchase Order")
if not validation_result.quantity_match:
indicators.append("Quantity mismatch detected")
if validation_result.confidence_score is not None and validation_result.confidence_score < 0.6:
indicators.append(f"Low validation confidence ({validation_result.confidence_score:.2f})")
# 2. Vendor / Customer anomalies
customer_name = getattr(invoice_data, "customer_name", "") or ""
if "test" in customer_name.lower() or "demo" in customer_name.lower():
indicators.append("Suspicious vendor name (Test/Demo account)")
if "new_vendor" in customer_name.lower():
indicators.append("First-time or unverified vendor")
if any(keyword in customer_name.lower() for keyword in ["fraud", "fake", "invalid"]):
indicators.append("Vendor flagged with risky keywords")
# 3. Amount-level risk signals
total = getattr(invoice_data, "total", 0.0) or 0.0
if total > 1_000_000:
indicators.append(f"Unusually high invoice total (${total:,.2f})")
elif total < 10:
indicators.append(f"Suspiciously low invoice total (${total:,.2f})")
# 4. Date anomalies
due_date = getattr(invoice_data, "due_date", None)
invoice_date = getattr(invoice_data, "invoice_date", None)
if invoice_date and due_date and (due_date - invoice_date).days < 0:
indicators.append("Due date earlier than invoice date (possible manipulation)")
elif invoice_date and due_date and (due_date - invoice_date).days < 3:
indicators.append("Unusually short payment window")
# 5. Duplicate or pattern-based red flags
if invoice_data.invoice_number and invoice_data.invoice_number.lower().startswith("dup-"):
indicators.append("Possible duplicate invoice ID pattern")
if invoice_data.file_name and "copy" in invoice_data.file_name.lower():
indicators.append("Invoice filename suggests duplication")
# 6. Confidence anomalies (AI extraction)
if invoice_data.extraction_confidence is not None and invoice_data.extraction_confidence < 0.5:
indicators.append(f"Low extraction confidence ({invoice_data.extraction_confidence:.2f}) — possible OCR tampering")
# 7. Currency or unusual metadata patterns
if getattr(invoice_data, "currency", "").upper() not in {"USD", "EUR", "GBP", "INR"}:
indicators.append(f"Uncommon currency code: {invoice_data.currency}")
return indicators
async def _check_compliance(self, invoice_data, state: InvoiceProcessingState) -> List[str]:
"""
Performs a multi-layer compliance check on invoice and state integrity.
Returns a list of detected compliance issues.
"""
issues = []
# 1. Invoice integrity checks
if not invoice_data.invoice_number:
issues.append("Missing invoice number")
if not invoice_data.customer_name:
issues.append("Missing customer name")
if not invoice_data.total or invoice_data.total <= 0:
issues.append("Invalid or missing total amount")
if not invoice_data.due_date:
issues.append("Missing due date")
# 2. Item-level verification
if not invoice_data.item_details or len(invoice_data.item_details) == 0:
issues.append("No item details present")
else:
for item in invoice_data.item_details:
if not getattr(item, "item_name", None):
issues.append("Item missing name")
if getattr(item, "quantity", 1) <= 0:
issues.append(f"Invalid quantity for item '{item.item_name or 'Unknown'}'")
# 3. Confidence & quality checks
if invoice_data.extraction_confidence and invoice_data.extraction_confidence < 0.7:
issues.append(f"Low extraction confidence ({invoice_data.extraction_confidence:.2f})")
# 4. Workflow state checks
if not getattr(state, "approval_chain", True):
issues.append("Approval chain incomplete")
if getattr(state, "escalation_required", False):
issues.append("Escalation required before payment")
if getattr(state, "human_review_required", False):
issues.append("Pending human review")
# 5. Audit consistency
if len(state.audit_trail) == 0:
issues.append("No audit trail entries found")
# # 6. Optional receipt confirmation
# if not getattr(invoice_data, "receipt_confirmed", False):
# issues.append("Missing receipt confirmation")
# 7. Risk-based compliance (if risk assessment exists)
if state.risk_assessment and state.risk_assessment.risk_score >= 0.7:
issues.append(f"High risk score detected ({state.risk_assessment.risk_score:.2f})")
return issues
async def _ai_risk_assessment(
self,
invoice_data,
validation_result,
fraud_indicators: List[str]
) -> Dict[str, Any]:
"""
Uses a Generative AI model (Gemini) to assess risk level based on
structured invoice data, validation results, and detected fraud indicators.
Returns:
dict: {
"risk_score": float between 0–1,
"reason": str (explanation for the score)
}
"""
self.logger.logger.info("[RiskAgent] Running AI-based risk assessment...")
# model_name = "gemini-2.5-flash"
result = {"risk_score": 0.0, "reason": "Default – AI assessment not available"}
try:
# Initialize model
# model = genai.GenerativeModel(model_name)
# --- Construct dynamic and context-rich prompt ---
prompt = f"""
You are a financial risk analysis model for invoice fraud detection.
Carefully analyze the following details:
INVOICE DATA:
{invoice_data}
VALIDATION RESULT:
{validation_result}
DETECTED FRAUD INDICATORS:
{fraud_indicators}
TASK:
1. Assess overall risk of this invoice being fraudulent or non-compliant.
2. Provide reasoning.
3. Respond **only in JSON** with keys:
- "risk_score": a float between 0 and 1 (higher = higher risk)
- "reason": short explanation of what contributed to this score.
EXAMPLES:
{{
"risk_score": 0.85,
"reason": "High amount mismatch, new vendor, and unusual currency"
}}
{{
"risk_score": 0.25,
"reason": "Valid PO and consistent totals, low fraud signals"
}}
"""
import asyncio
# --- Model call ---
response = self.generate(prompt)
# response = await asyncio.to_thread(model.generate_content, prompt)
# --- Clean and parse ---
raw_text = getattr(response, "text", "") or ""
cleaned_json = self._clean_json_response(raw_text)
ai_output = json.loads(cleaned_json)
# --- Validate AI output ---
score = float(ai_output.get("risk_score", 0.0))
reason = str(ai_output.get("reason", "No reason provided"))
# Clamp score between 0–1 for safety
result = {
"risk_score": max(0.0, min(score, 1.0)),
"reason": reason.strip()[:400] # limit for logs
}
self.logger.logger.info(
f"[RiskAgent] AI Risk Assessment completed: score={result['risk_score']}, reason={result['reason']}"
)
except json.JSONDecodeError as e:
self.logger.logger.warning(f"[RiskAgent] JSON parsing failed: {e}")
result["reason"] = "AI response could not be parsed"
except Exception as e:
self.logger.logger.error(f"[RiskAgent] AI assessment error: {e}", exc_info=True)
result["reason"] = "Fallback to base risk model"
return result
def _clean_json_response(self, text: str) -> str:
text = re.sub(r'^[^{]*','',text)
text = re.sub(r'[^}]*$','',text)
return text
def _combine_risk_factors(
self,
base_score: float,
fraud_indicators: List[str],
compliance_issues: List[str],
ai_assessment: Dict[str, Any]
) -> float:
"""
Combines multiple risk components (base, fraud, compliance, and AI analysis)
into a single normalized risk score between 0.0 and 1.0.
Weighting strategy:
- Base Score: foundation derived from deterministic checks
- Fraud Indicators: +0.1 per flag (max +0.3)
- Compliance Issues: +0.05 per issue (max +0.2)
- AI Risk Score: contributes 40–50% of total weight
Returns:
float: final risk score clamped to [0, 1]
"""
try:
# Extract and normalize AI risk
ai_score = float(ai_assessment.get("risk_score", 0.0))
ai_score = max(0.0, min(ai_score, 1.0))
# --- Weighted contributions ---
fraud_contrib = min(len(fraud_indicators) * 0.1, 0.3)
compliance_contrib = min(len(compliance_issues) * 0.05, 0.2)
ai_contrib = 0.5 * ai_score if ai_score > 0 else 0.2 * base_score
combined = base_score + fraud_contrib + compliance_contrib + ai_contrib
# Cap at 1.0 for safety
final_score = round(min(combined, 1.0), 3)
self.logger.logger.info(
f"[RiskAgent] Combined risk computed: base={base_score}, "
f"fraud_flags={len(fraud_indicators)}, compliance_flags={len(compliance_issues)}, "
f"ai_score={ai_score}, final={final_score}"
)
return final_score
except Exception as e:
self.logger.logger.error(f"[RiskAgent] Error combining risk factors: {e}", exc_info=True)
return min(base_score + 0.2, 1.0) # fallback conservative estimate
def _determine_risk_level(self, risk_score: float) -> RiskLevel:
if risk_score<0.3:
return RiskLevel.LOW
elif risk_score<0.6:
return RiskLevel.MEDIUM
elif risk_score<0.8:
return RiskLevel.HIGH
return RiskLevel.CRITICAL
def _generate_recommendation(
self,
risk_level: RiskLevel,
fraud_indicators: List[str],
compliance_issues: List[str],
validation_result
) -> Dict[str, Any]:
"""
Generate a structured recommendation (approve, escalate, or reject)
based on overall risk, fraud, and compliance outcomes.
Decision Logic:
- HIGH / CRITICAL risk → escalate for human review
- INVALID validation → reject
- Medium risk with minor issues → escalate
- Otherwise → approve
Returns:
Dict[str, Any]: {
'action': str, # 'approve', 'escalate', or 'reject'
'reason': str, # Explanation summary
'requires_human_review': bool
}
"""
try:
# --- Determine key flags ---
has_fraud = bool(fraud_indicators)
has_compliance_issues = bool(compliance_issues)
validation_invalid = (
validation_result and validation_result.validation_status == ValidationStatus.INVALID
)
# --- Decision Logic ---
if validation_invalid:
action = "reject"
requires_review = True
reason = "Validation failed: " + "; ".join(fraud_indicators + compliance_issues or ["Invalid invoice data"])
elif risk_level in [RiskLevel.HIGH, RiskLevel.CRITICAL]:
action = "escalate"
requires_review = True
reason = f"High risk level detected ({risk_level.value}). Issues: " + "; ".join(fraud_indicators + compliance_issues or ["Potential anomalies"])
elif has_fraud or has_compliance_issues:
action = "escalate"
requires_review = True
reason = "Minor irregularities found: " + "; ".join(fraud_indicators + compliance_issues)
else:
action = "approve"
requires_review = False
reason = "All checks passed; invoice appears valid and compliant."
# --- Structured Output ---
recommendation = {
"action": action,
"reason": reason,
"requires_human_review": requires_review,
}
self.logger.logger.info(
f"[DecisionAgent] Recommendation generated: {recommendation}"
)
return recommendation
except Exception as e:
self.logger.logger.error(f"[DecisionAgent] Error generating recommendation: {e}", exc_info=True)
# Safe fallback
return {
"action": "escalate",
"reason": "Error during recommendation generation",
"requires_human_review": True,
}
def _record_execution(self, success: bool, duration_ms: float):
self.execution_history.append({
# "timestamp": datetime.utcnow().isoformat(),
"success": success,
"duration_ms": duration_ms,
})
# Keep recent N only
if len(self.execution_history) > self.max_history:
self.execution_history.pop(0)
async def health_check(self) -> Dict[str, Any]:
total_runs = len(self.execution_history)
if total_runs == 0:
return {
"Agent": "Risk Agent ⚠️",
"Executions": 0,
"Success Rate (%)": 0.0,
"Avg Duration (ms)": 0.0,
"Total Failures": 0,
"Status": "idle",
# "Timestamp": datetime.utcnow().isoformat()
}
metrics_data = {}
executions = 0
success_rate = 0.0
avg_duration = 0.0
failures = 0
last_run = None
# 1. Try to get live metrics from state
# print("(self.state)-------",self.metrics)
# print("self.state.agent_metrics-------", self.state.agent_metrics)
if self.metrics:
executions = self.metrics["processed"]
avg_duration = self.metrics["avg_latency_ms"]
failures = self.metrics["errors"]
last_run = self.metrics["last_run_at"]
success_rate = (executions - failures) / (executions+1e-8)
# 2. API connectivity check
gemini_ok = bool(self.api_key)
api_status = "🟢 Active" if gemini_ok else "🔴 Missing Key"
# 3. Health logic
overall_status = "🟢 Healthy"
if not gemini_ok or failures > 3:
overall_status = "🟠 Degraded"
if executions > 0 and success_rate < 0.5:
overall_status = "🔴 Unhealthy"
successes = sum(1 for e in self.execution_history if e["success"])
failures = total_runs - successes
avg_duration = round(mean(e["duration_ms"] for e in self.execution_history), 2)
success_rate = round((successes / (total_runs+1e-8)) * 100, 2)
return {
"Agent": "Risk Agent ⚠️",
"Executions": total_runs,
"Success Rate (%)": success_rate,
"Avg Duration (ms)": avg_duration,
"API Status": api_status,
"Total Failures": failures,
"Last Run": str(last_run) if last_run else "Not applicable",
# "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
"Overall Health": overall_status,
}