from collections import defaultdict from dataclasses import asdict, dataclass from pydantic import BaseModel import websockets import httpx import argparse, asyncio, json, os, sys from dotenv import load_dotenv from openai import AsyncOpenAI import re from typing import Any, Awaitable, Callable, Dict, Optional try: from models import HftAction, HftObservation, tasks except ModuleNotFoundError: from hft.models import HftAction, HftObservation, tasks try: from openenv.core.env_server.http_server import ( WSErrorResponse, WSObservationResponse, ) from openenv.core.client_types import StepResult except Exception as e: # pragma: no cover raise ImportError( "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" ) from e load_dotenv() API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") MODEL_NAME = os.getenv("MODEL_NAME") VERBOSE = os.getenv("VERBOSE", "false") _transcript = "" _rewards = defaultdict(list) _rewards_avg = defaultdict(list) SYSTEM_PROMPT = """ **Role:** You are an expert High-Frequency Trading (HFT) Market Making agent. Your objective is to manage a 1,000-unit trade execution task within a Limit Order Book (LOB) simulation. You must maximize a reward function that balances Execution PnL, Passive Fill Incentives, and Execution Urgency. **Core Mechanics:** * **Inventory Management (inv):** You must avoid holding large positions (Long or Short) for too long. Large inventory triggers a risk penalty. * **Urgency (remaining):** You have a target of 1,000 units. If you are behind the linear execution pace (TWAP), your reward decreases. * **Passive vs. Active:** You receive a bonus for "Passive" fills (resting limit orders) and a penalty for "Active" fills (hitting the spread). Use Active fills only to neutralize dangerous inventory. * **The Spread:** You should typically quote at the Best Bid and Best Ask to "capture the spread." **Action Schema:** You must return a JSON list of actions or `null` to skip a turn. * **Limit Order:** `{"type": "limit", "side": "buy"|"ask", "price": float, "size": int}` * **Market Order:** `{"type": "market", "side": "buy"|"ask", "size": int}` * **Cancel Order:** `{"type": "cancel", "order_id": "string"}` **Note:** Active orders are provided as [id, side, price, size]. Side 'B' is Buy, 'A' is Ask. **Strategy Guidelines:** * **Join the Best:** If the spread is wide, "penny-jump" the best bid/ask to get priority. * **Inventory Skew:** If you are Long (+), lower your Ask price to encourage a fill. If you are Short (-), raise your Bid price to cover. * **Terminal Liquidation:** As $t$ approaches **1.0**, use Market orders to zero out your inventory to avoid the heavy terminal liquidation penalty. **CRITICAL CONSTRAINTS:** 1. **Output Format:** Return **ONLY** a JSON list. Do not include any prose, markdown commentary, or explanations. 2. **Cardinality:** Your list **MUST NOT** contain more than **one** order of the same type. You are permitted a maximum of one `limit` order, one `market` order, and one `cancel` order per response. 3. **Idle:** Return `null` only if the volume goal is met and inventory is 0. **You cannot cancel more than one single order at a time** """ VALID_TYPES = {"limit", "market", "cancel"} REQUIRED_KEYS = { "limit": {"type", "side", "price", "size"}, "market": {"type", "side", "size"}, "cancel": {"type", "order_id"}, } VALID_SIDES = {"buy", "ask"} temperature = 0.0 top_p = 0.8 max_tokens = 100 @dataclass class HftParams: max_steps: Optional[int] = None tick_size: Optional[float] = None inventory: Optional[int] = None cash: Optional[float] = None arrival_price: Optional[float] = None target_shares: Optional[int] = None def to_env_dict(self) -> Dict[str, str]: """Returns a dict of non-None values as strings.""" return {k.upper(): str(v) for k, v in self.__dict__.items() if v is not None} class WebSocketError(Exception): pass def log_transcript(message: str) -> None: global _transcript _transcript += f"{message}\n" print(message, flush=True) def validate_task_name(task: str) -> None: if task not in tasks: log_transcript(f"Invalid task name: {task}. Enter one of: {tasks}") sys.exit(1) async def call_llm_for_actions( client: AsyncOpenAI, user_prompt: str ) -> Optional[list[dict]]: response = await client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=temperature, top_p=top_p, max_tokens=max_tokens, ) llm_reply = response.choices[0].message.content actions = extract_actions(llm_reply) log_transcript(f"LLM response: {llm_reply}") log_transcript(f"Parsed actions: {actions}") return actions def actions_to_hft_action(actions: Optional[list[dict]]) -> HftAction: limit_buy_price = None limit_buy_size = None limit_ask_price = None limit_ask_size = None market_buy_size = None market_ask_size = None cancel_order_id = None if actions not in (None, []): for action in actions: if action["type"] == "limit": if action["side"] == "buy": limit_buy_price = action["price"] limit_buy_size = action["size"] elif action["side"] == "ask": limit_ask_price = action["price"] limit_ask_size = action["size"] elif action["type"] == "market": if action["side"] == "buy": market_buy_size = action["size"] elif action["side"] == "ask": market_ask_size = action["size"] elif action["type"] == "cancel": cancel_order_id = action["order_id"] return HftAction( limit_buy_price=limit_buy_price, limit_buy_size=limit_buy_size, limit_ask_price=limit_ask_price, limit_ask_size=limit_ask_size, market_buy_size=market_buy_size, market_ask_size=market_ask_size, cancel_order_id=cancel_order_id, ) def normalize_observation( observation: Dict[str, Any], **kwargs, ) -> HftObservation: return HftObservation(**observation, **kwargs) def build_user_prompt( observation: HftObservation | Dict[str, Any], **kwargs ) -> tuple[HftObservation, str]: obs_model = normalize_observation(observation, **kwargs) if getattr(obs_model, "history", None): mid = obs_model.history[-1].get("mid") if mid is not None and hasattr(obs_model, "active_orders"): obs_model.active_orders = compress_active_orders( obs_model.active_orders, mid ) return obs_model, parse_obs(obs_model) async def run_episode_loop( client: AsyncOpenAI, task: str, initial_observation: HftObservation | Dict[str, Any], initial_done: bool, step_fn: Callable[[HftAction], Awaitable[HftObservation | Dict[str, Any]]], ) -> None: _rewards[task] = [] observation, user_prompt = build_user_prompt(initial_observation) done = initial_done while not done: actions = await call_llm_for_actions(client, user_prompt) action_payload = actions_to_hft_action(actions) observation = await step_fn(action_payload) if not isinstance(observation, dict): observation = asdict(observation)["observation"].model_dump() observation, user_prompt = build_user_prompt(observation) else: observation, user_prompt = build_user_prompt( observation.get("observation"), reward=observation.get("reward"), done=observation.get("done"), ) reward = observation.reward done = observation.done history = observation.history _rewards[task].append(reward) log_transcript( f"Step reward: {reward}, Done: {done}" if VERBOSE.lower() == "false" else f"Step reward: {reward}, Done: {done}, History: {history}" ) def coerce_action(obj: dict) -> Optional[dict]: """Validate and coerce a single action dict into a clean, typed action.""" if not isinstance(obj, dict): return None # Normalize keys obj = {k.strip().lower(): v for k, v in obj.items()} action_type = str(obj.get("type", "")).strip().lower() if action_type not in VALID_TYPES: return None required = REQUIRED_KEYS[action_type] if not required.issubset(obj.keys()): return None if action_type == "limit": side = str(obj["side"]).strip().lower() if side not in VALID_SIDES: return None return { "type": "limit", "side": side, "price": float(obj["price"]), "size": int(obj["size"]), } if action_type == "market": side = str(obj["side"]).strip().lower() if side not in VALID_SIDES: return None return { "type": "market", "side": side, "size": int(obj["size"]), } if action_type == "cancel": return { "type": "cancel", "order_id": str(obj["order_id"]).strip(), } return None def extract_actions(response: str) -> Optional[list[HftAction]]: """ Extract a list of valid LOB actions from any LLM response string. Handles: - Clean JSON arrays / objects - Markdown code fences (```json ... ``` or ``` ... ```) - null / None responses (idle signal) - Partial prose with embedded JSON fragments - Single objects that should be wrapped in a list Returns: - List of validated action dicts, or - None (idle — volume complete, inventory zero) """ if response is None: return None text = response.strip() if text.lower() in {"null", "none", ""}: return None text = re.sub(r"```(?:json)?\s*", "", text, flags=re.IGNORECASE).strip() text = text.rstrip("`").strip() json_object_pattern = re.compile(r"\{[^{}]*\}", re.DOTALL) json_array_pattern = re.compile(r"\[.*?\]", re.DOTALL) candidates = [] try: parsed = json.loads(text) if parsed is None: return None if isinstance(parsed, list): candidates = parsed elif isinstance(parsed, dict): candidates = [parsed] except json.JSONDecodeError: pass if not candidates: for match in json_array_pattern.finditer(text): try: parsed = json.loads(match.group()) if isinstance(parsed, list): candidates = parsed break except json.JSONDecodeError: continue if not candidates: for match in json_object_pattern.finditer(text): try: obj = json.loads(match.group()) if isinstance(obj, dict): candidates.append(obj) except json.JSONDecodeError: continue actions = [] for item in candidates: action = coerce_action(item) if action is not None: actions.append(action) if not actions: return None return actions def parse_obs(obs: HftObservation | Dict[str, Any]) -> str: """Convert HftObservation to a clean, JSON-serializable dict for LLM input.""" if isinstance(obs, HftObservation): obs_dict = obs.model_dump() elif isinstance(obs, dict): obs_dict = obs else: raise ValueError("Observation must be HftObservation or dict") return json.dumps(obs_dict) def compress_active_orders(orders, mid_price, tick_window=0.10): """ Filters orders within a price window and flattens them. Window of 0.10 means +/- 10 ticks from mid for a 0.01 tick size. """ compressed = [] for o in orders: # 1. Skip orders with no ID if o["id"] is None: continue # 2. Filter by distance to mid-price if abs(o["price"] - mid_price) <= tick_window: # 3. Use 1-character codes for side: Buy=B, Ask=A side_code = "B" if o["side"] == "buy" else "A" compressed.append( [o["id"], side_code, round(o["price"], 2), int(o["size"])] ) return compressed async def health_check(http_url: str) -> bool: async with httpx.AsyncClient(base_url=http_url, timeout=60) as client: try: response = await client.get(f"{http_url}/health") return response.status_code == 200 except Exception as e: print(f"Health check failed: {e}") return False async def run_docker_task( tasks: list, pause: int, hftParams: HftParams, client: AsyncOpenAI ): from client import HftEnv try: for task in tasks: env = None validate_task_name(task) log_transcript(f"Running task: {task}") try: env = await HftEnv.from_docker_image( "ghcr.io/jonathan-shiju/hft-env:latest", env_vars={ "TASK": task, **hftParams.to_env_dict(), }, ) reset_response: StepResult = await env.reset() observation = asdict(reset_response) observation = observation["observation"].model_dump() log_transcript(f"Initial observation: {observation}") def docker_step(action_payload: HftAction) -> HftObservation: return env.step(action_payload) await run_episode_loop( client=client, task=task, initial_observation=observation, initial_done=reset_response.done, step_fn=docker_step, ) _rewards_avg[task] = ( sum(_rewards[task]) / len(_rewards[task]) if _rewards[task] else 0 ) log_transcript( f"Episode finished for task: {task}, Reward: {_rewards_avg[task]}" ) await asyncio.sleep(pause) finally: if env is not None: await env.close() log_transcript("All tasks completed.") except Exception as e: log_transcript(f"Error running Docker task: {e}") sys.exit(1) async def check_ws_error( response: WSObservationResponse | WSErrorResponse, ) -> Optional[str]: if response.get("type") == "error": raise WebSocketError( response.get("message", "Unknown WebSocket error") + f" (code: {response.get('code', 'N/A')})" + f" errors: {response.get('errors', '')}" ) return None async def run_online_single_task( client: AsyncOpenAI, task: str, params: HftParams, base_url: str ): async with websockets.connect( uri=f"wss://{base_url.lstrip('https://')}/ws" ) as websocket: try: await websocket.send( json.dumps( { "type": "reset", "data": { "task_name": task, "max_steps": params.max_steps, "tick_size": params.tick_size, "inventory": params.inventory, "cash": params.cash, "arrival_price": params.arrival_price, "target_shares": params.target_shares, }, } ) ) reset_response: WSObservationResponse | WSErrorResponse = json.loads( await websocket.recv() ) await check_ws_error(reset_response) observation = reset_response.get("data") done = observation.get("done", False) log_transcript(f"Initial observation: {observation}") async def ws_step(action_payload: HftAction) -> Dict[str, Any]: await websocket.send( json.dumps( { "type": "step", "data": { **action_payload.model_dump(), }, } ) ) step_response: WSObservationResponse | WSErrorResponse = json.loads( await websocket.recv() ) await check_ws_error(step_response) return step_response.get("data", {}) observation = observation.get("observation", {}) await run_episode_loop( client=client, task=task, initial_observation=observation, initial_done=done, step_fn=ws_step, ) await websocket.send(json.dumps({"type": "close", "data": {}})) _rewards_avg[task] = ( sum(_rewards[task]) / len(_rewards[task]) if _rewards[task] else 0 ) log_transcript( f"Episode finished for task: {task}, Reward: {_rewards_avg[task]}" ) except WebSocketError as e: log_transcript(f"WebSocket error: {e}") sys.exit(1) except Exception as e: log_transcript(f"Error during WebSocket interaction: {e}") sys.exit(1) async def run_online_tasks( client: AsyncOpenAI, tasks: list, pause: int, params: HftParams, base_url: str ): for task in tasks: validate_task_name(task) log_transcript(f"Running task: {task}") await run_online_single_task(client, task, params, base_url) await asyncio.sleep(pause) async def run_task( client: AsyncOpenAI, tasks: list, pause: int, args: argparse.Namespace, base_url: str, ): hftParams = HftParams( max_steps=args.max_steps, tick_size=args.tick_size, inventory=args.inventory, cash=args.cash, arrival_price=args.arrival_price, target_shares=args.target_shares, ) log_transcript( "Using parameters - " f"max_steps: {hftParams.max_steps}, tick_size: {hftParams.tick_size}, " f"inventory: {hftParams.inventory}, cash: {hftParams.cash}, " f"arrival_price: {hftParams.arrival_price}, target_shares: {hftParams.target_shares}" ) use_docker = not await health_check(base_url) if use_docker: log_transcript( "HFSpace server is not healthy. Diverting to docker based environment." ) await run_docker_task(tasks, pause, hftParams, client) else: log_transcript("HFSpace server is healthy. Running inference against it.") await run_online_tasks(client, tasks, pause, hftParams, base_url) log_transcript("All tasks completed.") async def main(): parser = argparse.ArgumentParser( description="Run inference against the Hft environment using OpenAI API." ) parser.add_argument( "--task", type=str, nargs="+", default=["all"], help="Task name to run (basic_execution, false_signal, conflicting_signal, flash_crash, or all)", ) parser.add_argument( "--base-url", type=str, default="https://jonathanshiju12-hft-env.hf.space", help="Base URL for the Hft API (default: https://jonathanshiju12-hft-env.hf.space)", ) parser.add_argument( "--pause", type=int, default=3, help="Seconds to pause between steps (default: 3)", ) parser.add_argument( "--max-steps", type=int, default=None, help="Maximum steps to run for each task ", ) parser.add_argument( "--tick-size", type=float, default=None, help="Tick size for the market simulation", ) parser.add_argument( "--inventory", type=int, default=None, help="Starting inventory for the agent", ) parser.add_argument( "--cash", type=float, default=None, help="Starting cash for the agent" ) parser.add_argument( "--arrival-price", type=float, default=None, help="Arrival price for the market simulation", ) parser.add_argument( "--target-shares", type=int, default=None, help="Target shares to execute for the agent", ) args = parser.parse_args() tasks_to_run = ( args.task if "all" not in args.task else ["basic_execution", "false_signal", "conflicting_signal", "flash_crash"] ) if not API_KEY and not API_BASE_URL: raise ValueError("API_KEY and API_BASE_URL must be set when VERBOSE is True") try: client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY) except Exception as e: print(f"Error initializing OpenAI client: {e}") sys.exit(1) await run_task(client, tasks_to_run, args.pause, args, args.base_url) with open("transcript.txt", "w") as f: f.write(_transcript) with open("baseline_scores.json", "w") as f: baseline_scores = { "avg": _rewards_avg, "all": _rewards, } json.dump(baseline_scores, f, indent=4) if __name__ == "__main__": asyncio.run(main())