"""Gymnasium trading environment for offline RL data collection. This environment produces episodes from the parquet data. It is used both for: 1. Online validation / replay of trained policies 2. Generating offline transitions for IQL training """ from __future__ import annotations from dataclasses import dataclass from typing import Any import gymnasium as gym import numpy as np import pandas as pd from .constants import ( ACTIONS, ACTION_INDEX_BY_NAME, DEFAULT_EPISODE_SPAN_DAYS, DEFAULT_EPISODE_STRIDE_DAYS, DEFAULT_HISTORY_LENGTH, DEFAULT_REBALANCE_TOLERANCE, DRAWDOWN_LIMIT, MARKET_FEATURE_COLUMNS, PORTFOLIO_FEATURE_COLUMNS, STARTING_CASH, ) @dataclass class FillResult: executed: bool action_name: str side: str | None shares: float = 0.0 price: float = 0.0 cost: float = 0.0 fee: float = 0.0 reason: str = "" @dataclass class PositionState: side: str | None = None # "YES" or "NO" or None shares: float = 0.0 cost_basis_usdc: float = 0.0 avg_entry_price: float = 0.0 steps_held: int = 0 def is_open(self) -> bool: return self.side is not None and self.shares > 0 def reset(self): self.side = None self.shares = 0.0 self.cost_basis_usdc = 0.0 self.avg_entry_price = 0.0 self.steps_held = 0 @dataclass class PortfolioState: starting_cash: float = STARTING_CASH cash: float = STARTING_CASH peak_equity: float = STARTING_CASH max_drawdown_fraction: float = 0.0 realized_pnl: float = 0.0 fill_count: int = 0 position: PositionState = None def __post_init__(self): if self.position is None: self.position = PositionState() def equity(self, yes_bid: float, no_bid: float) -> float: if self.position.is_open(): bid = yes_bid if self.position.side == "YES" else no_bid liquidation_value = self.position.shares * bid return self.cash + liquidation_value return self.cash def mark_to_liquidation(self, yes_bid: float, no_bid: float) -> float: if self.position.is_open(): bid = yes_bid if self.position.side == "YES" else no_bid return self.position.shares * bid return 0.0 def update_drawdown(self, yes_bid: float, no_bid: float) -> float: current = self.equity(yes_bid=yes_bid, no_bid=no_bid) if current > self.peak_equity: self.peak_equity = current if self.peak_equity > 1e-12: dd = max(0.0, (self.peak_equity - current) / self.peak_equity) self.max_drawdown_fraction = max(self.max_drawdown_fraction, dd) return self.max_drawdown_fraction def _to_utc_timestamp(value) -> pd.Timestamp: ts = pd.Timestamp(value) if ts.tzinfo is None: return ts.tz_localize("UTC") return ts.tz_convert("UTC") @dataclass class EpisodeData: episode_index: int episode_start_day: pd.Timestamp episode_end_day: pd.Timestamp episode_days: tuple[pd.Timestamp, ...] frame: pd.DataFrame market_features_raw: np.ndarray class BTCTradingEnv(gym.Env[np.ndarray, int]): """Gymnasium environment for BTC 5m trading. Observation = flattened [history of market features] + [current portfolio state]. Action = discrete action ID from ACTIONS. """ metadata = {"render_modes": []} def __init__( self, df: pd.DataFrame, *, market_feature_columns: list[str] | None = None, starting_cash: float = STARTING_CASH, market_feature_mean: np.ndarray | None = None, market_feature_std: np.ndarray | None = None, episode_days: list[pd.Timestamp] | None = None, history_length: int = DEFAULT_HISTORY_LENGTH, episode_span_days: int = DEFAULT_EPISODE_SPAN_DAYS, episode_stride_days: int = DEFAULT_EPISODE_STRIDE_DAYS, max_position_cost_fraction_of_equity: float | None = 0.50, max_position_steps: int | None = None, soft_drawdown_increment_penalty: float = 0.50, rebalance_tolerance: float = DEFAULT_REBALANCE_TOLERANCE, risk_lambda: float = 1.0, cvar_alpha: float = 0.05, taker_fee_rate: float = 0.072, ) -> None: super().__init__() self.market_feature_columns = market_feature_columns or list(MARKET_FEATURE_COLUMNS) self.starting_cash = float(starting_cash) self.market_feature_mean = market_feature_mean self.market_feature_std = market_feature_std self.history_length = max(1, int(history_length)) self.episode_span_days = max(1, int(episode_span_days)) self.episode_stride_days = max(1, int(episode_stride_days)) self.max_position_cost_fraction_of_equity = ( None if max_position_cost_fraction_of_equity is None else float(max_position_cost_fraction_of_equity) ) self.max_position_steps = None if max_position_steps is None else int(max_position_steps) self.soft_drawdown_increment_penalty = float(soft_drawdown_increment_penalty) self.rebalance_tolerance = max(0.0, float(rebalance_tolerance)) self.risk_lambda = float(risk_lambda) self.cvar_alpha = float(cvar_alpha) self.taker_fee_rate = float(taker_fee_rate) working = df.copy() working["episode_day"] = pd.to_datetime(working["episode_day"], utc=True) feature_defaults = { "funding_rate": 0.0, "funding_rate_prev": 0.0, "oi_delta_5m": 0.0, "oi_delta_15m": 0.0, "oi_delta_60m": 0.0, "long_short_ratio": 1.0, "label_up": 0.0, } for col in self.market_feature_columns: if col not in working.columns: working[col] = feature_defaults.get(col, 0.0) elif working[col].isna().all(): working[col] = feature_defaults.get(col, 0.0) else: working[col] = working[col].fillna(feature_defaults.get(col, 0.0)) if "label_up" in working.columns: working["label_up"] = working["label_up"].fillna(0.0) available_days = list(episode_days or sorted(pd.Index(working["episode_day"].drop_duplicates()))) available_days = [_to_utc_timestamp(day) for day in available_days] self._episodes: list[EpisodeData] = [] for start_idx in range(0, len(available_days), self.episode_stride_days): window_days = tuple(available_days[start_idx:start_idx + self.episode_span_days]) if not window_days: continue frame = ( working.loc[working["episode_day"].isin(window_days)] .sort_values(["start_time", "obs_pos"]) .reset_index(drop=True) ) if frame.empty: continue self._episodes.append(EpisodeData( episode_index=len(self._episodes), episode_start_day=window_days[0], episode_end_day=window_days[-1], episode_days=window_days, frame=frame, market_features_raw=np.nan_to_num( frame[self.market_feature_columns].to_numpy(dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0, ), )) if not self._episodes: raise ValueError("BTCTradingEnv requires at least one non-empty episode.") obs_size = (self.history_length * len(self.market_feature_columns)) + len(PORTFOLIO_FEATURE_COLUMNS) self.action_space = gym.spaces.Discrete(len(ACTIONS)) self.observation_space = gym.spaces.Box( low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float32, ) self._rng = np.random.default_rng() self._episode_index = 0 self._episode: EpisodeData | None = None self._cursor = 0 self._portfolio = PortfolioState(starting_cash=self.starting_cash, cash=self.starting_cash) @property def portfolio(self) -> PortfolioState: return self._portfolio def current_row(self) -> pd.Series: if self._episode is None: raise RuntimeError("Environment not reset.") return self._episode.frame.iloc[self._cursor] def _normalize_market_features(self, raw: np.ndarray) -> np.ndarray: raw = np.nan_to_num(raw.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0) if self.market_feature_mean is None or self.market_feature_std is None: return raw normalized = ((raw - self.market_feature_mean) / self.market_feature_std).astype(np.float32) return np.nan_to_num(normalized, nan=0.0, posinf=0.0, neginf=0.0) def _position_fraction(self, yes_bid: float, no_bid: float) -> float: equity = self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) if equity <= 1e-12: return 0.0 return self._portfolio.mark_to_liquidation(yes_bid=yes_bid, no_bid=no_bid) / equity def _portfolio_features(self, row: pd.Series) -> np.ndarray: yes_bid = float(row["yes_bid_validated"]) no_bid = float(row["no_bid_validated"]) equity = self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) mark_value = self._portfolio.mark_to_liquidation(yes_bid=yes_bid, no_bid=no_bid) position = self._portfolio.position position_side = 0.0 if position.side == "YES": position_side = 1.0 elif position.side == "NO": position_side = -1.0 unrealized = mark_value - position.cost_basis_usdc if position.is_open() else 0.0 position_fraction = self._position_fraction(yes_bid=yes_bid, no_bid=no_bid) return np.array([ self._portfolio.cash / self.starting_cash, equity / self.starting_cash, self._portfolio.max_drawdown_fraction, position_side, position_fraction, position.shares, position.avg_entry_price, unrealized / self.starting_cash, position.steps_held / 5.0, ], dtype=np.float32) def _history_features(self) -> np.ndarray: if self._episode is None: raise RuntimeError("Environment not reset.") history_vectors: list[np.ndarray] = [] first_index = max(0, self._cursor - self.history_length + 1) for idx in range(first_index, self._cursor + 1): history_vectors.append(self._normalize_market_features(self._episode.market_features_raw[idx])) while len(history_vectors) < self.history_length: history_vectors.insert(0, np.zeros(len(self.market_feature_columns), dtype=np.float32)) return np.concatenate(history_vectors, dtype=np.float32) def _get_observation(self) -> np.ndarray: row = self.current_row() market_history = self._history_features() portfolio = self._portfolio_features(row) return np.concatenate([market_history, portfolio], dtype=np.float32) def _compute_fee(self, shares: float, price: float) -> float: """Compute taker fee for a trade. Quadratic fee model: fee = shares × fee_rate × price × (1 - price). Highest near price=0.50, zero at price=0 or 1. Matches Crypto/BTC prediction market fee schedule. """ if self.taker_fee_rate <= 0: return 0.0 return float(shares) * self.taker_fee_rate * price * (1.0 - price) def _buy_to_target(self, *, row: pd.Series, side: str, target_fraction: float, equity_reference: float) -> FillResult: ask = float(row["yes_ask_validated"] if side == "YES" else row["no_ask_validated"]) budget = min(self._portfolio.cash, max(0.0, float(target_fraction) * float(equity_reference))) if self.max_position_cost_fraction_of_equity is not None: cap = float(self.max_position_cost_fraction_of_equity) * float(equity_reference) budget = min(budget, cap) if ask <= 0 or budget <= 0: return FillResult(False, "BUY", side, reason="INVALID_PRICE") shares = budget / ask fee = self._compute_fee(shares, ask) total_cost = budget + fee if total_cost > self._portfolio.cash: # Adjust to what we can actually afford shares = self._portfolio.cash / (ask + self.taker_fee_rate * ask * (1.0 - ask)) budget = shares * ask fee = self._compute_fee(shares, ask) total_cost = budget + fee new_cost = self._portfolio.position.cost_basis_usdc + budget new_shares = self._portfolio.position.shares + shares new_avg = new_cost / new_shares if new_shares > 0 else 0 self._portfolio.position.side = side self._portfolio.position.shares = new_shares self._portfolio.position.cost_basis_usdc = new_cost self._portfolio.position.avg_entry_price = new_avg self._portfolio.cash -= total_cost self._portfolio.fill_count += 1 return FillResult(True, "BUY", side, shares=shares, price=ask, cost=budget, fee=fee) def _liquidate(self, *, row: pd.Series) -> FillResult: position = self._portfolio.position if not position.is_open(): return FillResult(False, "LIQUIDATE", None, reason="NO_POSITION") side = position.side shares = position.shares yes_bid = float(row["yes_bid_validated"]) no_bid = float(row["no_bid_validated"]) bid = yes_bid if side == "YES" else no_bid proceeds = shares * bid fee = self._compute_fee(shares, bid) net_proceeds = proceeds - fee pnl = net_proceeds - position.cost_basis_usdc self._portfolio.cash += net_proceeds self._portfolio.realized_pnl += pnl position.reset() self._portfolio.fill_count += 1 return FillResult(True, "LIQUIDATE", side, shares=shares, price=bid, fee=fee) def _apply_action(self, action_id: int) -> FillResult: action = ACTIONS[int(action_id)] row = self.current_row() yes_bid = float(row["yes_bid_validated"]) no_bid = float(row["no_bid_validated"]) equity = self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) current_fraction = self._position_fraction(yes_bid=yes_bid, no_bid=no_bid) target_fraction = float(action.target_fraction) if self.max_position_cost_fraction_of_equity is not None and target_fraction > 0.0: target_fraction = min(target_fraction, float(self.max_position_cost_fraction_of_equity)) if action.kind == "hold": return FillResult(False, action.name, None, reason="HOLD") if int(row.get("obs_pos", 0)) == 4: # The settlement row resolves any existing inventory in `step`. # Opening or changing positions here would let a policy buy using # same-row outcome information, which is not executable. return FillResult(False, action.name, action.target_side, reason="SETTLEMENT_ROW") if target_fraction <= 0.0: # FLAT action: liquidate if open if self._portfolio.position.is_open(): return self._liquidate(row=row) return FillResult(False, action.name, None, reason="ALREADY_FLAT") # Side switch: liquidate first if (self._portfolio.position.is_open() and self._portfolio.position.side != action.target_side): self._liquidate(row=row) equity = self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) current_fraction = self._position_fraction(yes_bid=yes_bid, no_bid=no_bid) delta = target_fraction - current_fraction if abs(delta) <= self.rebalance_tolerance: return FillResult(False, action.name, action.target_side, reason="AT_TARGET") if delta > 0: return self._buy_to_target( row=row, side=action.target_side, target_fraction=delta, equity_reference=equity, ) else: # Reduce: sell fraction of current position sell_frac = min(1.0, abs(delta) / max(current_fraction, 1e-12)) if sell_frac <= 0: return FillResult(False, action.name, action.target_side, reason="NO_REDUCTION") position = self._portfolio.position yes_bid = float(row["yes_bid_validated"]) no_bid = float(row["no_bid_validated"]) bid = yes_bid if position.side == "YES" else no_bid shares_to_sell = position.shares * sell_frac proceeds = shares_to_sell * bid fee = self._compute_fee(shares_to_sell, bid) net_proceeds = proceeds - fee cost_portion = position.cost_basis_usdc * sell_frac pnl = net_proceeds - cost_portion self._portfolio.cash += net_proceeds self._portfolio.realized_pnl += pnl position.shares -= shares_to_sell position.cost_basis_usdc -= cost_portion side = position.side if position.shares <= 1e-6: position.reset() self._portfolio.fill_count += 1 return FillResult(True, "REDUCE", side, shares=shares_to_sell, price=bid, fee=fee) def _compute_risk_sensitive_reward(self, equity_before: float, equity_after: float, prev_drawdown: float, current_drawdown: float) -> float: """Compute risk-sensitive reward based on PnL, drawdown, and CVaR-like penalty.""" # Bounded log-equity return. This matches the offline dataset reward # scaling and prevents rare near-expiry binary payouts from dominating # the critic target scale. equity_floor = max(1e-6, 0.01 * self.starting_cash) before = max(float(equity_before), equity_floor) after = max(float(equity_after), equity_floor) pnl_reward = float(np.log(after / before)) pnl_reward = max(-2.0, min(2.0, pnl_reward)) # Drawdown increment penalty dd_penalty = self.soft_drawdown_increment_penalty * max(0.0, current_drawdown - prev_drawdown) # Risk penalty: heavily penalise large drawdown increments risk_penalty = self.risk_lambda * max(0.0, current_drawdown - prev_drawdown) ** 2 # CVaR-like: penalise tail losses (amplify negative returns during drawdown) if pnl_reward < 0: cvar_penalty = self.risk_lambda * abs(pnl_reward) * (1.0 + current_drawdown) else: cvar_penalty = 0.0 reward = pnl_reward - dd_penalty - risk_penalty - cvar_penalty return max(-4.0, min(4.0, reward)) def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[np.ndarray, dict[str, Any]]: super().reset(seed=seed) if seed is not None: self._rng = np.random.default_rng(seed) options = options or {} if "episode_index" in options: episode_index = int(options["episode_index"]) else: episode_index = int(self._rng.integers(0, len(self._episodes))) starting_cash = float(options.get("starting_cash", self.starting_cash)) starting_peak = float(options.get("starting_peak_equity", starting_cash)) starting_dd = float(options.get("starting_max_drawdown_fraction", 0.0)) self._episode_index = episode_index self._episode = self._episodes[episode_index] self._cursor = 0 self._portfolio = PortfolioState( starting_cash=self.starting_cash, cash=starting_cash, peak_equity=starting_peak, max_drawdown_fraction=starting_dd, ) obs = self._get_observation() row = self.current_row() info = { "episode_index": self._episode_index, "episode_day": row["episode_day"].isoformat(), "episode_start_day": self._episode.episode_start_day.isoformat(), "episode_end_day": self._episode.episode_end_day.isoformat(), } return obs, info def step(self, action_id: int) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: row = self.current_row() yes_bid = float(row["yes_bid_validated"]) no_bid = float(row["no_bid_validated"]) equity_before = self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) prev_drawdown = self._portfolio.max_drawdown_fraction fill = self._apply_action(int(action_id)) # Check settlement: obs_pos==4 means end of 5m window settlement_pnl = 0.0 if int(row.get("obs_pos", 0)) == 4 and self._portfolio.position.is_open(): # Settle position based on label label_up = int(row.get("label_up", 0)) position = self._portfolio.position # YES pays $1 if UP (label_up==1), $0 if DOWN # NO pays $1 if DOWN (label_up==0), $0 if UP if position.side == "YES": proceeds = position.shares * (1.0 if label_up == 1 else 0.0) else: # NO proceeds = position.shares * (1.0 if label_up == 0 else 0.0) settlement_pnl = proceeds - position.cost_basis_usdc self._portfolio.cash += proceeds self._portfolio.realized_pnl += settlement_pnl position.reset() self._portfolio.fill_count += 1 equity_after = self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) current_drawdown = self._portfolio.update_drawdown(yes_bid=yes_bid, no_bid=no_bid) reward = self._compute_risk_sensitive_reward( equity_before, equity_after, prev_drawdown, current_drawdown, ) # Drawdown breach terminated = current_drawdown >= DRAWDOWN_LIMIT if terminated: if self._portfolio.position.is_open(): self._liquidate(row=row) reward -= 1.0 is_last_row = self._cursor >= len(self._episode.frame) - 1 truncated = is_last_row and not terminated info: dict[str, Any] = { "episode_day": row["episode_day"].isoformat(), "episode_start_day": self._episode.episode_start_day.isoformat(), "episode_end_day": self._episode.episode_end_day.isoformat(), "action_name": ACTIONS[int(action_id)].name, "fill_executed": fill.executed, "fill_fee": fill.fee, "equity_before": equity_before, "equity_after": equity_after, "reward": reward, } if terminated or truncated: info["episode_summary"] = { "ending_equity": self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid), "total_pnl": self._portfolio.equity(yes_bid=yes_bid, no_bid=no_bid) - self.starting_cash, "realized_pnl": self._portfolio.realized_pnl, "max_drawdown_fraction": self._portfolio.max_drawdown_fraction, "fill_count": self._portfolio.fill_count, } obs = np.zeros(self.observation_space.shape, dtype=np.float32) else: self._cursor += 1 obs = self._get_observation() return obs, float(reward), terminated, truncated, info