stocks / tests /test_backtester.py
Arrechenash's picture
Initial Commit
b2a37ab
"""Minimal tests for backtester CSV export functionality."""
from unittest.mock import patch
import pandas as pd
import pytest
from core.backtester import Backtester
@pytest.fixture
def sample_daily_df():
"""Create sample daily data for testing OHLCV extraction."""
dates = pd.date_range("2024-03-13", "2024-03-18", freq="D")
return pd.DataFrame(
{
"timestamp": dates,
"open": [100.0, 101.0, 102.0, 103.0, 104.0, 105.0],
"high": [105.0, 106.0, 107.0, 108.0, 109.0, 110.0],
"low": [98.0, 99.0, 100.0, 101.0, 102.0, 103.0],
"close": [103.0, 104.0, 105.0, 106.0, 107.0, 108.0],
"volume": [1000000, 1100000, 1200000, 1300000, 1400000, 1500000],
}
)
@pytest.fixture
def sample_minute_bars():
"""Create sample minute bars for a single day."""
dates = pd.date_range("2024-03-15 09:30", "2024-03-15 16:00", freq="5min", tz="America/New_York")
return pd.DataFrame(
{
"open": [103.0] * len(dates),
"high": [104.0] * len(dates),
"low": [102.0] * len(dates),
"close": [103.5] * len(dates),
"volume": [10000] * len(dates),
},
index=dates,
)
@pytest.fixture
def backtester():
return Backtester()
def mock_backtester_trade(backtester, daily_df, day_idx, extra_data=None, minute_bars=None, offset=0):
"""Helper to mock a trade result for testing OHLCV extraction logic.
Since _simulate_trade_impl requires minute data to return a trade,
we extract the OHLCV logic by temporarily patching get_minute_data.
"""
if extra_data is None:
extra_data = {}
# Create a mock minute bars DataFrame if not provided
if minute_bars is None:
minute_bars = pd.DataFrame(
{
"open": [103.0, 103.5, 104.0],
"high": [104.0, 104.5, 105.0],
"low": [102.0, 102.5, 103.0],
"close": [103.5, 104.0, 104.5],
"volume": [10000, 15000, 20000],
},
index=pd.date_range("2024-03-15 09:30", periods=3, freq="5min", tz="America/New_York"),
)
# Patch get_minute_data to return our mock data
with patch.object(backtester.data_manager, "get_minute_data", return_value=minute_bars):
trade_data = backtester._simulate_trade_impl(
original_idx=day_idx,
symbol="TEST",
date_str=daily_df["timestamp"].iloc[day_idx].strftime("%Y-%m-%d"),
compiled_entry=compile("True", "<entry>", "eval"),
compiled_exit=compile("close > 0", "<exit>", "eval"),
safe_globals={"__builtins__": None},
ctx_template={},
extra_data=extra_data,
daily_df=daily_df,
day_idx=day_idx,
offset=offset,
)
return trade_data
class TestOHLCVExtraction:
"""Test OHLCV history extraction from daily_df."""
def test_d0_d1_d2_values_correct(self, backtester, sample_daily_df):
"""Verify d0, d1, d2 columns have correct values for offset=2."""
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, offset=2)
# day_idx=2 = 2024-03-15 (index 2)
# d0 = offset 0 = day_idx = 2024-03-15 (selected day)
assert trade_data["d0_open"] == 102.0
assert trade_data["d0_high"] == 107.0
assert trade_data["d0_low"] == 100.0
assert trade_data["d0_close"] == 105.0
assert trade_data["d0_volume"] == 1200000
# d1 = offset 1 = day_idx - 1 = 2024-03-14
assert trade_data["d1_open"] == 101.0
assert trade_data["d1_high"] == 106.0
assert trade_data["d1_low"] == 99.0
assert trade_data["d1_close"] == 104.0
assert trade_data["d1_volume"] == 1100000
# d2 = offset 2 = day_idx - 2 = 2024-03-13
assert trade_data["d2_open"] == 100.0
assert trade_data["d2_high"] == 105.0
assert trade_data["d2_low"] == 98.0
assert trade_data["d2_close"] == 103.0
assert trade_data["d2_volume"] == 1000000
def test_only_d0_for_offset_zero(self, backtester, sample_daily_df):
"""Verify only d0 column exists when offset=0."""
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, offset=0)
assert "d0_open" in trade_data
assert "d0_high" in trade_data
assert "d0_low" in trade_data
assert "d0_close" in trade_data
assert "d0_volume" in trade_data
assert "d1_open" not in trade_data
assert "d2_open" not in trade_data
def test_avg_vol10_calculated_correctly(self, backtester, sample_daily_df):
"""Verify avg_vol10 = mean(volume of 10 days before, shifted)."""
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, offset=2)
# For day_idx=2 (2024-03-15), avg_vol10 = mean(vol[0:2]) shifted
# Sample volumes: [1M, 1.1M, 1.2M, ...] - shifted excludes current
# mean(vol[0:2]) = (1M + 1.1M) / 2 = 1.05M = 1050000
assert trade_data["avg_vol10"] == 1050000
def test_prev_close_is_day_before_selected(self, backtester, sample_daily_df):
"""Verify prev_close = close of day before selected (offset 1)."""
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, offset=2)
# day_idx=2 (2024-03-15), prev_close = close of 2024-03-14 (index 1)
assert trade_data["prev_close"] == 104.0
def test_metadata_passed_through(self, backtester, sample_daily_df):
"""Verify metadata is correctly passed through extra_data."""
extra_data = {"market_cap": 1e9, "country": "US", "industry": "Tech", "sector": "Software"}
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, extra_data=extra_data, offset=2)
assert trade_data["market_cap"] == 1e9
assert trade_data["country"] == "US"
assert trade_data["industry"] == "Tech"
assert trade_data["sector"] == "Software"
def test_original_date_from_extra_data(self, backtester, sample_daily_df):
"""Verify original_date comes from extra_data."""
extra_data = {"original_date": "2024-03-13"}
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, extra_data=extra_data, offset=2)
assert trade_data["original_date"] == "2024-03-13"
class TestCSVExportColumns:
"""Test that CSV columns have correct values."""
def test_trade_entry_exit_profit(self, backtester, sample_daily_df):
"""Verify trade execution columns."""
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, offset=2)
assert trade_data["symbol"] == "TEST"
assert trade_data["entry"] > 0
assert trade_data["exit"] > 0
assert "profit" in trade_data
assert "ret" in trade_data
def test_premarket_and_intraday(self, backtester, sample_daily_df):
"""Verify premarket and intraday milestone values."""
trade_data = mock_backtester_trade(backtester, sample_daily_df, day_idx=2, offset=2)
# pm_high >= pm_low
assert trade_data["pm_high"] >= trade_data["pm_low"]
# hod >= lod
assert trade_data["hod"] >= trade_data["lod"]