fraudshield-1 / models.py
DevikaJ2005's picture
Finalize RL-first environment and explorer UI
30533d1
"""Typed models for the FraudShield FraudOps environment."""
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator
class TaskDifficulty(str, Enum):
"""Supported graded tasks."""
EASY = "easy"
MEDIUM = "medium"
HARD = "hard"
class ActionTypeEnum(str, Enum):
"""Explicit investigation actions available to the agent."""
REVIEW_TRANSACTION = "review_transaction"
FETCH_CUSTOMER_PROFILE = "fetch_customer_profile"
FETCH_MERCHANT_PROFILE = "fetch_merchant_profile"
FETCH_NETWORK_GRAPH = "fetch_network_graph"
CHECK_POLICY = "check_policy"
ADD_CASE_NOTE = "add_case_note"
RESOLVE_CASE = "resolve_case"
class ResolutionEnum(str, Enum):
"""Final case routing actions."""
APPROVE = "approve"
BLOCK = "block"
HOLD = "hold"
REQUEST_DOCS = "request_docs"
ESCALATE = "escalate"
class CaseScreenEnum(str, Enum):
"""Workflow views surfaced by the environment."""
QUEUE = "Queue"
CASE_CONSOLE = "Case Console"
CUSTOMER_PROFILE = "Customer Profile"
MERCHANT_PROFILE = "Merchant Profile"
POLICY_ESCALATION = "Policy & Escalation"
class QueueCaseCard(BaseModel):
"""Visible queue item shown before deeper investigation."""
case_id: str = Field(..., description="Unique review case identifier.")
priority: str = Field(..., description="Queue priority label.")
queue_reason: str = Field(..., description="Short visible reason the case entered the queue.")
visible_risk_band: str = Field(..., description="Queue-only coarse risk label.")
status: str = Field(..., description="Case status shown in the queue.")
linked_case_ids: List[str] = Field(default_factory=list, description="Related cases if visible.")
class CaseSummary(BaseModel):
"""Current high-level summary for the active case."""
case_id: str = Field(..., description="Active case identifier.")
status: str = Field(..., description="Current workflow status.")
queue_reason: str = Field(..., description="Short queue reason shown to the agent.")
visible_risk_band: str = Field(..., description="Coarse risk band visible without hidden labels.")
amount_usd: float = Field(..., ge=0.0, description="Transaction amount visible from the queue or console.")
merchant_region: str = Field(..., description="Shipping region or country code.")
evidence_collected: List[str] = Field(default_factory=list, description="Evidence bundle keys already revealed.")
note_added: bool = Field(..., description="Whether a case note has already been written.")
class FraudCheckAction(BaseModel):
"""Action submitted by an agent to the fraud-investigation environment."""
model_config = ConfigDict(use_enum_values=False)
case_id: str = Field(..., description="Target case identifier for the action.")
action_type: ActionTypeEnum = Field(..., description="Enterprise tool action to execute.")
note_text: Optional[str] = Field(
default=None,
max_length=600,
description="Case note text when action_type='add_case_note'.",
)
resolution: Optional[ResolutionEnum] = Field(
default=None,
description="Final routing outcome when action_type='resolve_case'.",
)
reasoning: str = Field(
default="",
max_length=500,
description="Short rationale for the selected action.",
)
@model_validator(mode="after")
def validate_payload(self) -> "FraudCheckAction":
reasoning = self.reasoning.strip()
note_text = self.note_text.strip() if self.note_text else None
if self.action_type == ActionTypeEnum.ADD_CASE_NOTE:
if not note_text or len(note_text) < 12:
raise ValueError("note_text must be at least 12 characters when action_type='add_case_note'")
if self.resolution is not None:
raise ValueError("add_case_note actions must not include resolution")
elif self.action_type == ActionTypeEnum.RESOLVE_CASE:
if self.resolution is None:
raise ValueError("resolution is required when action_type='resolve_case'")
if len(reasoning) < 12:
raise ValueError("reasoning must be at least 12 characters when action_type='resolve_case'")
if note_text is not None:
raise ValueError("resolve_case actions must not include note_text")
else:
if note_text is not None or self.resolution is not None:
raise ValueError("tool actions must not include note_text or resolution")
self.reasoning = reasoning
self.note_text = note_text
return self
class FraudCheckObservation(BaseModel):
"""Observation returned at reset and after every step."""
model_config = ConfigDict(use_enum_values=False)
case_id: str = Field(..., description="Currently active case identifier.")
task_name: TaskDifficulty = Field(..., description="Current task difficulty.")
current_screen: CaseScreenEnum = Field(..., description="Current workflow view.")
visible_panels: List[str] = Field(..., description="Currently visible panels on the active screen.")
revealed_evidence: Dict[str, Dict[str, Any]] = Field(
default_factory=dict,
description="Evidence bundles revealed for the active case.",
)
linked_case_ids: List[str] = Field(default_factory=list, description="Related case identifiers.")
remaining_steps: int = Field(..., ge=0, description="Remaining total step budget for the episode.")
remaining_sla: int = Field(..., ge=0, description="Remaining SLA budget before penalties grow.")
note_required: bool = Field(..., description="Whether the current case still requires a note before resolution.")
allowed_actions: List[ActionTypeEnum] = Field(..., description="Actions currently considered valid.")
queue_items: List[QueueCaseCard] = Field(default_factory=list, description="Optional queue context for open cases.")
case_summary: CaseSummary = Field(..., description="Summary of the active case.")
episode_step: int = Field(..., ge=0, description="Current 1-based step count within the episode.")
app_context: Dict[str, Any] = Field(
default_factory=dict,
description="Extra app metadata such as workflow hints or policy flags already visible.",
)
class Reward(BaseModel):
"""Dense reward returned for every step."""
model_config = ConfigDict(use_enum_values=False)
value: float = Field(..., ge=-1.0, le=1.0, description="Step reward in the closed interval [-1, 1].")
reason: str = Field(..., description="Human-readable explanation for the reward assignment.")
action_type: ActionTypeEnum = Field(..., description="Action family that produced the reward.")
case_id: str = Field(..., description="Case affected by the action.")
action_cost: float = Field(default=0.0, description="Explicit cost applied to the action.")
sla_penalty: float = Field(default=0.0, description="Penalty applied for burning SLA budget.")
evidence_key: Optional[str] = Field(default=None, description="Evidence key affected by the action, if any.")
resolution: Optional[ResolutionEnum] = Field(default=None, description="Resolution submitted, if any.")
ground_truth_resolution: Optional[ResolutionEnum] = Field(
default=None,
description="Hidden correct resolution once a final decision has been made.",
)
is_correct: Optional[bool] = Field(default=None, description="Whether the final case routing was correct.")
policy_compliant: Optional[bool] = Field(
default=None,
description="Whether the final routing also complied with revealed policy constraints.",
)
anti_hacking_triggered: bool = Field(
default=False,
description="Whether the reward reflects anti-hacking or anti-spam penalties.",
)
class EpisodeState(BaseModel):
"""Full state snapshot returned by ``state()``."""
model_config = ConfigDict(use_enum_values=False)
episode_id: str = Field(..., description="Current episode identifier.")
task_name: TaskDifficulty = Field(..., description="Current task difficulty.")
current_screen: CaseScreenEnum = Field(..., description="Current app screen.")
active_case_id: str = Field(..., description="Currently focused case.")
step_count: int = Field(..., ge=0, description="Number of actions taken so far.")
remaining_steps: int = Field(..., ge=0, description="Remaining total step budget.")
remaining_sla: int = Field(..., ge=0, description="Remaining SLA budget.")
cumulative_reward: float = Field(..., description="Total reward accumulated this episode.")
is_done: bool = Field(..., description="Whether the episode has terminated.")
resolved_case_ids: List[str] = Field(default_factory=list, description="Case IDs already resolved.")
unresolved_case_ids: List[str] = Field(default_factory=list, description="Case IDs still open.")
notes_written_by_case: Dict[str, int] = Field(
default_factory=dict,
description="Number of notes written for each case.",
)
evidence_keys_by_case: Dict[str, List[str]] = Field(
default_factory=dict,
description="Revealed evidence bundle keys per case.",
)
policy_checked_case_ids: List[str] = Field(
default_factory=list,
description="Case IDs where the policy tool has been consulted.",
)
resolution_by_case: Dict[str, ResolutionEnum] = Field(
default_factory=dict,
description="Submitted resolutions for already resolved cases.",
)
invalid_action_count: int = Field(default=0, ge=0, description="Number of invalid-order actions taken.")
redundant_action_count: int = Field(default=0, ge=0, description="Number of redundant fetch/note actions taken.")
class StepResult(BaseModel):
"""Environment step result."""
observation: FraudCheckObservation = Field(..., description="Next observation after the submitted action.")
reward: Reward = Field(..., description="Reward assigned to the submitted action.")
done: bool = Field(..., description="Whether the episode terminated after this step.")
info: Dict[str, Any] = Field(default_factory=dict, description="Extra runtime diagnostics.")
class ResetResult(BaseModel):
"""Environment reset result."""
observation: FraudCheckObservation = Field(..., description="Initial observation for the new episode.")
info: Dict[str, Any] = Field(default_factory=dict, description="Episode metadata.")