Spaces:
Sleeping
Sleeping
| """FraudShield partial-observability environment implementation.""" | |
| from __future__ import annotations | |
| import copy | |
| import uuid | |
| from typing import Any, Dict, List | |
| from data_loader import FraudDataLoader | |
| from models import ( | |
| ActionTypeEnum, | |
| CaseScreenEnum, | |
| CaseSummary, | |
| EpisodeState, | |
| FraudCheckAction, | |
| FraudCheckObservation, | |
| QueueCaseCard, | |
| ResetResult, | |
| ResolutionEnum, | |
| Reward, | |
| StepResult, | |
| TaskDifficulty, | |
| ) | |
| TASK_CONFIG: Dict[TaskDifficulty, Dict[str, Any]] = { | |
| TaskDifficulty.EASY: { | |
| "source_task": "easy", | |
| "num_cases": 1, | |
| "max_steps": 6, | |
| "sla_limit": 5, | |
| "ideal_steps": 3, | |
| "investigation_budget": 1, | |
| "minimum_fetches_for_bonus": 1, | |
| "description": "Single low-noise case with strong visible cues and one fetch budget.", | |
| }, | |
| TaskDifficulty.MEDIUM: { | |
| "source_task": "medium", | |
| "num_cases": 1, | |
| "max_steps": 8, | |
| "sla_limit": 6, | |
| "ideal_steps": 5, | |
| "investigation_budget": 2, | |
| "minimum_fetches_for_bonus": 1, | |
| "description": "Single mixed-signal case that requires at least one investigation before routing.", | |
| }, | |
| TaskDifficulty.HARD: { | |
| "source_task": "hard", | |
| "num_cases": 2, | |
| "max_steps": 14, | |
| "sla_limit": 11, | |
| "ideal_steps": 9, | |
| "investigation_budget": 3, | |
| "minimum_fetches_for_bonus": 1, | |
| "description": "Two misleading linked cases where graph evidence is usually required.", | |
| }, | |
| } | |
| FETCH_ACTIONS = { | |
| ActionTypeEnum.FETCH_CUSTOMER_PROFILE, | |
| ActionTypeEnum.FETCH_MERCHANT_PROFILE, | |
| ActionTypeEnum.FETCH_NETWORK_GRAPH, | |
| ActionTypeEnum.CHECK_POLICY, | |
| } | |
| class FraudShieldEnvironment: | |
| """OpenEnv-compatible fraud-investigation environment.""" | |
| def __init__(self, data_path: str = "data", seed: int = 42): | |
| self.seed = seed | |
| self.data_loader = FraudDataLoader(data_path=data_path, seed=seed) | |
| self.data_loaded = False | |
| self.episode_id = "" | |
| self.current_task = TaskDifficulty.EASY | |
| self.step_count = 0 | |
| self.cumulative_reward = 0.0 | |
| self.is_done = False | |
| self.current_screen = CaseScreenEnum.QUEUE | |
| self.active_case_id = "" | |
| self.workflow_cases: Dict[str, Dict[str, Any]] = {} | |
| self.case_state: Dict[str, Dict[str, Any]] = {} | |
| self.case_order: List[str] = [] | |
| self.audit_log: List[Dict[str, Any]] = [] | |
| self.invalid_action_count = 0 | |
| self.redundant_action_count = 0 | |
| self.note_spam_count = 0 | |
| self.case_counts = {task: TASK_CONFIG[task]["num_cases"] for task in TaskDifficulty} | |
| self.max_steps = {task: TASK_CONFIG[task]["max_steps"] for task in TaskDifficulty} | |
| self.sla_limit = {task: TASK_CONFIG[task]["sla_limit"] for task in TaskDifficulty} | |
| self.last_episode_summary: Dict[str, Any] = {} | |
| def load_data(self) -> bool: | |
| """Load the deterministic committed snapshot.""" | |
| self.data_loaded = self.data_loader.load_data() | |
| return self.data_loaded | |
| def load_kaggle_data(self) -> bool: | |
| """Backward-compatible alias for older validation scripts.""" | |
| return self.load_data() | |
| def ensure_data_loaded(self) -> None: | |
| """Load data lazily for local and remote execution.""" | |
| if not self.data_loaded and not self.load_data(): | |
| raise RuntimeError("FraudShield data bundle could not be loaded.") | |
| def reset(self, task: str = "easy") -> ResetResult: | |
| """Start a new fraud-investigation episode.""" | |
| self.ensure_data_loaded() | |
| self.current_task = TaskDifficulty(task) | |
| config = TASK_CONFIG[self.current_task] | |
| self.episode_id = f"ep_{uuid.uuid4().hex[:8]}" | |
| self.step_count = 0 | |
| self.cumulative_reward = 0.0 | |
| self.is_done = False | |
| self.current_screen = CaseScreenEnum.QUEUE | |
| self.invalid_action_count = 0 | |
| self.redundant_action_count = 0 | |
| self.note_spam_count = 0 | |
| self.audit_log = [] | |
| self.last_episode_summary = {} | |
| self.workflow_cases = self._build_workflow_cases(self.current_task) | |
| self.case_order = list(self.workflow_cases.keys()) | |
| self.active_case_id = self.case_order[0] | |
| self.case_state = { | |
| case_id: { | |
| "status": "triage", | |
| "reviewed": False, | |
| "revealed_evidence": {}, | |
| "note_count": 0, | |
| "notes": [], | |
| "policy_checked": False, | |
| "resolution": None, | |
| "resolved": False, | |
| "resolution_correct": False, | |
| "policy_compliant": False, | |
| "invalid_actions": 0, | |
| "redundant_actions": 0, | |
| "anti_hacking_hits": 0, | |
| "action_history": [], | |
| "fetches_used": 0, | |
| "fetch_budget_remaining": config["investigation_budget"], | |
| } | |
| for case_id in self.case_order | |
| } | |
| info = { | |
| "episode_id": self.episode_id, | |
| "task": self.current_task.value, | |
| "num_cases": config["num_cases"], | |
| "max_steps": config["max_steps"], | |
| "sla_limit": config["sla_limit"], | |
| "investigation_budget": config["investigation_budget"], | |
| "description": config["description"], | |
| "workflow_views": [screen.value for screen in CaseScreenEnum], | |
| "data_snapshot": self.data_loader.get_bundle_summary(), | |
| } | |
| return ResetResult(observation=self._build_observation(), info=info) | |
| def step(self, action: FraudCheckAction) -> StepResult: | |
| """Apply a single investigation or resolution action.""" | |
| if self.is_done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| if action.case_id not in self.workflow_cases: | |
| reward = self._apply_action_outcome( | |
| action=action, | |
| case_id=self.active_case_id or self.case_order[0], | |
| base_value=-0.35, | |
| reason=f"Unknown case_id '{action.case_id}'.", | |
| valid_action=False, | |
| anti_hacking=True, | |
| ) | |
| return self._build_step_result(reward, {"valid_action": False, "error": "unknown_case"}) | |
| if action.action_type != ActionTypeEnum.REVIEW_TRANSACTION and action.case_id != self.active_case_id: | |
| reward = self._apply_action_outcome( | |
| action=action, | |
| case_id=action.case_id, | |
| base_value=-0.18, | |
| reason="Only review_transaction may switch focus to another case.", | |
| valid_action=False, | |
| ) | |
| return self._build_step_result(reward, {"valid_action": False, "error": "inactive_case"}) | |
| case_id = action.case_id | |
| state = self.case_state[case_id] | |
| if state["resolved"] and action.action_type != ActionTypeEnum.REVIEW_TRANSACTION: | |
| reward = self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.22, | |
| reason="Resolved cases cannot be modified; move to an open case instead.", | |
| valid_action=False, | |
| ) | |
| return self._build_step_result(reward, {"valid_action": False, "error": "resolved_case"}) | |
| if action.action_type == ActionTypeEnum.REVIEW_TRANSACTION: | |
| reward = self._handle_review(case_id, action) | |
| elif action.action_type == ActionTypeEnum.FETCH_CUSTOMER_PROFILE: | |
| reward = self._handle_fetch(case_id, action, "customer_profile", CaseScreenEnum.CUSTOMER_PROFILE) | |
| elif action.action_type == ActionTypeEnum.FETCH_MERCHANT_PROFILE: | |
| reward = self._handle_fetch(case_id, action, "merchant_profile", CaseScreenEnum.MERCHANT_PROFILE) | |
| elif action.action_type == ActionTypeEnum.FETCH_NETWORK_GRAPH: | |
| reward = self._handle_fetch(case_id, action, "network_graph", CaseScreenEnum.CASE_CONSOLE) | |
| elif action.action_type == ActionTypeEnum.CHECK_POLICY: | |
| reward = self._handle_fetch(case_id, action, "policy_guide", CaseScreenEnum.POLICY_ESCALATION) | |
| elif action.action_type == ActionTypeEnum.ADD_CASE_NOTE: | |
| reward = self._handle_add_note(case_id, action) | |
| elif action.action_type == ActionTypeEnum.RESOLVE_CASE: | |
| reward = self._handle_resolve(case_id, action) | |
| else: # pragma: no cover - enum already constrains values | |
| reward = self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.25, | |
| reason=f"Unsupported action_type '{action.action_type.value}'.", | |
| valid_action=False, | |
| ) | |
| return self._build_step_result( | |
| reward, | |
| { | |
| "valid_action": reward.value > -0.999 and not reward.reason.startswith("Unknown case_id"), | |
| "active_case_id": self.active_case_id, | |
| "resolved_case_ids": self._resolved_case_ids(), | |
| "remaining_sla": self._remaining_sla(), | |
| "remaining_steps": self._remaining_steps(), | |
| }, | |
| ) | |
| def state(self) -> EpisodeState: | |
| """Return the current episode state.""" | |
| return EpisodeState( | |
| episode_id=self.episode_id, | |
| task_name=self.current_task, | |
| current_screen=self.current_screen, | |
| active_case_id=self.active_case_id, | |
| step_count=self.step_count, | |
| remaining_steps=self._remaining_steps(), | |
| remaining_sla=self._remaining_sla(), | |
| cumulative_reward=round(self.cumulative_reward, 4), | |
| is_done=self.is_done, | |
| resolved_case_ids=self._resolved_case_ids(), | |
| unresolved_case_ids=self._unresolved_case_ids(), | |
| notes_written_by_case={case_id: data["note_count"] for case_id, data in self.case_state.items()}, | |
| evidence_keys_by_case={ | |
| case_id: sorted(data["revealed_evidence"].keys()) for case_id, data in self.case_state.items() | |
| }, | |
| policy_checked_case_ids=sorted( | |
| case_id for case_id, data in self.case_state.items() if data["policy_checked"] | |
| ), | |
| resolution_by_case={ | |
| case_id: data["resolution"] | |
| for case_id, data in self.case_state.items() | |
| if data["resolution"] is not None | |
| }, | |
| invalid_action_count=self.invalid_action_count, | |
| redundant_action_count=self.redundant_action_count, | |
| ) | |
| def get_episode_report(self) -> Dict[str, Any]: | |
| """Return a deterministic grading report for the current or completed episode.""" | |
| case_reports = [self._build_case_report(case_id) for case_id in self.case_order] | |
| case_count = max(1, len(case_reports)) | |
| resolution_accuracy = sum(1.0 if report["resolution_correct"] else 0.0 for report in case_reports) / case_count | |
| evidence_coverage = sum(report["evidence_coverage"] for report in case_reports) / case_count | |
| policy_compliance = sum(1.0 if report["policy_compliant"] else 0.0 for report in case_reports) / case_count | |
| workflow_completion = ( | |
| sum(report["workflow_completion"] for report in case_reports) / case_count if case_reports else 0.0 | |
| ) | |
| overstep_penalty = max(0, self.step_count - TASK_CONFIG[self.current_task]["ideal_steps"]) * 0.05 | |
| efficiency = max( | |
| 0.0, | |
| 1.0 | |
| - (self.invalid_action_count * 0.12) | |
| - (self.redundant_action_count * 0.08) | |
| - (self.note_spam_count * 0.06) | |
| - overstep_penalty, | |
| ) | |
| if self.current_task == TaskDifficulty.HARD: | |
| reviewed_count = sum(1.0 if report["reviewed"] else 0.0 for report in case_reports) | |
| network_count = sum( | |
| 1.0 if "network_graph" in report["revealed_evidence"] else 0.0 for report in case_reports | |
| ) | |
| resolved_count = sum(1.0 if report["submitted_resolution"] else 0.0 for report in case_reports) | |
| link_consistency = min( | |
| 1.0, | |
| 0.25 * reviewed_count + 0.25 * network_count + 0.25 * resolved_count + 0.25 * resolution_accuracy * 2, | |
| ) | |
| else: | |
| link_consistency = 1.0 | |
| summary = { | |
| "episode_id": self.episode_id, | |
| "task": self.current_task.value, | |
| "step_count": self.step_count, | |
| "max_steps": TASK_CONFIG[self.current_task]["max_steps"], | |
| "remaining_sla": self._remaining_sla(), | |
| "cumulative_reward": round(self.cumulative_reward, 4), | |
| "invalid_action_count": self.invalid_action_count, | |
| "redundant_action_count": self.redundant_action_count, | |
| "note_spam_count": self.note_spam_count, | |
| "case_summaries": case_reports, | |
| "metrics": { | |
| "resolution_accuracy": round(resolution_accuracy, 4), | |
| "evidence_coverage": round(evidence_coverage, 4), | |
| "policy_compliance": round(policy_compliance, 4), | |
| "workflow_completion": round(workflow_completion, 4), | |
| "efficiency": round(efficiency, 4), | |
| "link_consistency": round(link_consistency, 4), | |
| }, | |
| "audit_log": list(self.audit_log), | |
| } | |
| self.last_episode_summary = summary | |
| return summary | |
| def _build_workflow_cases(self, task: TaskDifficulty) -> Dict[str, Dict[str, Any]]: | |
| if task == TaskDifficulty.EASY: | |
| source_case = self._select_easy_case(self.data_loader.get_task_cases("easy")) | |
| return { | |
| "easy_case_01": self._make_workflow_case( | |
| case_id="easy_case_01", | |
| raw_case=source_case, | |
| queue_reason="High-value purchase queued for a quick manual review.", | |
| correct_resolution=ResolutionEnum.BLOCK, | |
| required_tools={"transaction_review", "case_note"}, | |
| useful_tools={"merchant_profile"}, | |
| policy_required=False, | |
| linked_case_ids=[], | |
| role="single", | |
| ) | |
| } | |
| if task == TaskDifficulty.MEDIUM: | |
| source_case = self._select_medium_case(self.data_loader.get_task_cases("medium")) | |
| return { | |
| "medium_case_01": self._make_workflow_case( | |
| case_id="medium_case_01", | |
| raw_case=source_case, | |
| queue_reason="Mixed signals triggered review and supporting evidence is needed.", | |
| correct_resolution=ResolutionEnum.REQUEST_DOCS, | |
| required_tools={"transaction_review", "customer_profile", "policy_guide", "case_note"}, | |
| useful_tools={"merchant_profile"}, | |
| policy_required=True, | |
| linked_case_ids=[], | |
| role="single", | |
| ) | |
| } | |
| hard_cases = self._select_hard_pair(self.data_loader.get_task_cases("hard")) | |
| primary = self._make_workflow_case( | |
| case_id="hard_case_primary", | |
| raw_case=hard_cases[0], | |
| queue_reason="Operational anomaly spike triggered a higher-touch review.", | |
| correct_resolution=ResolutionEnum.ESCALATE, | |
| required_tools={"transaction_review", "network_graph", "merchant_profile", "policy_guide", "case_note"}, | |
| useful_tools=set(), | |
| policy_required=True, | |
| linked_case_ids=["hard_case_secondary"], | |
| role="primary", | |
| ) | |
| secondary = self._make_workflow_case( | |
| case_id="hard_case_secondary", | |
| raw_case=hard_cases[1], | |
| queue_reason="A related anomaly surfaced in the same review wave.", | |
| correct_resolution=ResolutionEnum.BLOCK, | |
| required_tools={"transaction_review", "network_graph", "customer_profile", "policy_guide", "case_note"}, | |
| useful_tools=set(), | |
| policy_required=True, | |
| linked_case_ids=["hard_case_primary"], | |
| role="secondary", | |
| ) | |
| return {primary["case_id"]: primary, secondary["case_id"]: secondary} | |
| def _select_easy_case(self, cases: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| fraud_cases = [case for case in cases if case["label"] == "fraud"] | |
| return max(fraud_cases, key=lambda case: (case["risk_score"], case["business_cost"])) | |
| def _select_medium_case(self, cases: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| candidates = [ | |
| case | |
| for case in cases | |
| if case["label"] == "legitimate" | |
| and case["transaction_data"]["previous_fraud_flags"] >= 1 | |
| and case["transaction_data"]["seller_chargeback_rate_30d"] >= 0.04 | |
| ] | |
| if not candidates: | |
| candidates = [case for case in cases if case["label"] == "legitimate"] | |
| return max(candidates, key=lambda case: (case["risk_score"], case["business_cost"])) | |
| def _select_hard_pair(self, cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| groups: Dict[str, List[Dict[str, Any]]] = {} | |
| for case in cases: | |
| seller_id = case["transaction_data"]["seller_id"] | |
| groups.setdefault(seller_id, []).append(case) | |
| linked_groups = [ | |
| group | |
| for group in groups.values() | |
| if len(group) >= 2 and all(case["label"] == "fraud" for case in group) | |
| ] | |
| if not linked_groups: | |
| linked_groups = [cases[:2]] | |
| chosen_group = max(linked_groups, key=lambda group: sum(case["business_cost"] for case in group)) | |
| ordered = sorted(chosen_group, key=lambda case: (case["business_cost"], case["risk_score"]), reverse=True) | |
| return ordered[:2] | |
| def _make_workflow_case( | |
| self, | |
| case_id: str, | |
| raw_case: Dict[str, Any], | |
| queue_reason: str, | |
| correct_resolution: ResolutionEnum, | |
| required_tools: set[str], | |
| useful_tools: set[str], | |
| policy_required: bool, | |
| linked_case_ids: List[str], | |
| role: str, | |
| ) -> Dict[str, Any]: | |
| transaction = copy.deepcopy(raw_case["transaction_data"]) | |
| history = copy.deepcopy(raw_case["historical_context"]) | |
| risk_score = float(raw_case["risk_score"]) | |
| business_cost = float(raw_case["business_cost"]) | |
| shipping_country = transaction["shipping_address"] | |
| device_country = transaction["device_country"] | |
| geo_mismatch = shipping_country != device_country | |
| timestamp = transaction["timestamp"] | |
| queue_card = { | |
| "case_id": case_id, | |
| "priority": self._priority_label(risk_score, business_cost), | |
| "queue_reason": queue_reason, | |
| "visible_risk_band": "review", | |
| "status": "triage", | |
| "linked_case_ids": [], | |
| } | |
| transaction_review = { | |
| "view": CaseScreenEnum.CASE_CONSOLE.value, | |
| "summary": self._transaction_summary(transaction, geo_mismatch), | |
| "facts": { | |
| "amount_usd": transaction["amount"], | |
| "item_category": transaction["item_category"], | |
| "timestamp": timestamp, | |
| "shipping_country": shipping_country, | |
| "device_country": device_country, | |
| "payment_method": transaction["payment_method"], | |
| "shipping_speed": transaction["shipping_speed"], | |
| "same_address_orders_24h": transaction["same_address_orders_24h"], | |
| }, | |
| } | |
| customer_profile = { | |
| "view": CaseScreenEnum.CUSTOMER_PROFILE.value, | |
| "summary": self._customer_summary(transaction), | |
| "facts": { | |
| "buyer_account_age_days": transaction["buyer_account_age_days"], | |
| "buyer_disputes_90d": transaction["buyer_disputes_90d"], | |
| "is_repeat_buyer": transaction["is_repeat_buyer"], | |
| }, | |
| } | |
| merchant_profile = { | |
| "view": CaseScreenEnum.MERCHANT_PROFILE.value, | |
| "summary": self._merchant_summary(transaction), | |
| "facts": { | |
| "seller_account_age_days": transaction["seller_account_age_days"], | |
| "seller_avg_rating": transaction["seller_avg_rating"], | |
| "num_seller_reviews": transaction["num_seller_reviews"], | |
| "seller_chargeback_rate_30d": transaction["seller_chargeback_rate_30d"], | |
| }, | |
| } | |
| network_graph = { | |
| "view": CaseScreenEnum.CASE_CONSOLE.value, | |
| "summary": self._network_summary(role, linked_case_ids, history), | |
| "facts": { | |
| "shared_device_accounts_24h": transaction["shared_device_accounts_24h"], | |
| "previous_fraud_flags": transaction["previous_fraud_flags"], | |
| "cluster_alert_score": history["cluster_alert_score"], | |
| "linked_cards_7d": history["linked_cards_7d"], | |
| "linked_case_ids": list(linked_case_ids), | |
| }, | |
| } | |
| policy_guide = { | |
| "view": CaseScreenEnum.POLICY_ESCALATION.value, | |
| "summary": self._policy_summary(policy_required, business_cost), | |
| "facts": { | |
| "policy_required": policy_required, | |
| "note_required": True, | |
| "request_docs_on_unresolved_conflict": self.current_task != TaskDifficulty.EASY, | |
| "escalate_if_cluster_and_loss": role == "primary" and business_cost >= 1.35, | |
| "high_loss_threshold": 1.35, | |
| }, | |
| } | |
| return { | |
| "case_id": case_id, | |
| "raw_case": raw_case, | |
| "transaction": transaction, | |
| "history": history, | |
| "risk_score": risk_score, | |
| "business_cost": business_cost, | |
| "correct_resolution": correct_resolution, | |
| "required_tools": set(required_tools), | |
| "useful_tools": set(useful_tools), | |
| "policy_required": policy_required, | |
| "linked_case_ids": list(linked_case_ids), | |
| "role": role, | |
| "queue_card": queue_card, | |
| "hidden_flags": { | |
| "payment_ops_high_risk": self._high_risk_payment_ops(transaction, geo_mismatch), | |
| "network_high_risk": history["cluster_alert_score"] >= 0.7, | |
| }, | |
| "evidence_catalog": { | |
| "transaction_review": transaction_review, | |
| "customer_profile": customer_profile, | |
| "merchant_profile": merchant_profile, | |
| "network_graph": network_graph, | |
| "policy_guide": policy_guide, | |
| }, | |
| } | |
| def _priority_label(self, risk_score: float, business_cost: float) -> str: | |
| if risk_score >= 0.68 or business_cost >= 1.45: | |
| return "P1" | |
| if risk_score >= 0.5 or business_cost >= 1.1: | |
| return "P2" | |
| return "P3" | |
| def _transaction_summary(self, transaction: Dict[str, Any], geo_mismatch: bool) -> str: | |
| return ( | |
| f"Payment {transaction['payment_method']}; shipping {transaction['shipping_speed']}; " | |
| f"same-address orders={transaction['same_address_orders_24h']}; geo mismatch={geo_mismatch}." | |
| ) | |
| def _customer_summary(self, transaction: Dict[str, Any]) -> str: | |
| return ( | |
| f"Buyer age {transaction['buyer_account_age_days']}d; disputes {transaction['buyer_disputes_90d']}; " | |
| f"repeat buyer={transaction['is_repeat_buyer']}." | |
| ) | |
| def _merchant_summary(self, transaction: Dict[str, Any]) -> str: | |
| return ( | |
| f"Seller rating {transaction['seller_avg_rating']:.2f}; reviews {transaction['num_seller_reviews']}; " | |
| f"chargeback rate {transaction['seller_chargeback_rate_30d']:.3f}." | |
| ) | |
| def _network_summary(self, role: str, linked_case_ids: List[str], history: Dict[str, Any]) -> str: | |
| if not linked_case_ids: | |
| return ( | |
| f"Graph review surfaced cluster score {history['cluster_alert_score']:.2f} " | |
| "with no immediately visible linked cases." | |
| ) | |
| if role == "primary": | |
| return ( | |
| f"Graph review surfaced a cluster score of {history['cluster_alert_score']:.2f} " | |
| "and a shared-entity pattern worth escalation review." | |
| ) | |
| return ( | |
| f"Graph review surfaced a cluster score of {history['cluster_alert_score']:.2f} " | |
| "and a related-activity pattern on this case." | |
| ) | |
| def _policy_summary(self, policy_required: bool, business_cost: float) -> str: | |
| if not policy_required: | |
| return "Policy allows direct approve or block decisions once a note is added." | |
| if business_cost >= 1.35: | |
| return "Policy recommends escalation when hidden network risk and business impact are both elevated." | |
| return "Policy recommends requesting documents or holding the case when signals remain mixed." | |
| def _high_risk_payment_ops(self, transaction: Dict[str, Any], geo_mismatch: bool) -> bool: | |
| return bool( | |
| transaction["payment_method"] in {"prepaid_card", "gift_card", "crypto_gateway"} | |
| or transaction["shipping_speed"] in {"same-day", "overnight"} | |
| or transaction["same_address_orders_24h"] >= 5 | |
| or geo_mismatch | |
| ) | |
| def _handle_review(self, case_id: str, action: FraudCheckAction) -> Reward: | |
| self.active_case_id = case_id | |
| self.current_screen = CaseScreenEnum.CASE_CONSOLE | |
| state = self.case_state[case_id] | |
| case = self.workflow_cases[case_id] | |
| if state["reviewed"]: | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.05, | |
| reason="Transaction review was already completed for this case.", | |
| evidence_key="transaction_review", | |
| redundant=True, | |
| ) | |
| state["reviewed"] = True | |
| state["status"] = "in_review" | |
| state["revealed_evidence"]["transaction_review"] = case["evidence_catalog"]["transaction_review"] | |
| bonus = 0.08 if case["hidden_flags"]["payment_ops_high_risk"] else 0.0 | |
| reason = "Transaction review revealed the operational transaction trace." | |
| if bonus > 0: | |
| reason += " The review surfaced high-risk payment or fulfillment signals." | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=0.04 + bonus, | |
| reason=reason, | |
| evidence_key="transaction_review", | |
| ) | |
| def _handle_fetch( | |
| self, | |
| case_id: str, | |
| action: FraudCheckAction, | |
| evidence_key: str, | |
| screen: CaseScreenEnum, | |
| ) -> Reward: | |
| self.active_case_id = case_id | |
| state = self.case_state[case_id] | |
| case = self.workflow_cases[case_id] | |
| if not state["reviewed"]: | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.14, | |
| reason="Open the transaction review before pulling deeper evidence.", | |
| valid_action=False, | |
| ) | |
| self.current_screen = screen | |
| if evidence_key in state["revealed_evidence"]: | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.05, | |
| reason=f"{evidence_key} was already fetched for this case.", | |
| evidence_key=evidence_key, | |
| redundant=True, | |
| ) | |
| over_budget = state["fetch_budget_remaining"] <= 0 | |
| if not over_budget: | |
| state["fetch_budget_remaining"] -= 1 | |
| state["fetches_used"] += 1 | |
| state["revealed_evidence"][evidence_key] = case["evidence_catalog"][evidence_key] | |
| state["status"] = "investigating" | |
| if evidence_key == "policy_guide": | |
| state["policy_checked"] = True | |
| useful = evidence_key in case["required_tools"] or evidence_key in case["useful_tools"] | |
| base_value = 0.05 if useful else 0.0 | |
| reason = f"{evidence_key} revealed new hidden evidence." | |
| if evidence_key == "network_graph" and case["hidden_flags"]["network_high_risk"]: | |
| base_value += 0.08 | |
| reason = "Network graph revealed high-risk cluster evidence before the final decision." | |
| if over_budget: | |
| base_value -= 0.03 | |
| reason += " The fetch happened after the investigation budget was exhausted." | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=base_value, | |
| reason=reason, | |
| evidence_key=evidence_key, | |
| ) | |
| def _handle_add_note(self, case_id: str, action: FraudCheckAction) -> Reward: | |
| self.active_case_id = case_id | |
| self.current_screen = CaseScreenEnum.CASE_CONSOLE | |
| state = self.case_state[case_id] | |
| if not state["reviewed"]: | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.16, | |
| reason="Notes are only allowed after the transaction has been reviewed.", | |
| valid_action=False, | |
| ) | |
| assert action.note_text is not None # validated by Pydantic | |
| normalized_note = action.note_text.strip().lower() | |
| existing_notes = [note.lower() for note in state["notes"]] | |
| if normalized_note in existing_notes or len(existing_notes) >= 2: | |
| self.note_spam_count += 1 | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.12, | |
| reason="Repeated or low-value note spam is penalized.", | |
| anti_hacking=True, | |
| redundant=True, | |
| ) | |
| state["notes"].append(action.note_text.strip()) | |
| state["note_count"] += 1 | |
| state["status"] = "documented" | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=0.09, | |
| reason="Added a case note that documents the investigation state.", | |
| evidence_key="case_note", | |
| ) | |
| def _handle_resolve(self, case_id: str, action: FraudCheckAction) -> Reward: | |
| self.active_case_id = case_id | |
| self.current_screen = CaseScreenEnum.POLICY_ESCALATION | |
| state = self.case_state[case_id] | |
| case = self.workflow_cases[case_id] | |
| if not state["reviewed"]: | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=-0.35, | |
| reason="Cannot resolve a case before reviewing the transaction.", | |
| valid_action=False, | |
| ) | |
| assert action.resolution is not None # validated by Pydantic | |
| missing_required = [ | |
| tool for tool in case["required_tools"] if tool not in self._completed_tool_markers(case_id) | |
| ] | |
| note_missing = state["note_count"] == 0 | |
| policy_missing = case["policy_required"] and not state["policy_checked"] | |
| no_fetch_evidence = ( | |
| self.current_task in {TaskDifficulty.MEDIUM, TaskDifficulty.HARD} and state["fetches_used"] == 0 | |
| ) | |
| used_investigation_bonus = ( | |
| self.current_task in {TaskDifficulty.MEDIUM, TaskDifficulty.HARD} | |
| and state["fetches_used"] >= TASK_CONFIG[self.current_task]["minimum_fetches_for_bonus"] | |
| ) | |
| correct = action.resolution == case["correct_resolution"] | |
| policy_compliant = correct and (not policy_missing) | |
| base_value = 0.72 if correct and not missing_required and not note_missing else 0.38 if correct else -0.72 | |
| if policy_missing and correct: | |
| base_value -= 0.28 | |
| if note_missing: | |
| base_value -= 0.20 | |
| if missing_required and correct: | |
| base_value -= 0.16 * min(2, len(missing_required)) | |
| if no_fetch_evidence: | |
| base_value -= 0.10 | |
| if correct and used_investigation_bonus: | |
| base_value += 0.15 | |
| reason_parts = [] | |
| if correct: | |
| reason_parts.append("Resolution matched the hidden correct routing.") | |
| else: | |
| reason_parts.append( | |
| f"Incorrect routing: expected {case['correct_resolution'].value}, got {action.resolution.value}." | |
| ) | |
| if policy_missing: | |
| reason_parts.append("Policy was not checked before resolution.") | |
| if note_missing: | |
| reason_parts.append("A case note was required before closure.") | |
| if missing_required: | |
| reason_parts.append(f"Missing required workflow steps: {', '.join(sorted(missing_required))}.") | |
| if no_fetch_evidence: | |
| reason_parts.append("Medium and hard cases require at least one investigation fetch before resolution.") | |
| if correct and used_investigation_bonus: | |
| reason_parts.append("The route also earned the investigation-use bonus.") | |
| state["resolution"] = action.resolution | |
| state["resolved"] = True | |
| state["resolution_correct"] = correct | |
| state["policy_compliant"] = policy_compliant | |
| state["status"] = "resolved" | |
| unresolved = self._unresolved_case_ids() | |
| if not unresolved: | |
| self.current_screen = CaseScreenEnum.QUEUE | |
| else: | |
| self.active_case_id = unresolved[0] | |
| self.current_screen = CaseScreenEnum.QUEUE | |
| return self._apply_action_outcome( | |
| action=action, | |
| case_id=case_id, | |
| base_value=base_value, | |
| reason=" ".join(reason_parts), | |
| resolution=action.resolution, | |
| ground_truth_resolution=case["correct_resolution"], | |
| is_correct=correct, | |
| policy_compliant=policy_compliant, | |
| ) | |
| def _completed_tool_markers(self, case_id: str) -> set[str]: | |
| state = self.case_state[case_id] | |
| completed = set(state["revealed_evidence"].keys()) | |
| if state["note_count"] > 0: | |
| completed.add("case_note") | |
| return completed | |
| def _apply_action_outcome( | |
| self, | |
| action: FraudCheckAction, | |
| case_id: str, | |
| base_value: float, | |
| reason: str, | |
| evidence_key: str | None = None, | |
| resolution: ResolutionEnum | None = None, | |
| ground_truth_resolution: ResolutionEnum | None = None, | |
| is_correct: bool | None = None, | |
| policy_compliant: bool | None = None, | |
| valid_action: bool = True, | |
| redundant: bool = False, | |
| anti_hacking: bool = False, | |
| ) -> Reward: | |
| state = self.case_state.get(case_id) | |
| if state is not None: | |
| state["action_history"].append(action.action_type.value) | |
| if not valid_action: | |
| state["invalid_actions"] += 1 | |
| self.invalid_action_count += 1 | |
| if redundant: | |
| state["redundant_actions"] += 1 | |
| self.redundant_action_count += 1 | |
| if anti_hacking: | |
| state["anti_hacking_hits"] += 1 | |
| action_cost = 0.02 if action.action_type in FETCH_ACTIONS else 0.0 | |
| if action.action_type == ActionTypeEnum.ADD_CASE_NOTE: | |
| action_cost = 0.01 | |
| projected_step = self.step_count + 1 | |
| sla_penalty = 0.06 * max(0, projected_step - self.sla_limit[self.current_task]) | |
| reward_value = max(-1.0, min(1.0, base_value - action_cost - sla_penalty)) | |
| reward = Reward( | |
| value=round(reward_value, 4), | |
| reason=reason, | |
| action_type=action.action_type, | |
| case_id=case_id, | |
| action_cost=round(action_cost, 4), | |
| sla_penalty=round(sla_penalty, 4), | |
| evidence_key=evidence_key, | |
| resolution=resolution, | |
| ground_truth_resolution=ground_truth_resolution, | |
| is_correct=is_correct, | |
| policy_compliant=policy_compliant, | |
| anti_hacking_triggered=anti_hacking, | |
| ) | |
| self.step_count = projected_step | |
| self.cumulative_reward += reward.value | |
| if self.step_count >= self.max_steps[self.current_task]: | |
| self.is_done = True | |
| if not self._unresolved_case_ids(): | |
| self.is_done = True | |
| self.audit_log.append( | |
| { | |
| "step": self.step_count, | |
| "case_id": case_id, | |
| "action_type": action.action_type.value, | |
| "reward": reward.value, | |
| "reason": reason, | |
| "screen": self.current_screen.value, | |
| "resolved_case_ids": self._resolved_case_ids(), | |
| } | |
| ) | |
| return reward | |
| def _build_step_result(self, reward: Reward, extra_info: Dict[str, Any]) -> StepResult: | |
| observation = self._build_observation() | |
| if self.is_done: | |
| self.last_episode_summary = self.get_episode_report() | |
| info = { | |
| "episode_id": self.episode_id, | |
| "task": self.current_task.value, | |
| **extra_info, | |
| } | |
| return StepResult(observation=observation, reward=reward, done=self.is_done, info=info) | |
| def _build_observation(self) -> FraudCheckObservation: | |
| case_id = self.active_case_id or self.case_order[0] | |
| case = self.workflow_cases[case_id] | |
| state = self.case_state[case_id] | |
| is_triage_only = not state["reviewed"] and self.step_count == 0 | |
| case_summary = CaseSummary( | |
| case_id=case_id, | |
| status=state["status"], | |
| queue_reason=case["queue_card"]["queue_reason"], | |
| visible_risk_band="review", | |
| amount_usd=float(case["transaction"]["amount"]), | |
| merchant_region="masked" if not state["reviewed"] else case["transaction"]["shipping_address"], | |
| evidence_collected=sorted(state["revealed_evidence"].keys()), | |
| note_added=state["note_count"] > 0, | |
| ) | |
| visible_panels = ["triage_summary"] if is_triage_only else ["triage_summary", "evidence_panel"] | |
| if state["reviewed"]: | |
| visible_panels.append(self.current_screen.value.lower().replace(" ", "_")) | |
| visible_panels.extend(sorted(state["revealed_evidence"].keys())) | |
| if state["note_count"] > 0: | |
| visible_panels.append("case_notes") | |
| linked_case_ids = [] | |
| if "network_graph" in state["revealed_evidence"]: | |
| linked_case_ids = list(case["linked_case_ids"]) | |
| queue_items: List[QueueCaseCard] = [] | |
| if state["reviewed"]: | |
| queue_items = [ | |
| QueueCaseCard( | |
| case_id=workflow_case["case_id"], | |
| priority=workflow_case["queue_card"]["priority"], | |
| queue_reason=workflow_case["queue_card"]["queue_reason"], | |
| visible_risk_band="review", | |
| status=self.case_state[workflow_case["case_id"]]["status"], | |
| linked_case_ids=[], | |
| ) | |
| for workflow_case in self.workflow_cases.values() | |
| ] | |
| return FraudCheckObservation( | |
| case_id=case_id, | |
| task_name=self.current_task, | |
| current_screen=self.current_screen, | |
| visible_panels=visible_panels, | |
| revealed_evidence=copy.deepcopy(state["revealed_evidence"]), | |
| linked_case_ids=linked_case_ids, | |
| remaining_steps=self._remaining_steps(), | |
| remaining_sla=self._remaining_sla(), | |
| note_required=state["note_count"] == 0, | |
| allowed_actions=self._allowed_actions(case_id), | |
| queue_items=queue_items, | |
| case_summary=case_summary, | |
| episode_step=self.step_count, | |
| app_context={ | |
| "item_category": case["transaction"]["item_category"], | |
| "timestamp": case["transaction"]["timestamp"], | |
| "investigation_budget_remaining": state["fetch_budget_remaining"], | |
| "available_investigations": sorted(action.value for action in FETCH_ACTIONS), | |
| "task_description": TASK_CONFIG[self.current_task]["description"], | |
| }, | |
| ) | |
| def _allowed_actions(self, case_id: str) -> List[ActionTypeEnum]: | |
| if self.is_done: | |
| return [] | |
| state = self.case_state[case_id] | |
| if state["resolved"]: | |
| return [ActionTypeEnum.REVIEW_TRANSACTION] if self._unresolved_case_ids() else [] | |
| if not state["reviewed"]: | |
| return [ActionTypeEnum.REVIEW_TRANSACTION] | |
| return [ | |
| ActionTypeEnum.REVIEW_TRANSACTION, | |
| ActionTypeEnum.FETCH_CUSTOMER_PROFILE, | |
| ActionTypeEnum.FETCH_MERCHANT_PROFILE, | |
| ActionTypeEnum.FETCH_NETWORK_GRAPH, | |
| ActionTypeEnum.CHECK_POLICY, | |
| ActionTypeEnum.ADD_CASE_NOTE, | |
| ActionTypeEnum.RESOLVE_CASE, | |
| ] | |
| def _remaining_steps(self) -> int: | |
| return max(0, self.max_steps[self.current_task] - self.step_count) | |
| def _remaining_sla(self) -> int: | |
| return max(0, self.sla_limit[self.current_task] - self.step_count) | |
| def _resolved_case_ids(self) -> List[str]: | |
| return [case_id for case_id in self.case_order if self.case_state[case_id]["resolved"]] | |
| def _unresolved_case_ids(self) -> List[str]: | |
| return [case_id for case_id in self.case_order if not self.case_state[case_id]["resolved"]] | |
| def _build_case_report(self, case_id: str) -> Dict[str, Any]: | |
| case = self.workflow_cases[case_id] | |
| state = self.case_state[case_id] | |
| required_tools = case["required_tools"] | |
| completed_tools = self._completed_tool_markers(case_id) | |
| coverage = len(required_tools & completed_tools) / max(1, len(required_tools)) | |
| workflow_completion = ( | |
| (1.0 if state["reviewed"] else 0.0) | |
| + (1.0 if state["note_count"] > 0 else 0.0) | |
| + (1.0 if state["resolved"] else 0.0) | |
| + coverage | |
| ) / 4.0 | |
| return { | |
| "case_id": case_id, | |
| "queue_reason": case["queue_card"]["queue_reason"], | |
| "correct_resolution": case["correct_resolution"].value, | |
| "submitted_resolution": state["resolution"].value if state["resolution"] else None, | |
| "reviewed": state["reviewed"], | |
| "note_count": state["note_count"], | |
| "policy_checked": state["policy_checked"], | |
| "revealed_evidence": sorted(state["revealed_evidence"].keys()), | |
| "resolution_correct": state["resolution_correct"], | |
| "policy_compliant": state["policy_compliant"], | |
| "invalid_actions": state["invalid_actions"], | |
| "redundant_actions": state["redundant_actions"], | |
| "anti_hacking_hits": state["anti_hacking_hits"], | |
| "linked_case_ids": list(case["linked_case_ids"]), | |
| "evidence_coverage": round(coverage, 4), | |
| "workflow_completion": round(workflow_completion, 4), | |
| } | |