hft-env / Inference.py
JonathanShiju12's picture
Upload folder using huggingface_hub
5834de0 verified
from collections import defaultdict
from dataclasses import asdict, dataclass
from pydantic import BaseModel
import websockets
import httpx
import argparse, asyncio, json, os, sys
from dotenv import load_dotenv
from openai import AsyncOpenAI
import re
from typing import Any, Awaitable, Callable, Dict, Optional
try:
from models import HftAction, HftObservation, tasks
except ModuleNotFoundError:
from hft.models import HftAction, HftObservation, tasks
try:
from openenv.core.env_server.http_server import (
WSErrorResponse,
WSObservationResponse,
)
from openenv.core.client_types import StepResult
except Exception as e: # pragma: no cover
raise ImportError(
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
) from e
load_dotenv()
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME")
VERBOSE = os.getenv("VERBOSE", "false")
_transcript = ""
_rewards = defaultdict(list)
_rewards_avg = defaultdict(list)
SYSTEM_PROMPT = """
**Role:** You are an expert High-Frequency Trading (HFT) Market Making agent. Your objective is to manage a 1,000-unit trade execution task within a Limit Order Book (LOB) simulation. You must maximize a reward function that balances Execution PnL, Passive Fill Incentives, and Execution Urgency.
**Core Mechanics:**
* **Inventory Management (inv):** You must avoid holding large positions (Long or Short) for too long. Large inventory triggers a risk penalty.
* **Urgency (remaining):** You have a target of 1,000 units. If you are behind the linear execution pace (TWAP), your reward decreases.
* **Passive vs. Active:** You receive a bonus for "Passive" fills (resting limit orders) and a penalty for "Active" fills (hitting the spread). Use Active fills only to neutralize dangerous inventory.
* **The Spread:** You should typically quote at the Best Bid and Best Ask to "capture the spread."
**Action Schema:**
You must return a JSON list of actions or `null` to skip a turn.
* **Limit Order:** `{"type": "limit", "side": "buy"|"ask", "price": float, "size": int}`
* **Market Order:** `{"type": "market", "side": "buy"|"ask", "size": int}`
* **Cancel Order:** `{"type": "cancel", "order_id": "string"}`
**Note:** Active orders are provided as [id, side, price, size]. Side 'B' is Buy, 'A' is Ask.
**Strategy Guidelines:**
* **Join the Best:** If the spread is wide, "penny-jump" the best bid/ask to get priority.
* **Inventory Skew:** If you are Long (+), lower your Ask price to encourage a fill. If you are Short (-), raise your Bid price to cover.
* **Terminal Liquidation:** As $t$ approaches **1.0**, use Market orders to zero out your inventory to avoid the heavy terminal liquidation penalty.
**CRITICAL CONSTRAINTS:**
1. **Output Format:** Return **ONLY** a JSON list. Do not include any prose, markdown commentary, or explanations.
2. **Cardinality:** Your list **MUST NOT** contain more than **one** order of the same type. You are permitted a maximum of one `limit` order, one `market` order, and one `cancel` order per response.
3. **Idle:** Return `null` only if the volume goal is met and inventory is 0.
**You cannot cancel more than one single order at a time**
"""
VALID_TYPES = {"limit", "market", "cancel"}
REQUIRED_KEYS = {
"limit": {"type", "side", "price", "size"},
"market": {"type", "side", "size"},
"cancel": {"type", "order_id"},
}
VALID_SIDES = {"buy", "ask"}
temperature = 0.0
top_p = 0.8
max_tokens = 100
@dataclass
class HftParams:
max_steps: Optional[int] = None
tick_size: Optional[float] = None
inventory: Optional[int] = None
cash: Optional[float] = None
arrival_price: Optional[float] = None
target_shares: Optional[int] = None
def to_env_dict(self) -> Dict[str, str]:
"""Returns a dict of non-None values as strings."""
return {k.upper(): str(v) for k, v in self.__dict__.items() if v is not None}
class WebSocketError(Exception):
pass
def log_transcript(message: str) -> None:
global _transcript
_transcript += f"{message}\n"
print(message, flush=True)
def validate_task_name(task: str) -> None:
if task not in tasks:
log_transcript(f"Invalid task name: {task}. Enter one of: {tasks}")
sys.exit(1)
async def call_llm_for_actions(
client: AsyncOpenAI, user_prompt: str
) -> Optional[list[dict]]:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
llm_reply = response.choices[0].message.content
actions = extract_actions(llm_reply)
log_transcript(f"LLM response: {llm_reply}")
log_transcript(f"Parsed actions: {actions}")
return actions
def actions_to_hft_action(actions: Optional[list[dict]]) -> HftAction:
limit_buy_price = None
limit_buy_size = None
limit_ask_price = None
limit_ask_size = None
market_buy_size = None
market_ask_size = None
cancel_order_id = None
if actions not in (None, []):
for action in actions:
if action["type"] == "limit":
if action["side"] == "buy":
limit_buy_price = action["price"]
limit_buy_size = action["size"]
elif action["side"] == "ask":
limit_ask_price = action["price"]
limit_ask_size = action["size"]
elif action["type"] == "market":
if action["side"] == "buy":
market_buy_size = action["size"]
elif action["side"] == "ask":
market_ask_size = action["size"]
elif action["type"] == "cancel":
cancel_order_id = action["order_id"]
return HftAction(
limit_buy_price=limit_buy_price,
limit_buy_size=limit_buy_size,
limit_ask_price=limit_ask_price,
limit_ask_size=limit_ask_size,
market_buy_size=market_buy_size,
market_ask_size=market_ask_size,
cancel_order_id=cancel_order_id,
)
def normalize_observation(
observation: Dict[str, Any],
**kwargs,
) -> HftObservation:
return HftObservation(**observation, **kwargs)
def build_user_prompt(
observation: HftObservation | Dict[str, Any], **kwargs
) -> tuple[HftObservation, str]:
obs_model = normalize_observation(observation, **kwargs)
if getattr(obs_model, "history", None):
mid = obs_model.history[-1].get("mid")
if mid is not None and hasattr(obs_model, "active_orders"):
obs_model.active_orders = compress_active_orders(
obs_model.active_orders, mid
)
return obs_model, parse_obs(obs_model)
async def run_episode_loop(
client: AsyncOpenAI,
task: str,
initial_observation: HftObservation | Dict[str, Any],
initial_done: bool,
step_fn: Callable[[HftAction], Awaitable[HftObservation | Dict[str, Any]]],
) -> None:
_rewards[task] = []
observation, user_prompt = build_user_prompt(initial_observation)
done = initial_done
while not done:
actions = await call_llm_for_actions(client, user_prompt)
action_payload = actions_to_hft_action(actions)
observation = await step_fn(action_payload)
if not isinstance(observation, dict):
observation = asdict(observation)["observation"].model_dump()
observation, user_prompt = build_user_prompt(observation)
else:
observation, user_prompt = build_user_prompt(
observation.get("observation"),
reward=observation.get("reward"),
done=observation.get("done"),
)
reward = observation.reward
done = observation.done
history = observation.history
_rewards[task].append(reward)
log_transcript(
f"Step reward: {reward}, Done: {done}"
if VERBOSE.lower() == "false"
else f"Step reward: {reward}, Done: {done}, History: {history}"
)
def coerce_action(obj: dict) -> Optional[dict]:
"""Validate and coerce a single action dict into a clean, typed action."""
if not isinstance(obj, dict):
return None
# Normalize keys
obj = {k.strip().lower(): v for k, v in obj.items()}
action_type = str(obj.get("type", "")).strip().lower()
if action_type not in VALID_TYPES:
return None
required = REQUIRED_KEYS[action_type]
if not required.issubset(obj.keys()):
return None
if action_type == "limit":
side = str(obj["side"]).strip().lower()
if side not in VALID_SIDES:
return None
return {
"type": "limit",
"side": side,
"price": float(obj["price"]),
"size": int(obj["size"]),
}
if action_type == "market":
side = str(obj["side"]).strip().lower()
if side not in VALID_SIDES:
return None
return {
"type": "market",
"side": side,
"size": int(obj["size"]),
}
if action_type == "cancel":
return {
"type": "cancel",
"order_id": str(obj["order_id"]).strip(),
}
return None
def extract_actions(response: str) -> Optional[list[HftAction]]:
"""
Extract a list of valid LOB actions from any LLM response string.
Handles:
- Clean JSON arrays / objects
- Markdown code fences (```json ... ``` or ``` ... ```)
- null / None responses (idle signal)
- Partial prose with embedded JSON fragments
- Single objects that should be wrapped in a list
Returns:
- List of validated action dicts, or
- None (idle — volume complete, inventory zero)
"""
if response is None:
return None
text = response.strip()
if text.lower() in {"null", "none", ""}:
return None
text = re.sub(r"```(?:json)?\s*", "", text, flags=re.IGNORECASE).strip()
text = text.rstrip("`").strip()
json_object_pattern = re.compile(r"\{[^{}]*\}", re.DOTALL)
json_array_pattern = re.compile(r"\[.*?\]", re.DOTALL)
candidates = []
try:
parsed = json.loads(text)
if parsed is None:
return None
if isinstance(parsed, list):
candidates = parsed
elif isinstance(parsed, dict):
candidates = [parsed]
except json.JSONDecodeError:
pass
if not candidates:
for match in json_array_pattern.finditer(text):
try:
parsed = json.loads(match.group())
if isinstance(parsed, list):
candidates = parsed
break
except json.JSONDecodeError:
continue
if not candidates:
for match in json_object_pattern.finditer(text):
try:
obj = json.loads(match.group())
if isinstance(obj, dict):
candidates.append(obj)
except json.JSONDecodeError:
continue
actions = []
for item in candidates:
action = coerce_action(item)
if action is not None:
actions.append(action)
if not actions:
return None
return actions
def parse_obs(obs: HftObservation | Dict[str, Any]) -> str:
"""Convert HftObservation to a clean, JSON-serializable dict for LLM input."""
if isinstance(obs, HftObservation):
obs_dict = obs.model_dump()
elif isinstance(obs, dict):
obs_dict = obs
else:
raise ValueError("Observation must be HftObservation or dict")
return json.dumps(obs_dict)
def compress_active_orders(orders, mid_price, tick_window=0.10):
"""
Filters orders within a price window and flattens them.
Window of 0.10 means +/- 10 ticks from mid for a 0.01 tick size.
"""
compressed = []
for o in orders:
# 1. Skip orders with no ID
if o["id"] is None:
continue
# 2. Filter by distance to mid-price
if abs(o["price"] - mid_price) <= tick_window:
# 3. Use 1-character codes for side: Buy=B, Ask=A
side_code = "B" if o["side"] == "buy" else "A"
compressed.append(
[o["id"], side_code, round(o["price"], 2), int(o["size"])]
)
return compressed
async def health_check(http_url: str) -> bool:
async with httpx.AsyncClient(base_url=http_url, timeout=60) as client:
try:
response = await client.get(f"{http_url}/health")
return response.status_code == 200
except Exception as e:
print(f"Health check failed: {e}")
return False
async def run_docker_task(
tasks: list, pause: int, hftParams: HftParams, client: AsyncOpenAI
):
from client import HftEnv
try:
for task in tasks:
env = None
validate_task_name(task)
log_transcript(f"Running task: {task}")
try:
env = await HftEnv.from_docker_image(
"ghcr.io/jonathan-shiju/hft-env:latest",
env_vars={
"TASK": task,
**hftParams.to_env_dict(),
},
)
reset_response: StepResult = await env.reset()
observation = asdict(reset_response)
observation = observation["observation"].model_dump()
log_transcript(f"Initial observation: {observation}")
def docker_step(action_payload: HftAction) -> HftObservation:
return env.step(action_payload)
await run_episode_loop(
client=client,
task=task,
initial_observation=observation,
initial_done=reset_response.done,
step_fn=docker_step,
)
_rewards_avg[task] = (
sum(_rewards[task]) / len(_rewards[task]) if _rewards[task] else 0
)
log_transcript(
f"Episode finished for task: {task}, Reward: {_rewards_avg[task]}"
)
await asyncio.sleep(pause)
finally:
if env is not None:
await env.close()
log_transcript("All tasks completed.")
except Exception as e:
log_transcript(f"Error running Docker task: {e}")
sys.exit(1)
async def check_ws_error(
response: WSObservationResponse | WSErrorResponse,
) -> Optional[str]:
if response.get("type") == "error":
raise WebSocketError(
response.get("message", "Unknown WebSocket error")
+ f" (code: {response.get('code', 'N/A')})"
+ f" errors: {response.get('errors', '')}"
)
return None
async def run_online_single_task(
client: AsyncOpenAI, task: str, params: HftParams, base_url: str
):
async with websockets.connect(
uri=f"wss://{base_url.lstrip('https://')}/ws"
) as websocket:
try:
await websocket.send(
json.dumps(
{
"type": "reset",
"data": {
"task_name": task,
"max_steps": params.max_steps,
"tick_size": params.tick_size,
"inventory": params.inventory,
"cash": params.cash,
"arrival_price": params.arrival_price,
"target_shares": params.target_shares,
},
}
)
)
reset_response: WSObservationResponse | WSErrorResponse = json.loads(
await websocket.recv()
)
await check_ws_error(reset_response)
observation = reset_response.get("data")
done = observation.get("done", False)
log_transcript(f"Initial observation: {observation}")
async def ws_step(action_payload: HftAction) -> Dict[str, Any]:
await websocket.send(
json.dumps(
{
"type": "step",
"data": {
**action_payload.model_dump(),
},
}
)
)
step_response: WSObservationResponse | WSErrorResponse = json.loads(
await websocket.recv()
)
await check_ws_error(step_response)
return step_response.get("data", {})
observation = observation.get("observation", {})
await run_episode_loop(
client=client,
task=task,
initial_observation=observation,
initial_done=done,
step_fn=ws_step,
)
await websocket.send(json.dumps({"type": "close", "data": {}}))
_rewards_avg[task] = (
sum(_rewards[task]) / len(_rewards[task]) if _rewards[task] else 0
)
log_transcript(
f"Episode finished for task: {task}, Reward: {_rewards_avg[task]}"
)
except WebSocketError as e:
log_transcript(f"WebSocket error: {e}")
sys.exit(1)
except Exception as e:
log_transcript(f"Error during WebSocket interaction: {e}")
sys.exit(1)
async def run_online_tasks(
client: AsyncOpenAI, tasks: list, pause: int, params: HftParams, base_url: str
):
for task in tasks:
validate_task_name(task)
log_transcript(f"Running task: {task}")
await run_online_single_task(client, task, params, base_url)
await asyncio.sleep(pause)
async def run_task(
client: AsyncOpenAI,
tasks: list,
pause: int,
args: argparse.Namespace,
base_url: str,
):
hftParams = HftParams(
max_steps=args.max_steps,
tick_size=args.tick_size,
inventory=args.inventory,
cash=args.cash,
arrival_price=args.arrival_price,
target_shares=args.target_shares,
)
log_transcript(
"Using parameters - "
f"max_steps: {hftParams.max_steps}, tick_size: {hftParams.tick_size}, "
f"inventory: {hftParams.inventory}, cash: {hftParams.cash}, "
f"arrival_price: {hftParams.arrival_price}, target_shares: {hftParams.target_shares}"
)
use_docker = not await health_check(base_url)
if use_docker:
log_transcript(
"HFSpace server is not healthy. Diverting to docker based environment."
)
await run_docker_task(tasks, pause, hftParams, client)
else:
log_transcript("HFSpace server is healthy. Running inference against it.")
await run_online_tasks(client, tasks, pause, hftParams, base_url)
log_transcript("All tasks completed.")
async def main():
parser = argparse.ArgumentParser(
description="Run inference against the Hft environment using OpenAI API."
)
parser.add_argument(
"--task",
type=str,
nargs="+",
default=["all"],
help="Task name to run (basic_execution, false_signal, conflicting_signal, flash_crash, or all)",
)
parser.add_argument(
"--base-url",
type=str,
default="https://jonathanshiju12-hft-env.hf.space",
help="Base URL for the Hft API (default: https://jonathanshiju12-hft-env.hf.space)",
)
parser.add_argument(
"--pause",
type=int,
default=3,
help="Seconds to pause between steps (default: 3)",
)
parser.add_argument(
"--max-steps",
type=int,
default=None,
help="Maximum steps to run for each task ",
)
parser.add_argument(
"--tick-size",
type=float,
default=None,
help="Tick size for the market simulation",
)
parser.add_argument(
"--inventory",
type=int,
default=None,
help="Starting inventory for the agent",
)
parser.add_argument(
"--cash", type=float, default=None, help="Starting cash for the agent"
)
parser.add_argument(
"--arrival-price",
type=float,
default=None,
help="Arrival price for the market simulation",
)
parser.add_argument(
"--target-shares",
type=int,
default=None,
help="Target shares to execute for the agent",
)
args = parser.parse_args()
tasks_to_run = (
args.task
if "all" not in args.task
else ["basic_execution", "false_signal", "conflicting_signal", "flash_crash"]
)
if not API_KEY and not API_BASE_URL:
raise ValueError("API_KEY and API_BASE_URL must be set when VERBOSE is True")
try:
client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
except Exception as e:
print(f"Error initializing OpenAI client: {e}")
sys.exit(1)
await run_task(client, tasks_to_run, args.pause, args, args.base_url)
with open("transcript.txt", "w") as f:
f.write(_transcript)
with open("baseline_scores.json", "w") as f:
baseline_scores = {
"avg": _rewards_avg,
"all": _rewards,
}
json.dump(baseline_scores, f, indent=4)
if __name__ == "__main__":
asyncio.run(main())