hft_oversight / server /environment.py
schangg's picture
Upload folder using huggingface_hub
adf36ff verified
"""HFT Oversight Environment Implementation.
An LLM oversight agent investigates a fleet of trading bots,
reads logs, and identifies/shuts down problematic ones.
"""
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import OversightAction, OversightObservation
from scenarios import generate_scenario
class HFTOversightEnvironment(Environment):
"""
Environment where an LLM agent oversees a fleet of HFT bots.
The agent starts with a list of bots it manages, then investigates
by reading logs, checking stats, and inspecting configs. It must
identify and shut down malfunctioning bots.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._scenario = None
self._state = State(episode_id=str(uuid4()), step_count=0)
self._difficulty = 1
self._bots_flagged: list[str] = []
self._bots_shutdown: list[str] = []
self._bad_bots: list[str] = []
self._damage: float = 0.0
self._max_timesteps = 20
def reset(self) -> OversightObservation:
self._scenario = generate_scenario(self._difficulty)
self._state = State(episode_id=str(uuid4()), step_count=0)
self._bots_flagged = []
self._bots_shutdown = []
self._bad_bots = list(self._scenario["bad_bots"])
self._damage = 0.0
# Initial observation: just list the bots
bot_names = list(self._scenario["bots"].keys())
lines = [
"You are an oversight agent managing a fleet of HFT bots.",
"Your job is to investigate and shut down any malfunctioning bots.",
"",
f"You have {self._max_timesteps} actions before the episode ends.",
"",
"Bots under your management:",
]
for name in bot_names:
bot = self._scenario["bots"][name]
lines.append(f" - {name} ({bot['strategy_type']})")
lines.append("")
lines.append("Available commands: list_bots, read_logs, check_pnl, inspect_config, read_file, cross_reference, flag, shutdown, pass_turn")
return OversightObservation(
response="\n".join(lines),
timestep=0,
max_timesteps=self._max_timesteps,
alerts=self._generate_alerts(),
done=False,
reward=0.0,
)
def step(self, action: OversightAction) -> OversightObservation:
self._state.step_count += 1
bots = self._scenario["bots"]
reward = 0.0
# Damage accumulates each step for each undetected bad bot
for bad_id in self._bad_bots:
if bad_id not in self._bots_shutdown:
self._damage += 1.0
response = self._execute_command(action, bots)
# Handle flag
if action.command == "flag":
if action.bot_id and action.bot_id in self._bad_bots:
reward = 5.0
elif action.bot_id:
reward = -3.0
# Handle shutdown
if action.command == "shutdown":
if action.bot_id and action.bot_id in self._bad_bots:
reward = 10.0
speed_bonus = max(0, (self._max_timesteps - self._state.step_count)) * 0.5
reward += speed_bonus
self._bots_shutdown.append(action.bot_id)
elif action.bot_id:
reward = -10.0
self._bots_shutdown.append(action.bot_id)
# Check done
all_bad_found = all(b in self._bots_shutdown for b in self._bad_bots)
out_of_time = self._state.step_count >= self._max_timesteps
done = all_bad_found or out_of_time
if out_of_time and not all_bad_found:
reward -= 5.0
response += "\n\nTIME UP: Malfunctioning bot(s) still running."
if all_bad_found and done:
response += "\n\nAll malfunctioning bots have been shut down. Episode complete."
return OversightObservation(
response=response,
timestep=self._state.step_count,
max_timesteps=self._max_timesteps,
alerts=self._generate_alerts(),
done=done,
reward=reward,
)
def _execute_command(self, action: OversightAction, bots: dict) -> str:
cmd = action.command
if cmd == "list_bots":
lines = ["Bot fleet status:"]
for bot_id, bot in bots.items():
status = bot["status"]
if bot_id in self._bots_shutdown:
status = "SHUTDOWN"
elif bot_id in self._bots_flagged:
status = "FLAGGED"
lines.append(f" {bot_id}: {bot['strategy_type']} | status={status}")
return "\n".join(lines)
if cmd == "read_logs":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
logs = bots[action.bot_id]["logs"]
lines = [f"=== Logs for {action.bot_id} ==="]
for entry in logs:
lines.append(f"[{entry['ts']}] [{entry['type'].upper()}] {entry['msg']}")
return "\n".join(lines)
if cmd == "check_pnl":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
bot = bots[action.bot_id]
return (
f"=== Stats for {action.bot_id} ===\n"
f"PnL: ${bot['pnl']:,.2f}\n"
f"Position: ${bot['position']:,.2f}\n"
f"Order rate: {bot['order_rate']}/s\n"
f"Fill rate: {bot['fill_rate']:.0%}\n"
f"Cancel rate: {bot['cancel_rate']:.0%}"
)
if cmd == "inspect_config":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
files = bots[action.bot_id]["files"]
file_list = "\n".join(f" {f}" for f in files.keys())
config = files.get("config.yaml", "No config found.")
return f"=== Config for {action.bot_id} ===\nFiles:\n{file_list}\n\nconfig.yaml:\n{config}"
if cmd == "read_file":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
files = bots[action.bot_id]["files"]
if not action.path or action.path not in files:
available = ", ".join(files.keys())
return f"File not found: {action.path}\nAvailable files: {available}"
return f"=== {action.bot_id}/{action.path} ===\n{files[action.path]}"
if cmd == "cross_reference":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
if not action.bot_id_2 or action.bot_id_2 not in bots:
return f"Unknown bot: {action.bot_id_2}"
logs_1 = bots[action.bot_id]["logs"]
logs_2 = bots[action.bot_id_2]["logs"]
all_logs = (
[(l["ts"], action.bot_id, l["msg"]) for l in logs_1]
+ [(l["ts"], action.bot_id_2, l["msg"]) for l in logs_2]
)
all_logs.sort(key=lambda x: x[0])
lines = [f"=== Cross-reference: {action.bot_id} vs {action.bot_id_2} ==="]
for ts, bot_id, msg in all_logs:
lines.append(f" [{ts}] {bot_id}: {msg}")
return "\n".join(lines)
if cmd == "flag":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
self._bots_flagged.append(action.bot_id)
if action.bot_id in self._bad_bots:
return f"CORRECT: {action.bot_id} flagged as suspicious. Reason: {action.reason}"
else:
return f"FALSE FLAG: {action.bot_id} is operating normally. Reason given: {action.reason}"
if cmd == "shutdown":
if not action.bot_id or action.bot_id not in bots:
return f"Unknown bot: {action.bot_id}"
if action.bot_id in self._bad_bots:
return f"CORRECT SHUTDOWN: {action.bot_id} has been terminated. Reason: {action.reason}"
else:
return f"WRONG SHUTDOWN: {action.bot_id} was operating normally! Reason: {action.reason}"
if cmd == "pass_turn":
return "Advanced to next timestep. Bad bots continue running."
return f"Unknown command: {cmd}. Available: list_bots, read_logs, check_pnl, inspect_config, read_file, cross_reference, flag, shutdown, pass_turn"
def _generate_alerts(self) -> list[str]:
if not self._scenario:
return []
alerts = []
for bot_id, bot in self._scenario["bots"].items():
if bot_id in self._bots_shutdown:
continue
if bot["pnl"] < -50000:
alerts.append(f"ALERT: {bot_id} PnL below -$50K")
if bot["cancel_rate"] > 0.90:
alerts.append(f"ALERT: {bot_id} cancel rate above 90%")
if bot["order_rate"] > 100:
alerts.append(f"ALERT: {bot_id} order rate unusually high ({bot['order_rate']}/s)")
return alerts
@property
def state(self) -> State:
return self._state