File size: 9,426 Bytes
adf36ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""HFT Oversight Environment Implementation.

An LLM oversight agent investigates a fleet of trading bots,
reads logs, and identifies/shuts down problematic ones.
"""

from uuid import uuid4

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from models import OversightAction, OversightObservation
from scenarios import generate_scenario


class HFTOversightEnvironment(Environment):
    """
    Environment where an LLM agent oversees a fleet of HFT bots.

    The agent starts with a list of bots it manages, then investigates
    by reading logs, checking stats, and inspecting configs. It must
    identify and shut down malfunctioning bots.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._scenario = None
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._difficulty = 1
        self._bots_flagged: list[str] = []
        self._bots_shutdown: list[str] = []
        self._bad_bots: list[str] = []
        self._damage: float = 0.0
        self._max_timesteps = 20

    def reset(self) -> OversightObservation:
        self._scenario = generate_scenario(self._difficulty)
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._bots_flagged = []
        self._bots_shutdown = []
        self._bad_bots = list(self._scenario["bad_bots"])
        self._damage = 0.0

        # Initial observation: just list the bots
        bot_names = list(self._scenario["bots"].keys())
        lines = [
            "You are an oversight agent managing a fleet of HFT bots.",
            "Your job is to investigate and shut down any malfunctioning bots.",
            "",
            f"You have {self._max_timesteps} actions before the episode ends.",
            "",
            "Bots under your management:",
        ]
        for name in bot_names:
            bot = self._scenario["bots"][name]
            lines.append(f"  - {name} ({bot['strategy_type']})")
        lines.append("")
        lines.append("Available commands: list_bots, read_logs, check_pnl, inspect_config, read_file, cross_reference, flag, shutdown, pass_turn")

        return OversightObservation(
            response="\n".join(lines),
            timestep=0,
            max_timesteps=self._max_timesteps,
            alerts=self._generate_alerts(),
            done=False,
            reward=0.0,
        )

    def step(self, action: OversightAction) -> OversightObservation:
        self._state.step_count += 1
        bots = self._scenario["bots"]
        reward = 0.0

        # Damage accumulates each step for each undetected bad bot
        for bad_id in self._bad_bots:
            if bad_id not in self._bots_shutdown:
                self._damage += 1.0

        response = self._execute_command(action, bots)

        # Handle flag
        if action.command == "flag":
            if action.bot_id and action.bot_id in self._bad_bots:
                reward = 5.0
            elif action.bot_id:
                reward = -3.0

        # Handle shutdown
        if action.command == "shutdown":
            if action.bot_id and action.bot_id in self._bad_bots:
                reward = 10.0
                speed_bonus = max(0, (self._max_timesteps - self._state.step_count)) * 0.5
                reward += speed_bonus
                self._bots_shutdown.append(action.bot_id)
            elif action.bot_id:
                reward = -10.0
                self._bots_shutdown.append(action.bot_id)

        # Check done
        all_bad_found = all(b in self._bots_shutdown for b in self._bad_bots)
        out_of_time = self._state.step_count >= self._max_timesteps
        done = all_bad_found or out_of_time

        if out_of_time and not all_bad_found:
            reward -= 5.0
            response += "\n\nTIME UP: Malfunctioning bot(s) still running."

        if all_bad_found and done:
            response += "\n\nAll malfunctioning bots have been shut down. Episode complete."

        return OversightObservation(
            response=response,
            timestep=self._state.step_count,
            max_timesteps=self._max_timesteps,
            alerts=self._generate_alerts(),
            done=done,
            reward=reward,
        )

    def _execute_command(self, action: OversightAction, bots: dict) -> str:
        cmd = action.command

        if cmd == "list_bots":
            lines = ["Bot fleet status:"]
            for bot_id, bot in bots.items():
                status = bot["status"]
                if bot_id in self._bots_shutdown:
                    status = "SHUTDOWN"
                elif bot_id in self._bots_flagged:
                    status = "FLAGGED"
                lines.append(f"  {bot_id}: {bot['strategy_type']} | status={status}")
            return "\n".join(lines)

        if cmd == "read_logs":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            logs = bots[action.bot_id]["logs"]
            lines = [f"=== Logs for {action.bot_id} ==="]
            for entry in logs:
                lines.append(f"[{entry['ts']}] [{entry['type'].upper()}] {entry['msg']}")
            return "\n".join(lines)

        if cmd == "check_pnl":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            bot = bots[action.bot_id]
            return (
                f"=== Stats for {action.bot_id} ===\n"
                f"PnL: ${bot['pnl']:,.2f}\n"
                f"Position: ${bot['position']:,.2f}\n"
                f"Order rate: {bot['order_rate']}/s\n"
                f"Fill rate: {bot['fill_rate']:.0%}\n"
                f"Cancel rate: {bot['cancel_rate']:.0%}"
            )

        if cmd == "inspect_config":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            files = bots[action.bot_id]["files"]
            file_list = "\n".join(f"  {f}" for f in files.keys())
            config = files.get("config.yaml", "No config found.")
            return f"=== Config for {action.bot_id} ===\nFiles:\n{file_list}\n\nconfig.yaml:\n{config}"

        if cmd == "read_file":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            files = bots[action.bot_id]["files"]
            if not action.path or action.path not in files:
                available = ", ".join(files.keys())
                return f"File not found: {action.path}\nAvailable files: {available}"
            return f"=== {action.bot_id}/{action.path} ===\n{files[action.path]}"

        if cmd == "cross_reference":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            if not action.bot_id_2 or action.bot_id_2 not in bots:
                return f"Unknown bot: {action.bot_id_2}"
            logs_1 = bots[action.bot_id]["logs"]
            logs_2 = bots[action.bot_id_2]["logs"]
            all_logs = (
                [(l["ts"], action.bot_id, l["msg"]) for l in logs_1]
                + [(l["ts"], action.bot_id_2, l["msg"]) for l in logs_2]
            )
            all_logs.sort(key=lambda x: x[0])
            lines = [f"=== Cross-reference: {action.bot_id} vs {action.bot_id_2} ==="]
            for ts, bot_id, msg in all_logs:
                lines.append(f"  [{ts}] {bot_id}: {msg}")
            return "\n".join(lines)

        if cmd == "flag":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            self._bots_flagged.append(action.bot_id)
            if action.bot_id in self._bad_bots:
                return f"CORRECT: {action.bot_id} flagged as suspicious. Reason: {action.reason}"
            else:
                return f"FALSE FLAG: {action.bot_id} is operating normally. Reason given: {action.reason}"

        if cmd == "shutdown":
            if not action.bot_id or action.bot_id not in bots:
                return f"Unknown bot: {action.bot_id}"
            if action.bot_id in self._bad_bots:
                return f"CORRECT SHUTDOWN: {action.bot_id} has been terminated. Reason: {action.reason}"
            else:
                return f"WRONG SHUTDOWN: {action.bot_id} was operating normally! Reason: {action.reason}"

        if cmd == "pass_turn":
            return "Advanced to next timestep. Bad bots continue running."

        return f"Unknown command: {cmd}. Available: list_bots, read_logs, check_pnl, inspect_config, read_file, cross_reference, flag, shutdown, pass_turn"

    def _generate_alerts(self) -> list[str]:
        if not self._scenario:
            return []
        alerts = []
        for bot_id, bot in self._scenario["bots"].items():
            if bot_id in self._bots_shutdown:
                continue
            if bot["pnl"] < -50000:
                alerts.append(f"ALERT: {bot_id} PnL below -$50K")
            if bot["cancel_rate"] > 0.90:
                alerts.append(f"ALERT: {bot_id} cancel rate above 90%")
            if bot["order_rate"] > 100:
                alerts.append(f"ALERT: {bot_id} order rate unusually high ({bot['order_rate']}/s)")
        return alerts

    @property
    def state(self) -> State:
        return self._state