fraudshield-1 / fraudshield_env.py
DevikaJ2005's picture
Finalize RL-first environment and explorer UI
30533d1
"""FraudShield partial-observability environment implementation."""
from __future__ import annotations
import copy
import uuid
from typing import Any, Dict, List
from data_loader import FraudDataLoader
from models import (
ActionTypeEnum,
CaseScreenEnum,
CaseSummary,
EpisodeState,
FraudCheckAction,
FraudCheckObservation,
QueueCaseCard,
ResetResult,
ResolutionEnum,
Reward,
StepResult,
TaskDifficulty,
)
TASK_CONFIG: Dict[TaskDifficulty, Dict[str, Any]] = {
TaskDifficulty.EASY: {
"source_task": "easy",
"num_cases": 1,
"max_steps": 6,
"sla_limit": 5,
"ideal_steps": 3,
"investigation_budget": 1,
"minimum_fetches_for_bonus": 1,
"description": "Single low-noise case with strong visible cues and one fetch budget.",
},
TaskDifficulty.MEDIUM: {
"source_task": "medium",
"num_cases": 1,
"max_steps": 8,
"sla_limit": 6,
"ideal_steps": 5,
"investigation_budget": 2,
"minimum_fetches_for_bonus": 1,
"description": "Single mixed-signal case that requires at least one investigation before routing.",
},
TaskDifficulty.HARD: {
"source_task": "hard",
"num_cases": 2,
"max_steps": 14,
"sla_limit": 11,
"ideal_steps": 9,
"investigation_budget": 3,
"minimum_fetches_for_bonus": 1,
"description": "Two misleading linked cases where graph evidence is usually required.",
},
}
FETCH_ACTIONS = {
ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
ActionTypeEnum.FETCH_MERCHANT_PROFILE,
ActionTypeEnum.FETCH_NETWORK_GRAPH,
ActionTypeEnum.CHECK_POLICY,
}
class FraudShieldEnvironment:
"""OpenEnv-compatible fraud-investigation environment."""
def __init__(self, data_path: str = "data", seed: int = 42):
self.seed = seed
self.data_loader = FraudDataLoader(data_path=data_path, seed=seed)
self.data_loaded = False
self.episode_id = ""
self.current_task = TaskDifficulty.EASY
self.step_count = 0
self.cumulative_reward = 0.0
self.is_done = False
self.current_screen = CaseScreenEnum.QUEUE
self.active_case_id = ""
self.workflow_cases: Dict[str, Dict[str, Any]] = {}
self.case_state: Dict[str, Dict[str, Any]] = {}
self.case_order: List[str] = []
self.audit_log: List[Dict[str, Any]] = []
self.invalid_action_count = 0
self.redundant_action_count = 0
self.note_spam_count = 0
self.case_counts = {task: TASK_CONFIG[task]["num_cases"] for task in TaskDifficulty}
self.max_steps = {task: TASK_CONFIG[task]["max_steps"] for task in TaskDifficulty}
self.sla_limit = {task: TASK_CONFIG[task]["sla_limit"] for task in TaskDifficulty}
self.last_episode_summary: Dict[str, Any] = {}
def load_data(self) -> bool:
"""Load the deterministic committed snapshot."""
self.data_loaded = self.data_loader.load_data()
return self.data_loaded
def load_kaggle_data(self) -> bool:
"""Backward-compatible alias for older validation scripts."""
return self.load_data()
def ensure_data_loaded(self) -> None:
"""Load data lazily for local and remote execution."""
if not self.data_loaded and not self.load_data():
raise RuntimeError("FraudShield data bundle could not be loaded.")
def reset(self, task: str = "easy") -> ResetResult:
"""Start a new fraud-investigation episode."""
self.ensure_data_loaded()
self.current_task = TaskDifficulty(task)
config = TASK_CONFIG[self.current_task]
self.episode_id = f"ep_{uuid.uuid4().hex[:8]}"
self.step_count = 0
self.cumulative_reward = 0.0
self.is_done = False
self.current_screen = CaseScreenEnum.QUEUE
self.invalid_action_count = 0
self.redundant_action_count = 0
self.note_spam_count = 0
self.audit_log = []
self.last_episode_summary = {}
self.workflow_cases = self._build_workflow_cases(self.current_task)
self.case_order = list(self.workflow_cases.keys())
self.active_case_id = self.case_order[0]
self.case_state = {
case_id: {
"status": "triage",
"reviewed": False,
"revealed_evidence": {},
"note_count": 0,
"notes": [],
"policy_checked": False,
"resolution": None,
"resolved": False,
"resolution_correct": False,
"policy_compliant": False,
"invalid_actions": 0,
"redundant_actions": 0,
"anti_hacking_hits": 0,
"action_history": [],
"fetches_used": 0,
"fetch_budget_remaining": config["investigation_budget"],
}
for case_id in self.case_order
}
info = {
"episode_id": self.episode_id,
"task": self.current_task.value,
"num_cases": config["num_cases"],
"max_steps": config["max_steps"],
"sla_limit": config["sla_limit"],
"investigation_budget": config["investigation_budget"],
"description": config["description"],
"workflow_views": [screen.value for screen in CaseScreenEnum],
"data_snapshot": self.data_loader.get_bundle_summary(),
}
return ResetResult(observation=self._build_observation(), info=info)
def step(self, action: FraudCheckAction) -> StepResult:
"""Apply a single investigation or resolution action."""
if self.is_done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
if action.case_id not in self.workflow_cases:
reward = self._apply_action_outcome(
action=action,
case_id=self.active_case_id or self.case_order[0],
base_value=-0.35,
reason=f"Unknown case_id '{action.case_id}'.",
valid_action=False,
anti_hacking=True,
)
return self._build_step_result(reward, {"valid_action": False, "error": "unknown_case"})
if action.action_type != ActionTypeEnum.REVIEW_TRANSACTION and action.case_id != self.active_case_id:
reward = self._apply_action_outcome(
action=action,
case_id=action.case_id,
base_value=-0.18,
reason="Only review_transaction may switch focus to another case.",
valid_action=False,
)
return self._build_step_result(reward, {"valid_action": False, "error": "inactive_case"})
case_id = action.case_id
state = self.case_state[case_id]
if state["resolved"] and action.action_type != ActionTypeEnum.REVIEW_TRANSACTION:
reward = self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.22,
reason="Resolved cases cannot be modified; move to an open case instead.",
valid_action=False,
)
return self._build_step_result(reward, {"valid_action": False, "error": "resolved_case"})
if action.action_type == ActionTypeEnum.REVIEW_TRANSACTION:
reward = self._handle_review(case_id, action)
elif action.action_type == ActionTypeEnum.FETCH_CUSTOMER_PROFILE:
reward = self._handle_fetch(case_id, action, "customer_profile", CaseScreenEnum.CUSTOMER_PROFILE)
elif action.action_type == ActionTypeEnum.FETCH_MERCHANT_PROFILE:
reward = self._handle_fetch(case_id, action, "merchant_profile", CaseScreenEnum.MERCHANT_PROFILE)
elif action.action_type == ActionTypeEnum.FETCH_NETWORK_GRAPH:
reward = self._handle_fetch(case_id, action, "network_graph", CaseScreenEnum.CASE_CONSOLE)
elif action.action_type == ActionTypeEnum.CHECK_POLICY:
reward = self._handle_fetch(case_id, action, "policy_guide", CaseScreenEnum.POLICY_ESCALATION)
elif action.action_type == ActionTypeEnum.ADD_CASE_NOTE:
reward = self._handle_add_note(case_id, action)
elif action.action_type == ActionTypeEnum.RESOLVE_CASE:
reward = self._handle_resolve(case_id, action)
else: # pragma: no cover - enum already constrains values
reward = self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.25,
reason=f"Unsupported action_type '{action.action_type.value}'.",
valid_action=False,
)
return self._build_step_result(
reward,
{
"valid_action": reward.value > -0.999 and not reward.reason.startswith("Unknown case_id"),
"active_case_id": self.active_case_id,
"resolved_case_ids": self._resolved_case_ids(),
"remaining_sla": self._remaining_sla(),
"remaining_steps": self._remaining_steps(),
},
)
def state(self) -> EpisodeState:
"""Return the current episode state."""
return EpisodeState(
episode_id=self.episode_id,
task_name=self.current_task,
current_screen=self.current_screen,
active_case_id=self.active_case_id,
step_count=self.step_count,
remaining_steps=self._remaining_steps(),
remaining_sla=self._remaining_sla(),
cumulative_reward=round(self.cumulative_reward, 4),
is_done=self.is_done,
resolved_case_ids=self._resolved_case_ids(),
unresolved_case_ids=self._unresolved_case_ids(),
notes_written_by_case={case_id: data["note_count"] for case_id, data in self.case_state.items()},
evidence_keys_by_case={
case_id: sorted(data["revealed_evidence"].keys()) for case_id, data in self.case_state.items()
},
policy_checked_case_ids=sorted(
case_id for case_id, data in self.case_state.items() if data["policy_checked"]
),
resolution_by_case={
case_id: data["resolution"]
for case_id, data in self.case_state.items()
if data["resolution"] is not None
},
invalid_action_count=self.invalid_action_count,
redundant_action_count=self.redundant_action_count,
)
def get_episode_report(self) -> Dict[str, Any]:
"""Return a deterministic grading report for the current or completed episode."""
case_reports = [self._build_case_report(case_id) for case_id in self.case_order]
case_count = max(1, len(case_reports))
resolution_accuracy = sum(1.0 if report["resolution_correct"] else 0.0 for report in case_reports) / case_count
evidence_coverage = sum(report["evidence_coverage"] for report in case_reports) / case_count
policy_compliance = sum(1.0 if report["policy_compliant"] else 0.0 for report in case_reports) / case_count
workflow_completion = (
sum(report["workflow_completion"] for report in case_reports) / case_count if case_reports else 0.0
)
overstep_penalty = max(0, self.step_count - TASK_CONFIG[self.current_task]["ideal_steps"]) * 0.05
efficiency = max(
0.0,
1.0
- (self.invalid_action_count * 0.12)
- (self.redundant_action_count * 0.08)
- (self.note_spam_count * 0.06)
- overstep_penalty,
)
if self.current_task == TaskDifficulty.HARD:
reviewed_count = sum(1.0 if report["reviewed"] else 0.0 for report in case_reports)
network_count = sum(
1.0 if "network_graph" in report["revealed_evidence"] else 0.0 for report in case_reports
)
resolved_count = sum(1.0 if report["submitted_resolution"] else 0.0 for report in case_reports)
link_consistency = min(
1.0,
0.25 * reviewed_count + 0.25 * network_count + 0.25 * resolved_count + 0.25 * resolution_accuracy * 2,
)
else:
link_consistency = 1.0
summary = {
"episode_id": self.episode_id,
"task": self.current_task.value,
"step_count": self.step_count,
"max_steps": TASK_CONFIG[self.current_task]["max_steps"],
"remaining_sla": self._remaining_sla(),
"cumulative_reward": round(self.cumulative_reward, 4),
"invalid_action_count": self.invalid_action_count,
"redundant_action_count": self.redundant_action_count,
"note_spam_count": self.note_spam_count,
"case_summaries": case_reports,
"metrics": {
"resolution_accuracy": round(resolution_accuracy, 4),
"evidence_coverage": round(evidence_coverage, 4),
"policy_compliance": round(policy_compliance, 4),
"workflow_completion": round(workflow_completion, 4),
"efficiency": round(efficiency, 4),
"link_consistency": round(link_consistency, 4),
},
"audit_log": list(self.audit_log),
}
self.last_episode_summary = summary
return summary
def _build_workflow_cases(self, task: TaskDifficulty) -> Dict[str, Dict[str, Any]]:
if task == TaskDifficulty.EASY:
source_case = self._select_easy_case(self.data_loader.get_task_cases("easy"))
return {
"easy_case_01": self._make_workflow_case(
case_id="easy_case_01",
raw_case=source_case,
queue_reason="High-value purchase queued for a quick manual review.",
correct_resolution=ResolutionEnum.BLOCK,
required_tools={"transaction_review", "case_note"},
useful_tools={"merchant_profile"},
policy_required=False,
linked_case_ids=[],
role="single",
)
}
if task == TaskDifficulty.MEDIUM:
source_case = self._select_medium_case(self.data_loader.get_task_cases("medium"))
return {
"medium_case_01": self._make_workflow_case(
case_id="medium_case_01",
raw_case=source_case,
queue_reason="Mixed signals triggered review and supporting evidence is needed.",
correct_resolution=ResolutionEnum.REQUEST_DOCS,
required_tools={"transaction_review", "customer_profile", "policy_guide", "case_note"},
useful_tools={"merchant_profile"},
policy_required=True,
linked_case_ids=[],
role="single",
)
}
hard_cases = self._select_hard_pair(self.data_loader.get_task_cases("hard"))
primary = self._make_workflow_case(
case_id="hard_case_primary",
raw_case=hard_cases[0],
queue_reason="Operational anomaly spike triggered a higher-touch review.",
correct_resolution=ResolutionEnum.ESCALATE,
required_tools={"transaction_review", "network_graph", "merchant_profile", "policy_guide", "case_note"},
useful_tools=set(),
policy_required=True,
linked_case_ids=["hard_case_secondary"],
role="primary",
)
secondary = self._make_workflow_case(
case_id="hard_case_secondary",
raw_case=hard_cases[1],
queue_reason="A related anomaly surfaced in the same review wave.",
correct_resolution=ResolutionEnum.BLOCK,
required_tools={"transaction_review", "network_graph", "customer_profile", "policy_guide", "case_note"},
useful_tools=set(),
policy_required=True,
linked_case_ids=["hard_case_primary"],
role="secondary",
)
return {primary["case_id"]: primary, secondary["case_id"]: secondary}
def _select_easy_case(self, cases: List[Dict[str, Any]]) -> Dict[str, Any]:
fraud_cases = [case for case in cases if case["label"] == "fraud"]
return max(fraud_cases, key=lambda case: (case["risk_score"], case["business_cost"]))
def _select_medium_case(self, cases: List[Dict[str, Any]]) -> Dict[str, Any]:
candidates = [
case
for case in cases
if case["label"] == "legitimate"
and case["transaction_data"]["previous_fraud_flags"] >= 1
and case["transaction_data"]["seller_chargeback_rate_30d"] >= 0.04
]
if not candidates:
candidates = [case for case in cases if case["label"] == "legitimate"]
return max(candidates, key=lambda case: (case["risk_score"], case["business_cost"]))
def _select_hard_pair(self, cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
groups: Dict[str, List[Dict[str, Any]]] = {}
for case in cases:
seller_id = case["transaction_data"]["seller_id"]
groups.setdefault(seller_id, []).append(case)
linked_groups = [
group
for group in groups.values()
if len(group) >= 2 and all(case["label"] == "fraud" for case in group)
]
if not linked_groups:
linked_groups = [cases[:2]]
chosen_group = max(linked_groups, key=lambda group: sum(case["business_cost"] for case in group))
ordered = sorted(chosen_group, key=lambda case: (case["business_cost"], case["risk_score"]), reverse=True)
return ordered[:2]
def _make_workflow_case(
self,
case_id: str,
raw_case: Dict[str, Any],
queue_reason: str,
correct_resolution: ResolutionEnum,
required_tools: set[str],
useful_tools: set[str],
policy_required: bool,
linked_case_ids: List[str],
role: str,
) -> Dict[str, Any]:
transaction = copy.deepcopy(raw_case["transaction_data"])
history = copy.deepcopy(raw_case["historical_context"])
risk_score = float(raw_case["risk_score"])
business_cost = float(raw_case["business_cost"])
shipping_country = transaction["shipping_address"]
device_country = transaction["device_country"]
geo_mismatch = shipping_country != device_country
timestamp = transaction["timestamp"]
queue_card = {
"case_id": case_id,
"priority": self._priority_label(risk_score, business_cost),
"queue_reason": queue_reason,
"visible_risk_band": "review",
"status": "triage",
"linked_case_ids": [],
}
transaction_review = {
"view": CaseScreenEnum.CASE_CONSOLE.value,
"summary": self._transaction_summary(transaction, geo_mismatch),
"facts": {
"amount_usd": transaction["amount"],
"item_category": transaction["item_category"],
"timestamp": timestamp,
"shipping_country": shipping_country,
"device_country": device_country,
"payment_method": transaction["payment_method"],
"shipping_speed": transaction["shipping_speed"],
"same_address_orders_24h": transaction["same_address_orders_24h"],
},
}
customer_profile = {
"view": CaseScreenEnum.CUSTOMER_PROFILE.value,
"summary": self._customer_summary(transaction),
"facts": {
"buyer_account_age_days": transaction["buyer_account_age_days"],
"buyer_disputes_90d": transaction["buyer_disputes_90d"],
"is_repeat_buyer": transaction["is_repeat_buyer"],
},
}
merchant_profile = {
"view": CaseScreenEnum.MERCHANT_PROFILE.value,
"summary": self._merchant_summary(transaction),
"facts": {
"seller_account_age_days": transaction["seller_account_age_days"],
"seller_avg_rating": transaction["seller_avg_rating"],
"num_seller_reviews": transaction["num_seller_reviews"],
"seller_chargeback_rate_30d": transaction["seller_chargeback_rate_30d"],
},
}
network_graph = {
"view": CaseScreenEnum.CASE_CONSOLE.value,
"summary": self._network_summary(role, linked_case_ids, history),
"facts": {
"shared_device_accounts_24h": transaction["shared_device_accounts_24h"],
"previous_fraud_flags": transaction["previous_fraud_flags"],
"cluster_alert_score": history["cluster_alert_score"],
"linked_cards_7d": history["linked_cards_7d"],
"linked_case_ids": list(linked_case_ids),
},
}
policy_guide = {
"view": CaseScreenEnum.POLICY_ESCALATION.value,
"summary": self._policy_summary(policy_required, business_cost),
"facts": {
"policy_required": policy_required,
"note_required": True,
"request_docs_on_unresolved_conflict": self.current_task != TaskDifficulty.EASY,
"escalate_if_cluster_and_loss": role == "primary" and business_cost >= 1.35,
"high_loss_threshold": 1.35,
},
}
return {
"case_id": case_id,
"raw_case": raw_case,
"transaction": transaction,
"history": history,
"risk_score": risk_score,
"business_cost": business_cost,
"correct_resolution": correct_resolution,
"required_tools": set(required_tools),
"useful_tools": set(useful_tools),
"policy_required": policy_required,
"linked_case_ids": list(linked_case_ids),
"role": role,
"queue_card": queue_card,
"hidden_flags": {
"payment_ops_high_risk": self._high_risk_payment_ops(transaction, geo_mismatch),
"network_high_risk": history["cluster_alert_score"] >= 0.7,
},
"evidence_catalog": {
"transaction_review": transaction_review,
"customer_profile": customer_profile,
"merchant_profile": merchant_profile,
"network_graph": network_graph,
"policy_guide": policy_guide,
},
}
def _priority_label(self, risk_score: float, business_cost: float) -> str:
if risk_score >= 0.68 or business_cost >= 1.45:
return "P1"
if risk_score >= 0.5 or business_cost >= 1.1:
return "P2"
return "P3"
def _transaction_summary(self, transaction: Dict[str, Any], geo_mismatch: bool) -> str:
return (
f"Payment {transaction['payment_method']}; shipping {transaction['shipping_speed']}; "
f"same-address orders={transaction['same_address_orders_24h']}; geo mismatch={geo_mismatch}."
)
def _customer_summary(self, transaction: Dict[str, Any]) -> str:
return (
f"Buyer age {transaction['buyer_account_age_days']}d; disputes {transaction['buyer_disputes_90d']}; "
f"repeat buyer={transaction['is_repeat_buyer']}."
)
def _merchant_summary(self, transaction: Dict[str, Any]) -> str:
return (
f"Seller rating {transaction['seller_avg_rating']:.2f}; reviews {transaction['num_seller_reviews']}; "
f"chargeback rate {transaction['seller_chargeback_rate_30d']:.3f}."
)
def _network_summary(self, role: str, linked_case_ids: List[str], history: Dict[str, Any]) -> str:
if not linked_case_ids:
return (
f"Graph review surfaced cluster score {history['cluster_alert_score']:.2f} "
"with no immediately visible linked cases."
)
if role == "primary":
return (
f"Graph review surfaced a cluster score of {history['cluster_alert_score']:.2f} "
"and a shared-entity pattern worth escalation review."
)
return (
f"Graph review surfaced a cluster score of {history['cluster_alert_score']:.2f} "
"and a related-activity pattern on this case."
)
def _policy_summary(self, policy_required: bool, business_cost: float) -> str:
if not policy_required:
return "Policy allows direct approve or block decisions once a note is added."
if business_cost >= 1.35:
return "Policy recommends escalation when hidden network risk and business impact are both elevated."
return "Policy recommends requesting documents or holding the case when signals remain mixed."
def _high_risk_payment_ops(self, transaction: Dict[str, Any], geo_mismatch: bool) -> bool:
return bool(
transaction["payment_method"] in {"prepaid_card", "gift_card", "crypto_gateway"}
or transaction["shipping_speed"] in {"same-day", "overnight"}
or transaction["same_address_orders_24h"] >= 5
or geo_mismatch
)
def _handle_review(self, case_id: str, action: FraudCheckAction) -> Reward:
self.active_case_id = case_id
self.current_screen = CaseScreenEnum.CASE_CONSOLE
state = self.case_state[case_id]
case = self.workflow_cases[case_id]
if state["reviewed"]:
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.05,
reason="Transaction review was already completed for this case.",
evidence_key="transaction_review",
redundant=True,
)
state["reviewed"] = True
state["status"] = "in_review"
state["revealed_evidence"]["transaction_review"] = case["evidence_catalog"]["transaction_review"]
bonus = 0.08 if case["hidden_flags"]["payment_ops_high_risk"] else 0.0
reason = "Transaction review revealed the operational transaction trace."
if bonus > 0:
reason += " The review surfaced high-risk payment or fulfillment signals."
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=0.04 + bonus,
reason=reason,
evidence_key="transaction_review",
)
def _handle_fetch(
self,
case_id: str,
action: FraudCheckAction,
evidence_key: str,
screen: CaseScreenEnum,
) -> Reward:
self.active_case_id = case_id
state = self.case_state[case_id]
case = self.workflow_cases[case_id]
if not state["reviewed"]:
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.14,
reason="Open the transaction review before pulling deeper evidence.",
valid_action=False,
)
self.current_screen = screen
if evidence_key in state["revealed_evidence"]:
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.05,
reason=f"{evidence_key} was already fetched for this case.",
evidence_key=evidence_key,
redundant=True,
)
over_budget = state["fetch_budget_remaining"] <= 0
if not over_budget:
state["fetch_budget_remaining"] -= 1
state["fetches_used"] += 1
state["revealed_evidence"][evidence_key] = case["evidence_catalog"][evidence_key]
state["status"] = "investigating"
if evidence_key == "policy_guide":
state["policy_checked"] = True
useful = evidence_key in case["required_tools"] or evidence_key in case["useful_tools"]
base_value = 0.05 if useful else 0.0
reason = f"{evidence_key} revealed new hidden evidence."
if evidence_key == "network_graph" and case["hidden_flags"]["network_high_risk"]:
base_value += 0.08
reason = "Network graph revealed high-risk cluster evidence before the final decision."
if over_budget:
base_value -= 0.03
reason += " The fetch happened after the investigation budget was exhausted."
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=base_value,
reason=reason,
evidence_key=evidence_key,
)
def _handle_add_note(self, case_id: str, action: FraudCheckAction) -> Reward:
self.active_case_id = case_id
self.current_screen = CaseScreenEnum.CASE_CONSOLE
state = self.case_state[case_id]
if not state["reviewed"]:
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.16,
reason="Notes are only allowed after the transaction has been reviewed.",
valid_action=False,
)
assert action.note_text is not None # validated by Pydantic
normalized_note = action.note_text.strip().lower()
existing_notes = [note.lower() for note in state["notes"]]
if normalized_note in existing_notes or len(existing_notes) >= 2:
self.note_spam_count += 1
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.12,
reason="Repeated or low-value note spam is penalized.",
anti_hacking=True,
redundant=True,
)
state["notes"].append(action.note_text.strip())
state["note_count"] += 1
state["status"] = "documented"
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=0.09,
reason="Added a case note that documents the investigation state.",
evidence_key="case_note",
)
def _handle_resolve(self, case_id: str, action: FraudCheckAction) -> Reward:
self.active_case_id = case_id
self.current_screen = CaseScreenEnum.POLICY_ESCALATION
state = self.case_state[case_id]
case = self.workflow_cases[case_id]
if not state["reviewed"]:
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=-0.35,
reason="Cannot resolve a case before reviewing the transaction.",
valid_action=False,
)
assert action.resolution is not None # validated by Pydantic
missing_required = [
tool for tool in case["required_tools"] if tool not in self._completed_tool_markers(case_id)
]
note_missing = state["note_count"] == 0
policy_missing = case["policy_required"] and not state["policy_checked"]
no_fetch_evidence = (
self.current_task in {TaskDifficulty.MEDIUM, TaskDifficulty.HARD} and state["fetches_used"] == 0
)
used_investigation_bonus = (
self.current_task in {TaskDifficulty.MEDIUM, TaskDifficulty.HARD}
and state["fetches_used"] >= TASK_CONFIG[self.current_task]["minimum_fetches_for_bonus"]
)
correct = action.resolution == case["correct_resolution"]
policy_compliant = correct and (not policy_missing)
base_value = 0.72 if correct and not missing_required and not note_missing else 0.38 if correct else -0.72
if policy_missing and correct:
base_value -= 0.28
if note_missing:
base_value -= 0.20
if missing_required and correct:
base_value -= 0.16 * min(2, len(missing_required))
if no_fetch_evidence:
base_value -= 0.10
if correct and used_investigation_bonus:
base_value += 0.15
reason_parts = []
if correct:
reason_parts.append("Resolution matched the hidden correct routing.")
else:
reason_parts.append(
f"Incorrect routing: expected {case['correct_resolution'].value}, got {action.resolution.value}."
)
if policy_missing:
reason_parts.append("Policy was not checked before resolution.")
if note_missing:
reason_parts.append("A case note was required before closure.")
if missing_required:
reason_parts.append(f"Missing required workflow steps: {', '.join(sorted(missing_required))}.")
if no_fetch_evidence:
reason_parts.append("Medium and hard cases require at least one investigation fetch before resolution.")
if correct and used_investigation_bonus:
reason_parts.append("The route also earned the investigation-use bonus.")
state["resolution"] = action.resolution
state["resolved"] = True
state["resolution_correct"] = correct
state["policy_compliant"] = policy_compliant
state["status"] = "resolved"
unresolved = self._unresolved_case_ids()
if not unresolved:
self.current_screen = CaseScreenEnum.QUEUE
else:
self.active_case_id = unresolved[0]
self.current_screen = CaseScreenEnum.QUEUE
return self._apply_action_outcome(
action=action,
case_id=case_id,
base_value=base_value,
reason=" ".join(reason_parts),
resolution=action.resolution,
ground_truth_resolution=case["correct_resolution"],
is_correct=correct,
policy_compliant=policy_compliant,
)
def _completed_tool_markers(self, case_id: str) -> set[str]:
state = self.case_state[case_id]
completed = set(state["revealed_evidence"].keys())
if state["note_count"] > 0:
completed.add("case_note")
return completed
def _apply_action_outcome(
self,
action: FraudCheckAction,
case_id: str,
base_value: float,
reason: str,
evidence_key: str | None = None,
resolution: ResolutionEnum | None = None,
ground_truth_resolution: ResolutionEnum | None = None,
is_correct: bool | None = None,
policy_compliant: bool | None = None,
valid_action: bool = True,
redundant: bool = False,
anti_hacking: bool = False,
) -> Reward:
state = self.case_state.get(case_id)
if state is not None:
state["action_history"].append(action.action_type.value)
if not valid_action:
state["invalid_actions"] += 1
self.invalid_action_count += 1
if redundant:
state["redundant_actions"] += 1
self.redundant_action_count += 1
if anti_hacking:
state["anti_hacking_hits"] += 1
action_cost = 0.02 if action.action_type in FETCH_ACTIONS else 0.0
if action.action_type == ActionTypeEnum.ADD_CASE_NOTE:
action_cost = 0.01
projected_step = self.step_count + 1
sla_penalty = 0.06 * max(0, projected_step - self.sla_limit[self.current_task])
reward_value = max(-1.0, min(1.0, base_value - action_cost - sla_penalty))
reward = Reward(
value=round(reward_value, 4),
reason=reason,
action_type=action.action_type,
case_id=case_id,
action_cost=round(action_cost, 4),
sla_penalty=round(sla_penalty, 4),
evidence_key=evidence_key,
resolution=resolution,
ground_truth_resolution=ground_truth_resolution,
is_correct=is_correct,
policy_compliant=policy_compliant,
anti_hacking_triggered=anti_hacking,
)
self.step_count = projected_step
self.cumulative_reward += reward.value
if self.step_count >= self.max_steps[self.current_task]:
self.is_done = True
if not self._unresolved_case_ids():
self.is_done = True
self.audit_log.append(
{
"step": self.step_count,
"case_id": case_id,
"action_type": action.action_type.value,
"reward": reward.value,
"reason": reason,
"screen": self.current_screen.value,
"resolved_case_ids": self._resolved_case_ids(),
}
)
return reward
def _build_step_result(self, reward: Reward, extra_info: Dict[str, Any]) -> StepResult:
observation = self._build_observation()
if self.is_done:
self.last_episode_summary = self.get_episode_report()
info = {
"episode_id": self.episode_id,
"task": self.current_task.value,
**extra_info,
}
return StepResult(observation=observation, reward=reward, done=self.is_done, info=info)
def _build_observation(self) -> FraudCheckObservation:
case_id = self.active_case_id or self.case_order[0]
case = self.workflow_cases[case_id]
state = self.case_state[case_id]
is_triage_only = not state["reviewed"] and self.step_count == 0
case_summary = CaseSummary(
case_id=case_id,
status=state["status"],
queue_reason=case["queue_card"]["queue_reason"],
visible_risk_band="review",
amount_usd=float(case["transaction"]["amount"]),
merchant_region="masked" if not state["reviewed"] else case["transaction"]["shipping_address"],
evidence_collected=sorted(state["revealed_evidence"].keys()),
note_added=state["note_count"] > 0,
)
visible_panels = ["triage_summary"] if is_triage_only else ["triage_summary", "evidence_panel"]
if state["reviewed"]:
visible_panels.append(self.current_screen.value.lower().replace(" ", "_"))
visible_panels.extend(sorted(state["revealed_evidence"].keys()))
if state["note_count"] > 0:
visible_panels.append("case_notes")
linked_case_ids = []
if "network_graph" in state["revealed_evidence"]:
linked_case_ids = list(case["linked_case_ids"])
queue_items: List[QueueCaseCard] = []
if state["reviewed"]:
queue_items = [
QueueCaseCard(
case_id=workflow_case["case_id"],
priority=workflow_case["queue_card"]["priority"],
queue_reason=workflow_case["queue_card"]["queue_reason"],
visible_risk_band="review",
status=self.case_state[workflow_case["case_id"]]["status"],
linked_case_ids=[],
)
for workflow_case in self.workflow_cases.values()
]
return FraudCheckObservation(
case_id=case_id,
task_name=self.current_task,
current_screen=self.current_screen,
visible_panels=visible_panels,
revealed_evidence=copy.deepcopy(state["revealed_evidence"]),
linked_case_ids=linked_case_ids,
remaining_steps=self._remaining_steps(),
remaining_sla=self._remaining_sla(),
note_required=state["note_count"] == 0,
allowed_actions=self._allowed_actions(case_id),
queue_items=queue_items,
case_summary=case_summary,
episode_step=self.step_count,
app_context={
"item_category": case["transaction"]["item_category"],
"timestamp": case["transaction"]["timestamp"],
"investigation_budget_remaining": state["fetch_budget_remaining"],
"available_investigations": sorted(action.value for action in FETCH_ACTIONS),
"task_description": TASK_CONFIG[self.current_task]["description"],
},
)
def _allowed_actions(self, case_id: str) -> List[ActionTypeEnum]:
if self.is_done:
return []
state = self.case_state[case_id]
if state["resolved"]:
return [ActionTypeEnum.REVIEW_TRANSACTION] if self._unresolved_case_ids() else []
if not state["reviewed"]:
return [ActionTypeEnum.REVIEW_TRANSACTION]
return [
ActionTypeEnum.REVIEW_TRANSACTION,
ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
ActionTypeEnum.FETCH_MERCHANT_PROFILE,
ActionTypeEnum.FETCH_NETWORK_GRAPH,
ActionTypeEnum.CHECK_POLICY,
ActionTypeEnum.ADD_CASE_NOTE,
ActionTypeEnum.RESOLVE_CASE,
]
def _remaining_steps(self) -> int:
return max(0, self.max_steps[self.current_task] - self.step_count)
def _remaining_sla(self) -> int:
return max(0, self.sla_limit[self.current_task] - self.step_count)
def _resolved_case_ids(self) -> List[str]:
return [case_id for case_id in self.case_order if self.case_state[case_id]["resolved"]]
def _unresolved_case_ids(self) -> List[str]:
return [case_id for case_id in self.case_order if not self.case_state[case_id]["resolved"]]
def _build_case_report(self, case_id: str) -> Dict[str, Any]:
case = self.workflow_cases[case_id]
state = self.case_state[case_id]
required_tools = case["required_tools"]
completed_tools = self._completed_tool_markers(case_id)
coverage = len(required_tools & completed_tools) / max(1, len(required_tools))
workflow_completion = (
(1.0 if state["reviewed"] else 0.0)
+ (1.0 if state["note_count"] > 0 else 0.0)
+ (1.0 if state["resolved"] else 0.0)
+ coverage
) / 4.0
return {
"case_id": case_id,
"queue_reason": case["queue_card"]["queue_reason"],
"correct_resolution": case["correct_resolution"].value,
"submitted_resolution": state["resolution"].value if state["resolution"] else None,
"reviewed": state["reviewed"],
"note_count": state["note_count"],
"policy_checked": state["policy_checked"],
"revealed_evidence": sorted(state["revealed_evidence"].keys()),
"resolution_correct": state["resolution_correct"],
"policy_compliant": state["policy_compliant"],
"invalid_actions": state["invalid_actions"],
"redundant_actions": state["redundant_actions"],
"anti_hacking_hits": state["anti_hacking_hits"],
"linked_case_ids": list(case["linked_case_ids"]),
"evidence_coverage": round(coverage, 4),
"workflow_completion": round(workflow_completion, 4),
}