Spaces:
Runtime error
Runtime error
| """HFT Oversight Environment Implementation. | |
| An LLM oversight agent investigates a fleet of trading bots, | |
| reads logs, and identifies/shuts down problematic ones. | |
| """ | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import OversightAction, OversightObservation | |
| from scenarios import generate_scenario | |
| class HFTOversightEnvironment(Environment): | |
| """ | |
| Environment where an LLM agent oversees a fleet of HFT bots. | |
| The agent starts with a list of bots it manages, then investigates | |
| by reading logs, checking stats, and inspecting configs. It must | |
| identify and shut down malfunctioning bots. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._scenario = None | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._difficulty = 1 | |
| self._bots_flagged: list[str] = [] | |
| self._bots_shutdown: list[str] = [] | |
| self._bad_bots: list[str] = [] | |
| self._damage: float = 0.0 | |
| self._max_timesteps = 20 | |
| def reset(self) -> OversightObservation: | |
| self._scenario = generate_scenario(self._difficulty) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._bots_flagged = [] | |
| self._bots_shutdown = [] | |
| self._bad_bots = list(self._scenario["bad_bots"]) | |
| self._damage = 0.0 | |
| # Initial observation: just list the bots | |
| bot_names = list(self._scenario["bots"].keys()) | |
| lines = [ | |
| "You are an oversight agent managing a fleet of HFT bots.", | |
| "Your job is to investigate and shut down any malfunctioning bots.", | |
| "", | |
| f"You have {self._max_timesteps} actions before the episode ends.", | |
| "", | |
| "Bots under your management:", | |
| ] | |
| for name in bot_names: | |
| bot = self._scenario["bots"][name] | |
| lines.append(f" - {name} ({bot['strategy_type']})") | |
| lines.append("") | |
| lines.append("Available commands: list_bots, read_logs, check_pnl, inspect_config, read_file, cross_reference, flag, shutdown, pass_turn") | |
| return OversightObservation( | |
| response="\n".join(lines), | |
| timestep=0, | |
| max_timesteps=self._max_timesteps, | |
| alerts=self._generate_alerts(), | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: OversightAction) -> OversightObservation: | |
| self._state.step_count += 1 | |
| bots = self._scenario["bots"] | |
| reward = 0.0 | |
| # Damage accumulates each step for each undetected bad bot | |
| for bad_id in self._bad_bots: | |
| if bad_id not in self._bots_shutdown: | |
| self._damage += 1.0 | |
| response = self._execute_command(action, bots) | |
| # Handle flag | |
| if action.command == "flag": | |
| if action.bot_id and action.bot_id in self._bad_bots: | |
| reward = 5.0 | |
| elif action.bot_id: | |
| reward = -3.0 | |
| # Handle shutdown | |
| if action.command == "shutdown": | |
| if action.bot_id and action.bot_id in self._bad_bots: | |
| reward = 10.0 | |
| speed_bonus = max(0, (self._max_timesteps - self._state.step_count)) * 0.5 | |
| reward += speed_bonus | |
| self._bots_shutdown.append(action.bot_id) | |
| elif action.bot_id: | |
| reward = -10.0 | |
| self._bots_shutdown.append(action.bot_id) | |
| # Check done | |
| all_bad_found = all(b in self._bots_shutdown for b in self._bad_bots) | |
| out_of_time = self._state.step_count >= self._max_timesteps | |
| done = all_bad_found or out_of_time | |
| if out_of_time and not all_bad_found: | |
| reward -= 5.0 | |
| response += "\n\nTIME UP: Malfunctioning bot(s) still running." | |
| if all_bad_found and done: | |
| response += "\n\nAll malfunctioning bots have been shut down. Episode complete." | |
| return OversightObservation( | |
| response=response, | |
| timestep=self._state.step_count, | |
| max_timesteps=self._max_timesteps, | |
| alerts=self._generate_alerts(), | |
| done=done, | |
| reward=reward, | |
| ) | |
| def _execute_command(self, action: OversightAction, bots: dict) -> str: | |
| cmd = action.command | |
| if cmd == "list_bots": | |
| lines = ["Bot fleet status:"] | |
| for bot_id, bot in bots.items(): | |
| status = bot["status"] | |
| if bot_id in self._bots_shutdown: | |
| status = "SHUTDOWN" | |
| elif bot_id in self._bots_flagged: | |
| status = "FLAGGED" | |
| lines.append(f" {bot_id}: {bot['strategy_type']} | status={status}") | |
| return "\n".join(lines) | |
| if cmd == "read_logs": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| logs = bots[action.bot_id]["logs"] | |
| lines = [f"=== Logs for {action.bot_id} ==="] | |
| for entry in logs: | |
| lines.append(f"[{entry['ts']}] [{entry['type'].upper()}] {entry['msg']}") | |
| return "\n".join(lines) | |
| if cmd == "check_pnl": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| bot = bots[action.bot_id] | |
| return ( | |
| f"=== Stats for {action.bot_id} ===\n" | |
| f"PnL: ${bot['pnl']:,.2f}\n" | |
| f"Position: ${bot['position']:,.2f}\n" | |
| f"Order rate: {bot['order_rate']}/s\n" | |
| f"Fill rate: {bot['fill_rate']:.0%}\n" | |
| f"Cancel rate: {bot['cancel_rate']:.0%}" | |
| ) | |
| if cmd == "inspect_config": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| files = bots[action.bot_id]["files"] | |
| file_list = "\n".join(f" {f}" for f in files.keys()) | |
| config = files.get("config.yaml", "No config found.") | |
| return f"=== Config for {action.bot_id} ===\nFiles:\n{file_list}\n\nconfig.yaml:\n{config}" | |
| if cmd == "read_file": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| files = bots[action.bot_id]["files"] | |
| if not action.path or action.path not in files: | |
| available = ", ".join(files.keys()) | |
| return f"File not found: {action.path}\nAvailable files: {available}" | |
| return f"=== {action.bot_id}/{action.path} ===\n{files[action.path]}" | |
| if cmd == "cross_reference": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| if not action.bot_id_2 or action.bot_id_2 not in bots: | |
| return f"Unknown bot: {action.bot_id_2}" | |
| logs_1 = bots[action.bot_id]["logs"] | |
| logs_2 = bots[action.bot_id_2]["logs"] | |
| all_logs = ( | |
| [(l["ts"], action.bot_id, l["msg"]) for l in logs_1] | |
| + [(l["ts"], action.bot_id_2, l["msg"]) for l in logs_2] | |
| ) | |
| all_logs.sort(key=lambda x: x[0]) | |
| lines = [f"=== Cross-reference: {action.bot_id} vs {action.bot_id_2} ==="] | |
| for ts, bot_id, msg in all_logs: | |
| lines.append(f" [{ts}] {bot_id}: {msg}") | |
| return "\n".join(lines) | |
| if cmd == "flag": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| self._bots_flagged.append(action.bot_id) | |
| if action.bot_id in self._bad_bots: | |
| return f"CORRECT: {action.bot_id} flagged as suspicious. Reason: {action.reason}" | |
| else: | |
| return f"FALSE FLAG: {action.bot_id} is operating normally. Reason given: {action.reason}" | |
| if cmd == "shutdown": | |
| if not action.bot_id or action.bot_id not in bots: | |
| return f"Unknown bot: {action.bot_id}" | |
| if action.bot_id in self._bad_bots: | |
| return f"CORRECT SHUTDOWN: {action.bot_id} has been terminated. Reason: {action.reason}" | |
| else: | |
| return f"WRONG SHUTDOWN: {action.bot_id} was operating normally! Reason: {action.reason}" | |
| if cmd == "pass_turn": | |
| return "Advanced to next timestep. Bad bots continue running." | |
| return f"Unknown command: {cmd}. Available: list_bots, read_logs, check_pnl, inspect_config, read_file, cross_reference, flag, shutdown, pass_turn" | |
| def _generate_alerts(self) -> list[str]: | |
| if not self._scenario: | |
| return [] | |
| alerts = [] | |
| for bot_id, bot in self._scenario["bots"].items(): | |
| if bot_id in self._bots_shutdown: | |
| continue | |
| if bot["pnl"] < -50000: | |
| alerts.append(f"ALERT: {bot_id} PnL below -$50K") | |
| if bot["cancel_rate"] > 0.90: | |
| alerts.append(f"ALERT: {bot_id} cancel rate above 90%") | |
| if bot["order_rate"] > 100: | |
| alerts.append(f"ALERT: {bot_id} order rate unusually high ({bot['order_rate']}/s)") | |
| return alerts | |
| def state(self) -> State: | |
| return self._state | |