File size: 2,139 Bytes
3bce488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import pytest
import pandas as pd
import numpy as np
from src.processing.features import calculate_sma, calculate_rsi, calculate_macd, process_data
from src.processing.split import split_data

# Sample Data Fixture
@pytest.fixture
def sample_data():
    data = {
        'timestamp': pd.date_range(start='2023-01-01', periods=100),
        'close': np.random.rand(100) * 100
    }
    return pd.DataFrame(data)

def test_calculate_sma(sample_data):
    """Test Simple Moving Average calculation."""
    window = 20
    sma = calculate_sma(sample_data, window)
    assert len(sma) == 100
    assert sma.iloc[0:window-1].isna().all() # First window-1 should be NaN
    assert not sma.iloc[window:].isna().any()

def test_calculate_rsi(sample_data):
    """Test RSI calculation."""
    rsi = calculate_rsi(sample_data)
    assert len(rsi) == 100
    assert rsi.min() >= 0
    assert rsi.max() <= 100

def test_calculate_macd(sample_data):
    """Test MACD calculation."""
    macd, signal = calculate_macd(sample_data)
    assert len(macd) == 100
    assert len(signal) == 100
    assert not macd.isna().all()

def test_split_data(sample_data):
    """Test data splitting."""
    train, test = split_data(sample_data, test_size=0.2)
    assert len(train) == 80
    assert len(test) == 20
    # Ensure no overlap and correct order
    assert train['timestamp'].max() < test['timestamp'].min()

def test_process_data_structure(tmp_path):
    """Test process_data function output structure."""
    # Create a dummy CSV
    df = pd.DataFrame({
        'timestamp': pd.date_range(start='2023-01-01', periods=60),
        'close': [100 + i for i in range(60)] # Linear uptrend
    })
    input_file = tmp_path / "test_input.csv"
    df.to_csv(input_file, index=False)
    
    processed_df = process_data(str(input_file))
    
    expected_columns = ['sma_20', 'sma_50', 'rsi', 'macd', 'target_direction', 'target_price']
    for col in expected_columns:
        assert col in processed_df.columns
    
    # Check if NaNs from rolling windows are dropped
    # SMA_50 needs 50 points, so we expect some data loss
    assert len(processed_df) < 60