luohoa97's picture
Deploy BitNet-Transformer Trainer
d5b7ee9 verified
"""Risk management — position sizing, stop-loss, drawdown checks."""
from __future__ import annotations
import logging
import math
logger = logging.getLogger(__name__)
def check_market_regime(
spy_ohlcv: pd.DataFrame,
period: int = 200,
) -> str:
"""
Determine if the broad market is Bullish or Bearish.
Uses SPY or QQQ 200-day SMA as a proxy.
"""
if spy_ohlcv.empty or len(spy_ohlcv) < period:
return "UNKNOWN"
close_col = "close" if "close" in spy_ohlcv.columns else "Close"
closes = spy_ohlcv[close_col]
sma = closes.rolling(window=period).mean().iloc[-1]
current = closes.iloc[-1]
return "BULLISH" if current > sma else "BEARISH"
def calculate_position_size(
portfolio_value: float,
price: float,
risk_pct: float = 0.02,
max_position_pct: float = 0.10,
) -> int:
"""
Calculate number of shares to buy.
risk_pct: fraction of portfolio to risk per trade (default 2%)
max_position_pct: cap single position at X% of portfolio
Returns at least 1 share, never more than the cap.
"""
if price <= 0 or portfolio_value <= 0:
return 0
# Calculate shares based on risk budget
risk_budget = portfolio_value * risk_pct
shares_by_risk = math.floor(risk_budget / price)
# Calculate shares based on portfolio cap
max_budget = portfolio_value * max_position_pct
max_shares = math.floor(max_budget / price)
# Use the smaller of the two, but at least 1 if we can afford it
shares = min(shares_by_risk, max_shares)
if shares <= 0 and max_shares > 0:
shares = 1
logger.debug(
"Position size: portfolio=%.0f price=%.2f risk_pct=%.2f max_pos_pct=%.2f → %d shares",
portfolio_value, price, risk_pct, max_position_pct, shares,
)
return shares
def check_stop_loss(
entry_price: float,
current_price: float,
threshold: float = 0.05,
) -> bool:
"""True if position has fallen more than `threshold` from entry (long only)."""
if entry_price <= 0:
return False
loss_pct = (entry_price - current_price) / entry_price
return loss_pct >= threshold
def check_max_drawdown(
portfolio_values: list[float],
max_dd: float = 0.15,
) -> bool:
"""
True if the portfolio has drawn down more than `max_dd` from its peak.
Expects a time-ordered list of portfolio values.
"""
if len(portfolio_values) < 2:
return False
peak = max(portfolio_values)
current = portfolio_values[-1]
if peak == 0:
return False
drawdown = (peak - current) / peak
return drawdown >= max_dd
def validate_buy(
symbol: str,
price: float,
qty: int,
cash: float,
positions: dict,
max_positions: int = 10,
) -> tuple[bool, str]:
"""Check if a BUY order is valid."""
cost = price * qty
if cash < cost:
return False, f"Insufficient cash: need ${cost:.2f}, have ${cash:.2f}"
if len(positions) >= max_positions and symbol not in positions:
return False, f"Max positions ({max_positions}) reached"
return True, "OK"
def validate_sell(
symbol: str,
qty: int,
positions: dict,
) -> tuple[bool, str]:
"""Check if a SELL order is valid."""
pos = positions.get(symbol)
if not pos:
return False, f"No position in {symbol}"
held = pos.get("qty", 0) if isinstance(pos, dict) else getattr(pos, "qty", 0)
if held < qty:
return False, f"Hold {held} shares, cannot sell {qty}"
return True, "OK"