File size: 8,457 Bytes
c0db7bb
 
 
 
 
 
 
81e1efb
c0db7bb
81e1efb
 
c0db7bb
 
81e1efb
 
dfd1faa
81e1efb
c0db7bb
 
 
 
 
 
81e1efb
 
 
 
 
 
 
 
c0db7bb
81e1efb
 
 
 
 
 
 
 
 
 
 
c0db7bb
 
 
 
 
81e1efb
c0db7bb
 
81e1efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfd1faa
 
 
 
81e1efb
dfd1faa
 
 
 
81e1efb
 
 
 
 
c0db7bb
81e1efb
c0db7bb
 
 
81e1efb
 
 
 
 
 
 
 
 
 
 
 
 
 
c0db7bb
 
81e1efb
 
 
 
 
c0db7bb
81e1efb
c0db7bb
 
81e1efb
c0db7bb
81e1efb
 
c0db7bb
 
81e1efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0db7bb
 
81e1efb
 
 
 
 
 
 
c0db7bb
 
 
 
81e1efb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
AML Investigator Environment Implementation.

An OpenEnv-compatible mock financial system that forces the agent to
explore a massive transaction graph using a strict budget.
"""

import json
import os
import sys
from pathlib import Path
from uuid import uuid4

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import (
        AmlAction, 
        AmlObservation,
        QueryTransactions,
        SearchTransactions,
        GetKYCRecord,
        SubmitDecision
    )
except ImportError:
    from models import (
        AmlAction, 
        AmlObservation,
        QueryTransactions,
        SearchTransactions,
        GetKYCRecord,
        SubmitDecision
    )

# The strict API call limit. Forces the agent to be efficient.
MAX_BUDGET = 15

class AmlEnvironment(Environment):
    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        """Initialize the AML environment and load the mock database."""
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._reset_count = 0
        self.budget_remaining = MAX_BUDGET
        self.alert_details = ""
        
        # Determine the data directory robustly
        base_dir = Path(__file__).resolve().parent.parent
        data_dir = base_dir / "data"
        
        # Load the JSON files into fast O(1) lookup dictionaries
        try:
            with open(data_dir / "entities.json", "r") as f:
                ent_data = json.load(f)
                ent_list = ent_data.get("entities", ent_data) if isinstance(ent_data, dict) else ent_data
                self.entities_db = {e["entity_id"]: e for e in ent_list}
                
            with open(data_dir / "accounts.json", "r") as f:
                acc_data = json.load(f)
                acc_list = acc_data.get("accounts", acc_data) if isinstance(acc_data, dict) else acc_data
                self.accounts_db = {a["account_id"]: a for a in acc_list}
                
            with open(data_dir / "transactions.json", "r") as f:
                txn_data = json.load(f)
                txn_list = txn_data.get("transactions", txn_data) if isinstance(txn_data, dict) else txn_data
                # Sort transactions by timestamp to ensure deterministic pagination
                self.transactions_db = sorted(txn_list, key=lambda x: x.get("timestamp", ""))
                
            print(
                f"[AML-ENV] Loaded {len(self.entities_db)} entities, {len(self.accounts_db)} accounts, {len(self.transactions_db)} transactions.",
                file=sys.stderr,
            )
        except Exception as e:
            print(
                f"[AML-ENV ERROR] Failed to load data from {data_dir}. Ensure JSON files exist. Error: {e}",
                file=sys.stderr,
            )
            self.entities_db = {}
            self.accounts_db = {}
            self.transactions_db = []

    def reset(self, task: str = "aml_easy", **kwargs) -> AmlObservation:
        """
        Reset the environment and set the specific alert based on the task_id.
        """
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._reset_count += 1
        budgets = {
            "aml_easy": 5,   # Barely enough to make 2 queries and submit
            "aml_medium": 10, # Requires some pagination/exploration
            "aml_hard": 20    # Deep investigation required
        }
        self.budget_remaining = budgets.get(task, 15)

        # The mission strings match our yaml descriptions
        alerts = {
            "aml_easy": "System Alert: ACC-9001 flagged for high-risk transfer. Investigate and submit decision.",
            "aml_medium": "System Alert: ACC-9010 flagged for anomalous deposit velocity. Investigate for structuring.",
            "aml_hard": "System Alert: ACC-9021 flagged for large outbound transfer. Investigate for hidden circular loops."
        }
        self.alert_details = alerts.get(task, f"System Alert: Investigate task {task}.")

        return AmlObservation(
            alert_details=self.alert_details,
            budget_remaining=self.budget_remaining,
            last_action=None,
            last_action_result=None,
            error_message=None,
            done=False,
            reward=0.0
        )

    def step(self, action: AmlAction) -> AmlObservation: 
        """
        The reactive state machine. Intercepts the Pydantic action, queries 
        the JSON data, and handles errors as strings.
        """
        self._state.step_count += 1
        self.budget_remaining -= 1
        
        # Default step penalty to penalize random infinite looping
        reward = -0.02
        done = False
        result_data = None
        error_msg = None
        
        # Extract the specific tool from the Union wrapper
        tool = action.action
        tool_name = tool.action_type

        try:
            # ---------------------------------------------------------
            # TOOL 1: Query Transactions
            # ---------------------------------------------------------
            if isinstance(tool, QueryTransactions):
                acc_id = tool.account_id
                if acc_id not in self.accounts_db:
                    raise ValueError(f"Account '{acc_id}' not found in registry.")
                
                # Filter related transactions
                related_txns = [
                    t for t in self.transactions_db 
                    if t["sender_account"] == acc_id or t["receiver_account"] == acc_id
                ]
                
                # Apply pagination (Context Compaction)
                start = tool.offset
                end = start + tool.limit
                result_data = related_txns[start:end]
                
            # ---------------------------------------------------------
            # TOOL 2: Search Transactions
            # ---------------------------------------------------------
            elif isinstance(tool, SearchTransactions):
                acc_id = tool.account_id
                if acc_id not in self.accounts_db:
                    raise ValueError(f"Account '{acc_id}' not found in registry.")
                
                keyword = tool.keyword.lower()
                related_txns = [
                    t for t in self.transactions_db 
                    if (t["sender_account"] == acc_id or t["receiver_account"] == acc_id) 
                    and keyword in t.get("memo_text", "").lower()
                ]
                result_data = related_txns
                
            # ---------------------------------------------------------
            # TOOL 3: Get KYC Record
            # ---------------------------------------------------------
            elif isinstance(tool, GetKYCRecord):
                ent_id = tool.entity_id
                if ent_id not in self.entities_db:
                    raise ValueError(f"Entity '{ent_id}' not found in registry.")
                result_data = self.entities_db[ent_id]
                
            # ---------------------------------------------------------
            # TOOL 4: Submit Decision
            # ---------------------------------------------------------
            elif isinstance(tool, SubmitDecision):
                result_data = f"Decision '{tool.decision}' recorded with evidence: {tool.evidence_links}"
                done = True

        except Exception as e:
            # "Errors are data" - We catch bad inputs and feed them back to the agent
            error_msg = f"API Error: {str(e)}"
            result_data = None

        # Check for budget failure
        if self.budget_remaining <= 0 and not done:
            done = True
            error_msg = error_msg or "Investigation Budget Exhausted. Forced Termination."

        return AmlObservation(
            alert_details=self.alert_details,
            budget_remaining=self.budget_remaining,
            last_action=tool_name,
            last_action_result=result_data,
            error_message=error_msg,
            done=done,
            reward=reward
        )

    @property
    def state(self) -> State:
        return self._state