"""FraudShield environment implementation.""" from __future__ import annotations import uuid from datetime import datetime from typing import Any, Dict, List from data_loader import FraudDataLoader from models import ( DecisionEnum, EpisodeState, FraudCheckAction, FraudCheckObservation, ResetResult, Reward, StepResult, TaskDifficulty, TransactionData, ) class FraudShieldEnvironment: """OpenEnv-compatible environment for e-commerce fraud review.""" def __init__(self, data_path: str = "data", seed: int = 42): self.seed = seed self.data_loader = FraudDataLoader(data_path=data_path, seed=seed) self.data_loaded = False self.episode_id = "" self.current_task = TaskDifficulty.EASY self.step_count = 0 self.current_transaction_idx = 0 self.cumulative_reward = 0.0 self.correct_predictions = 0 self.is_done = False self.current_cases: List[Dict[str, Any]] = [] self.ground_truth_labels: List[str] = [] self.predictions: List[str] = [] self.confidences: List[float] = [] self.max_steps = { TaskDifficulty.EASY: 24, TaskDifficulty.MEDIUM: 36, TaskDifficulty.HARD: 48, } def load_data(self) -> bool: """Load the committed snapshot or rebuild it from the local public source CSV.""" self.data_loaded = self.data_loader.load_data() return self.data_loaded def load_kaggle_data(self) -> bool: """Backward-compatible wrapper for the previous method name.""" return self.load_data() def ensure_data_loaded(self) -> None: """Load data on demand so server startup can stay simple.""" if not self.data_loaded and not self.load_data(): raise RuntimeError("FraudShield data bundle could not be loaded.") def reset(self, task: str = "easy") -> ResetResult: """Start a fresh episode for a given task difficulty.""" self.ensure_data_loaded() self.episode_id = f"ep_{uuid.uuid4().hex[:8]}" self.current_task = TaskDifficulty(task) self.step_count = 0 self.current_transaction_idx = 0 self.cumulative_reward = 0.0 self.correct_predictions = 0 self.is_done = False self.predictions = [] self.confidences = [] self.current_cases = self.data_loader.get_task_cases(task) self.ground_truth_labels = [case["label"] for case in self.current_cases] self.max_steps[self.current_task] = len(self.current_cases) observation = self._get_observation() info = { "episode_id": self.episode_id, "task": task, "task_focus": observation.historical_context.get("task_focus") if observation.historical_context else None, "data_snapshot": self.data_loader.get_bundle_summary(), "max_steps": self.max_steps[self.current_task], "num_transactions": len(self.current_cases), "fraud_count": sum(1 for label in self.ground_truth_labels if label == "fraud"), "legitimate_count": sum(1 for label in self.ground_truth_labels if label == "legitimate"), } return ResetResult(observation=observation, info=info) def step(self, action: FraudCheckAction) -> StepResult: """Evaluate one agent action and return the next observation.""" if self.is_done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") current_case = self.current_cases[self.current_transaction_idx] expected_transaction_id = current_case["transaction_id"] wrong_transaction_id = action.transaction_id != expected_transaction_id ground_truth = current_case["label"] risk_score = float(current_case["risk_score"]) business_cost = float(current_case["business_cost"]) predicted_label = action.decision.value is_correct = predicted_label == ground_truth and not wrong_transaction_id reward_value, confidence_penalty, reward_reason = self._calculate_reward( predicted_label=predicted_label, ground_truth=ground_truth, confidence=action.confidence, risk_score=risk_score, business_cost=business_cost, wrong_transaction_id=wrong_transaction_id, ) if is_correct: self.correct_predictions += 1 self.predictions.append(predicted_label) self.confidences.append(action.confidence) self.cumulative_reward += reward_value self.step_count += 1 self.current_transaction_idx += 1 self.is_done = self.current_transaction_idx >= len(self.current_cases) reward = Reward( value=reward_value, reason=reward_reason, is_correct=is_correct, ground_truth=DecisionEnum(ground_truth), confidence_penalty=confidence_penalty, business_impact=business_cost, ) observation = self._get_terminal_observation() if self.is_done else self._get_observation() info = { "step": self.step_count, "accuracy_so_far": round(self.correct_predictions / self.step_count, 4), "cumulative_reward": round(self.cumulative_reward, 4), "expected_transaction_id": expected_transaction_id, "wrong_transaction_id": wrong_transaction_id, "risk_score": risk_score, "business_cost": business_cost, } return StepResult(observation=observation, reward=reward, done=self.is_done, info=info) def state(self) -> EpisodeState: """Return the current episode state.""" return EpisodeState( episode_id=self.episode_id, task_name=self.current_task, step_count=self.step_count, transactions_evaluated=self.current_transaction_idx, cumulative_reward=self.cumulative_reward, correct_predictions=self.correct_predictions, is_done=self.is_done, max_steps=self.max_steps[self.current_task], ) def _calculate_reward( self, predicted_label: str, ground_truth: str, confidence: float, risk_score: float, business_cost: float, wrong_transaction_id: bool, ) -> tuple[float, float, str]: """Apply dense reward shaping with business-cost sensitivity.""" is_fraud_case = ground_truth == "fraud" predicted_fraud = predicted_label == "fraud" if is_fraud_case and predicted_fraud: base_reward = 0.68 + (0.16 * business_cost) elif not is_fraud_case and not predicted_fraud: base_reward = 0.54 + (0.06 * (1.2 - min(business_cost, 1.2))) elif is_fraud_case and not predicted_fraud: base_reward = -0.72 - (0.14 * business_cost) else: base_reward = -0.48 - (0.08 * business_cost) target_confidence = risk_score if is_fraud_case else (1.0 - risk_score) confidence_penalty = 0.12 - abs(confidence - target_confidence) * 0.24 if predicted_label != ground_truth: confidence_penalty -= 0.04 + (confidence * 0.06) if wrong_transaction_id: confidence_penalty -= 0.10 reward_value = max(-1.0, min(1.0, base_reward + confidence_penalty)) reason_bits = [ f"predicted={predicted_label}", f"actual={ground_truth}", f"target_confidence={target_confidence:.2f}", ] if wrong_transaction_id: reason_bits.append("action referenced the wrong transaction_id") reward_reason = ", ".join(reason_bits) return reward_value, confidence_penalty, reward_reason def _get_observation(self) -> FraudCheckObservation: """Return the current task observation.""" current_case = self.current_cases[self.current_transaction_idx] return FraudCheckObservation( transaction_id=current_case["transaction_id"], transaction_data=TransactionData(**current_case["transaction_data"]), task_name=self.current_task, episode_step=self.step_count + 1, historical_context=current_case["historical_context"], ) def _get_terminal_observation(self) -> FraudCheckObservation: """Return a terminal observation once the episode completes.""" terminal_transaction = TransactionData( amount=0.0, seller_id="TERMINAL", buyer_id="TERMINAL", item_category="none", item_price=0.0, shipping_address="XX", seller_account_age_days=0, buyer_account_age_days=0, payment_method="none", device_country="XX", timestamp=datetime.utcnow().isoformat(), is_repeat_buyer=False, seller_avg_rating=0.0, num_seller_reviews=0, previous_fraud_flags=0, shipping_speed="none", amount_percentile=0.0, seller_chargeback_rate_30d=0.0, buyer_disputes_90d=0, shared_device_accounts_24h=0, same_address_orders_24h=0, ) return FraudCheckObservation( transaction_id="TERMINAL", transaction_data=terminal_transaction, task_name=self.current_task, episode_step=max(1, self.step_count), historical_context={"episode_done": True}, )