File size: 4,198 Bytes
9cb3002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
from pathlib import Path
import numpy as np

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from agents.trader import QuantTrader
from env.reward import normalize_reward, compute_raw_reward

OBS_SIZE = 24  # 14 market + 5 portfolio + 5 risk


def test_sl_tp_calculation_long():
    """Test SL/TP for long (buy) positions."""
    trader = QuantTrader()
    obs = np.zeros(OBS_SIZE)
    
    res_data = ("bullish", 0.8, "Signal reasoning")
    fa_data = (0.7, "Sentiment reasoning")
    
    current_price = 50000.0
    risk_data = (
        0.5, 
        {"suggested_sl_ratio": 0.02, "raw_price": current_price}, 
        "Risk reasoning"
    )
    
    direction, size, sl, tp, reasoning = trader(obs, res_data, fa_data, risk_data)
    
    print(f"Long: Direction: {direction}, Size: {size}")
    print(f"Price: {current_price}, SL: {sl}, TP: {tp}")
    
    if direction == 1:
        assert sl < current_price, "Buy SL should be below entry"
        assert tp > current_price, "Buy TP should be above entry"
        expected_sl = current_price * (1 - 0.02)
        assert abs(sl - expected_sl) < 1e-5, f"Expected SL {expected_sl}, got {sl}"
    
    print("Long SL/TP test passed!")


def test_sl_tp_calculation_short():
    """Test SL/TP for short positions."""
    trader = QuantTrader()
    obs = np.zeros(OBS_SIZE)
    
    res_data = ("bearish", 0.9, "Strong bearish signal")
    fa_data = (0.1, "Bearish sentiment")
    
    current_price = 50000.0
    risk_data = (
        0.5,
        {"suggested_sl_ratio": 0.02, "raw_price": current_price},
        "Risk reasoning"
    )
    
    direction, size, sl, tp, reasoning = trader(obs, res_data, fa_data, risk_data)
    
    print(f"Short: Direction: {direction}, Size: {size}")
    print(f"Price: {current_price}, SL: {sl}, TP: {tp}")
    
    if direction == 2:
        assert sl > current_price, f"Short SL should be ABOVE entry, got SL={sl}, price={current_price}"
        assert tp < current_price, f"Short TP should be BELOW entry, got TP={tp}, price={current_price}"
        expected_sl = current_price * (1 + 0.02)
        assert abs(sl - expected_sl) < 1e-5, f"Expected SL {expected_sl}, got {sl}"
        print("Short SL/TP test passed!")
    else:
        print(f"Trader chose direction={direction} instead of short (2). Skipping SL/TP assertions.")


def test_reward_normalization():
    assert normalize_reward(1.0) > 0.0, "Positive reward should be > 0"
    assert normalize_reward(-1.0) < 0.0, "Negative reward should be < 0"
    assert -1.0 <= normalize_reward(100.0) <= 1.0, "Capping failed"
    assert -1.0 <= normalize_reward(-100.0) <= 1.0, "Capping failed"
    assert abs(normalize_reward(0.0)) < 1e-10, "Zero input should give zero"
    print("Reward normalization test passed!")


def test_directional_reward():
    """Test that reward correctly handles both long and short directions."""
    # Buy in uptrend = positive reward
    r1 = compute_raw_reward(
        profit=0.001, drawdown=0.0, volatility=0.01, sharpe=0.5,
        trade_count=1, direction=1, price_trend=0.01
    )
    # Sell/Short in downtrend = positive reward
    r2 = compute_raw_reward(
        profit=0.001, drawdown=0.0, volatility=0.01, sharpe=0.5,
        trade_count=1, direction=2, price_trend=-0.01
    )
    # Wrong direction = negative directional bonus
    r3 = compute_raw_reward(
        profit=0.001, drawdown=0.0, volatility=0.01, sharpe=0.5,
        trade_count=1, direction=1, price_trend=-0.01
    )
    
    assert r1 > r3, f"Correct long direction ({r1}) should score higher than wrong ({r3})"
    assert r2 > r3, f"Correct short direction ({r2}) should score higher than wrong ({r3})"
    print(f"Directional rewards: correct_long={r1:.3f}, correct_short={r2:.3f}, wrong={r3:.3f}")
    print("Directional reward test passed!")


if __name__ == "__main__":
    try:
        test_sl_tp_calculation_long()
        test_sl_tp_calculation_short()
        test_reward_normalization()
        test_directional_reward()
        print("\nAll math verifications passed!")
    except Exception as e:
        print(f"\nVerification failed: {e}")
        sys.exit(1)