File size: 7,887 Bytes
7169bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
"""
Test cases for technical analysis module.
Tests edge cases for technical indicators.
"""
import unittest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from src.technical_analysis import (
calculate_sma,
calculate_ema,
calculate_rsi,
calculate_macd,
calculate_bollinger_bands
)
class TestTechnicalAnalysis(unittest.TestCase):
"""Test cases for technical analysis functions."""
def setUp(self):
"""Set up test data."""
# Create sample price data
dates = pd.date_range(start='2024-01-01', periods=100, freq='D')
# Generate realistic price data with trend
base_price = 100
trend = np.linspace(0, 20, 100)
noise = np.random.normal(0, 2, 100)
prices = base_price + trend + noise
self.prices = pd.Series(prices, index=dates)
# ========================================================================
# SMA TESTS
# ========================================================================
def test_calculate_sma_normal(self):
"""Test SMA calculation with normal data."""
sma = calculate_sma(self.prices, 20)
self.assertEqual(len(sma), len(self.prices))
self.assertFalse(sma.iloc[:19].notna().any()) # First 19 should be NaN
self.assertTrue(sma.iloc[19:].notna().all()) # Rest should have values
def test_calculate_sma_short_data(self):
"""Test SMA with insufficient data."""
short_prices = self.prices[:10]
sma = calculate_sma(short_prices, 20)
self.assertTrue(sma.isna().all()) # All should be NaN
def test_calculate_sma_single_value(self):
"""Test SMA with single price point."""
single_price = pd.Series([100])
sma = calculate_sma(single_price, 1)
self.assertEqual(sma.iloc[0], 100)
def test_calculate_sma_empty(self):
"""Test SMA with empty series."""
empty_prices = pd.Series([], dtype=float)
sma = calculate_sma(empty_prices, 20)
self.assertEqual(len(sma), 0)
# ========================================================================
# EMA TESTS
# ========================================================================
def test_calculate_ema_normal(self):
"""Test EMA calculation with normal data."""
ema = calculate_ema(self.prices, 12)
self.assertEqual(len(ema), len(self.prices))
self.assertTrue(ema.notna().any()) # Should have some values
def test_calculate_ema_empty(self):
"""Test EMA with empty series."""
empty_prices = pd.Series([], dtype=float)
ema = calculate_ema(empty_prices, 12)
self.assertEqual(len(ema), 0)
# ========================================================================
# RSI TESTS
# ========================================================================
def test_calculate_rsi_normal(self):
"""Test RSI calculation with normal data."""
rsi = calculate_rsi(self.prices, 14)
self.assertIsInstance(rsi, float)
self.assertGreaterEqual(rsi, 0)
self.assertLessEqual(rsi, 100)
def test_calculate_rsi_insufficient_data(self):
"""Test RSI with insufficient data."""
short_prices = self.prices[:10]
rsi = calculate_rsi(short_prices, 14)
self.assertEqual(rsi, 50.0) # Should return neutral
def test_calculate_rsi_all_gains(self):
"""Test RSI with all positive changes (should be high)."""
increasing_prices = pd.Series([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115])
rsi = calculate_rsi(increasing_prices, 14)
self.assertGreater(rsi, 50) # Should be above neutral
def test_calculate_rsi_all_losses(self):
"""Test RSI with all negative changes (should be low)."""
decreasing_prices = pd.Series([115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100])
rsi = calculate_rsi(decreasing_prices, 14)
self.assertLess(rsi, 50) # Should be below neutral
def test_calculate_rsi_empty(self):
"""Test RSI with empty series."""
empty_prices = pd.Series([], dtype=float)
rsi = calculate_rsi(empty_prices, 14)
self.assertEqual(rsi, 50.0)
# ========================================================================
# MACD TESTS
# ========================================================================
def test_calculate_macd_normal(self):
"""Test MACD calculation with normal data."""
macd = calculate_macd(self.prices)
self.assertIn("macd", macd)
self.assertIn("signal", macd)
self.assertIn("histogram", macd)
self.assertIn("trend", macd)
self.assertIn(macd["trend"], ["bullish", "bearish"])
def test_calculate_macd_insufficient_data(self):
"""Test MACD with insufficient data."""
short_prices = self.prices[:20] # Need at least 26+9=35 for full MACD
macd = calculate_macd(short_prices)
self.assertEqual(macd["trend"], "neutral")
self.assertEqual(macd["macd"], 0.0)
def test_calculate_macd_empty(self):
"""Test MACD with empty series."""
empty_prices = pd.Series([], dtype=float)
macd = calculate_macd(empty_prices)
self.assertEqual(macd["trend"], "neutral")
# ========================================================================
# BOLLINGER BANDS TESTS
# ========================================================================
def test_calculate_bollinger_bands_normal(self):
"""Test Bollinger Bands calculation with normal data."""
bands = calculate_bollinger_bands(self.prices, 20)
self.assertIn("upper", bands)
self.assertIn("middle", bands)
self.assertIn("lower", bands)
self.assertIn("current_price", bands)
self.assertIn("position", bands)
self.assertGreater(bands["upper"], bands["middle"])
self.assertLess(bands["lower"], bands["middle"])
def test_calculate_bollinger_bands_insufficient_data(self):
"""Test Bollinger Bands with insufficient data."""
short_prices = self.prices[:10]
bands = calculate_bollinger_bands(short_prices, 20)
self.assertEqual(bands["upper"], bands["middle"])
self.assertEqual(bands["lower"], bands["middle"])
self.assertEqual(bands["position"], "neutral")
def test_calculate_bollinger_bands_empty(self):
"""Test Bollinger Bands with empty series."""
empty_prices = pd.Series([], dtype=float)
bands = calculate_bollinger_bands(empty_prices, 20)
self.assertEqual(bands["position"], "neutral")
self.assertEqual(bands["current_price"], 0)
def test_calculate_bollinger_bands_overbought(self):
"""Test Bollinger Bands when price is above upper band."""
# Create prices that trend upward significantly
high_prices = pd.Series([100 + i*2 for i in range(30)])
bands = calculate_bollinger_bands(high_prices, 20)
# Price should be above upper band (overbought)
if bands["current_price"] > bands["upper"]:
self.assertEqual(bands["position"], "overbought")
def test_calculate_bollinger_bands_oversold(self):
"""Test Bollinger Bands when price is below lower band."""
# Create prices that trend downward significantly
low_prices = pd.Series([100 - i*2 for i in range(30)])
bands = calculate_bollinger_bands(low_prices, 20)
# Price should be below lower band (oversold)
if bands["current_price"] < bands["lower"]:
self.assertEqual(bands["position"], "oversold")
if __name__ == '__main__':
unittest.main()
|