QuantHive / env /state.py
ARKAISW's picture
Hackathon Final Submission: PettingZoo multi-agent arch, GRPO training, docs
9cb3002
"""
State management for the trading environment.
Defines MarketState, PortfolioState, RiskState, and observation construction.
"""
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
@dataclass
class MarketState:
"""Holds current market data and technical indicators for the observation."""
prices: pd.DataFrame # OHLCV + indicators dataframe
current_step: int = 0
def current_row(self) -> pd.Series:
return self.prices.iloc[self.current_step]
def current_price(self) -> float:
return float(self.prices.iloc[self.current_step]["close"])
def observation_vector(self) -> np.ndarray:
"""Return a normalized vector of market features."""
row = self.current_row()
features = []
# Normalized price features (relative to close)
close = row["close"]
for col in ["open", "high", "low", "close"]:
features.append(row[col] / (close + 1e-10))
# Volume — log-normalize
features.append(np.log1p(row["volume"]) / 20.0)
# RSI normalized to [0, 1]
features.append(row["rsi"] / 100.0)
# EMAs relative to close
features.append(row["ema_20"] / (close + 1e-10))
features.append(row["ema_50"] / (close + 1e-10))
# MACD features normalized
features.append(np.tanh(row["macd"] / (close + 1e-10) * 100))
features.append(np.tanh(row["macd_signal"] / (close + 1e-10) * 100))
features.append(np.tanh(row["macd_hist"] / (close + 1e-10) * 100))
# Bollinger Band position: where is price within bands
bb_range = row["bb_upper"] - row["bb_lower"] + 1e-10
features.append((close - row["bb_lower"]) / bb_range)
# Volatility — clip to reasonable range
features.append(min(row["volatility"] * 100, 1.0))
# ATR relative to close (normalized)
features.append(row["atr"] / (close + 1e-10))
return np.array(features, dtype=np.float32)
@property
def feature_size(self) -> int:
return 14 # Number of features in observation_vector
@dataclass
class PortfolioState:
"""Tracks portfolio holdings and cash."""
initial_cash: float = 100_000.0
cash: float = 100_000.0
positions: Dict[str, float] = field(default_factory=dict) # ticker -> quantity
avg_costs: Dict[str, float] = field(default_factory=dict) # ticker -> average entry price
trade_durations: Dict[str, int] = field(default_factory=dict) # ticker -> steps held
trade_history: List[Dict[str, Any]] = field(default_factory=list)
# Professional risk management: Stop Loss and Take Profit
# Format: {ticker: price}
stop_losses: Dict[str, "Optional[float]"] = field(default_factory=dict)
take_profits: Dict[str, "Optional[float]"] = field(default_factory=dict)
def reset(self):
self.cash = self.initial_cash
self.positions = {}
self.avg_costs = {}
self.trade_history = []
self.stop_losses = {}
self.take_profits = {}
def total_value(self, current_price: float, ticker: str = "default") -> float:
"""Total portfolio value = cash + position mark-to-market.
For longs: value = cash + qty * price
For shorts: value = cash + qty * (avg_cost - price) + qty * avg_cost
which simplifies to cash + qty * (2 * avg_cost - price)
But since qty is negative for shorts, we use the unified formula:
value = cash + qty * price (for longs)
value = cash + margin_held + unrealized_pnl (for shorts)
"""
position_qty = self.positions.get(ticker, 0.0)
if position_qty >= 0:
# Long position
return self.cash + position_qty * current_price
else:
# Short position: cash already reduced by margin (|qty| * avg_cost)
# Unrealized P&L = |qty| * (avg_cost - current_price)
avg_cost = self.avg_costs.get(ticker, current_price)
unrealized = abs(position_qty) * (avg_cost - current_price)
return self.cash + unrealized
def unrealized_pnl(self, current_price: float, ticker: str = "default") -> float:
"""
Unrealized profit/loss from open positions using tracked average cost.
Supports both long (positive qty) and short (negative qty) positions.
"""
position_qty = self.positions.get(ticker, 0.0)
if abs(position_qty) < 1e-10:
return 0.0
avg_entry = self.avg_costs.get(ticker, 0.0)
if position_qty > 0:
# Long: profit when price goes up
return position_qty * (current_price - avg_entry)
else:
# Short: profit when price goes down
return abs(position_qty) * (avg_entry - current_price)
def observation_vector(self, current_price: float, ticker: str = "default") -> np.ndarray:
"""Return normalized portfolio features."""
total_val = self.total_value(current_price, ticker)
position_qty = self.positions.get(ticker, 0.0)
long_value = max(position_qty, 0.0) * current_price
short_value = abs(min(position_qty, 0.0)) * current_price
features = [
self.cash / (self.initial_cash + 1e-10), # cash ratio
long_value / (total_val + 1e-10), # long exposure ratio
total_val / (self.initial_cash + 1e-10), # portfolio return ratio
np.tanh(self.unrealized_pnl(current_price, ticker) / (self.initial_cash + 1e-10) * 10), # normalized PnL
short_value / (self.initial_cash + 1e-10), # short exposure ratio
]
return np.array(features, dtype=np.float32)
@property
def feature_size(self) -> int:
return 5
@dataclass
class RiskState:
"""Tracks risk metrics: drawdown, exposure."""
peak_value: float = 100_000.0
current_drawdown: float = 0.0
max_drawdown: float = 0.0
return_history: List[float] = field(default_factory=list)
trade_count: int = 0
def reset(self, initial_value: float = 100_000.0):
self.peak_value = initial_value
self.current_drawdown = 0.0
self.max_drawdown = 0.0
self.return_history = []
self.trade_count = 0
def update(self, portfolio_value: float):
"""Update risk metrics with latest portfolio value."""
# Track returns
if self.return_history:
prev = self.return_history[-1]
ret = (portfolio_value - prev) / (prev + 1e-10)
else:
ret = 0.0
self.return_history.append(portfolio_value)
# Update peak and drawdown
if portfolio_value > self.peak_value:
self.peak_value = portfolio_value
self.current_drawdown = (self.peak_value - portfolio_value) / (self.peak_value + 1e-10)
self.max_drawdown = max(self.max_drawdown, self.current_drawdown)
def sharpe_ratio(self, risk_free_rate: float = 0.0) -> float:
"""Compute Sharpe ratio from return history."""
if len(self.return_history) < 2:
return 0.0
values = np.array(self.return_history)
returns = np.diff(values) / (values[:-1] + 1e-10)
if len(returns) == 0 or np.std(returns) < 1e-10:
return 0.0
return float((np.mean(returns) - risk_free_rate) / (np.std(returns) + 1e-10))
def return_volatility(self) -> float:
"""Compute rolling return volatility."""
if len(self.return_history) < 2:
return 0.0
values = np.array(self.return_history)
returns = np.diff(values) / (values[:-1] + 1e-10)
return float(np.std(returns))
def observation_vector(self) -> np.ndarray:
"""Return normalized risk features."""
features = [
min(self.current_drawdown, 1.0), # current drawdown [0, 1]
min(self.max_drawdown, 1.0), # max drawdown [0, 1]
np.tanh(self.sharpe_ratio()), # sharpe ratio [-1, 1] -> tanh
min(self.return_volatility() * 100, 1.0), # volatility
min(self.trade_count / 100.0, 1.0), # normalized trade count
]
return np.array(features, dtype=np.float32)
@property
def feature_size(self) -> int:
return 5
def get_observation(market: MarketState, portfolio: PortfolioState,
risk: RiskState, ticker: str = "default") -> np.ndarray:
"""Concatenate all state observations into a single flat vector."""
current_price = market.current_price()
obs = np.concatenate([
market.observation_vector(),
portfolio.observation_vector(current_price, ticker),
risk.observation_vector(),
])
return obs
def get_observation_size(market: MarketState, portfolio: PortfolioState,
risk: RiskState) -> int:
"""Total observation vector size."""
return market.feature_size + portfolio.feature_size + risk.feature_size