|
|
""" |
|
|
Test cases for technical analysis module. |
|
|
Tests edge cases for technical indicators. |
|
|
""" |
|
|
|
|
|
import unittest |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from datetime import datetime, timedelta |
|
|
from src.technical_analysis import ( |
|
|
calculate_sma, |
|
|
calculate_ema, |
|
|
calculate_rsi, |
|
|
calculate_macd, |
|
|
calculate_bollinger_bands |
|
|
) |
|
|
|
|
|
|
|
|
class TestTechnicalAnalysis(unittest.TestCase): |
|
|
"""Test cases for technical analysis functions.""" |
|
|
|
|
|
def setUp(self): |
|
|
"""Set up test data.""" |
|
|
|
|
|
dates = pd.date_range(start='2024-01-01', periods=100, freq='D') |
|
|
|
|
|
base_price = 100 |
|
|
trend = np.linspace(0, 20, 100) |
|
|
noise = np.random.normal(0, 2, 100) |
|
|
prices = base_price + trend + noise |
|
|
self.prices = pd.Series(prices, index=dates) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_calculate_sma_normal(self): |
|
|
"""Test SMA calculation with normal data.""" |
|
|
sma = calculate_sma(self.prices, 20) |
|
|
self.assertEqual(len(sma), len(self.prices)) |
|
|
self.assertFalse(sma.iloc[:19].notna().any()) |
|
|
self.assertTrue(sma.iloc[19:].notna().all()) |
|
|
|
|
|
def test_calculate_sma_short_data(self): |
|
|
"""Test SMA with insufficient data.""" |
|
|
short_prices = self.prices[:10] |
|
|
sma = calculate_sma(short_prices, 20) |
|
|
self.assertTrue(sma.isna().all()) |
|
|
|
|
|
def test_calculate_sma_single_value(self): |
|
|
"""Test SMA with single price point.""" |
|
|
single_price = pd.Series([100]) |
|
|
sma = calculate_sma(single_price, 1) |
|
|
self.assertEqual(sma.iloc[0], 100) |
|
|
|
|
|
def test_calculate_sma_empty(self): |
|
|
"""Test SMA with empty series.""" |
|
|
empty_prices = pd.Series([], dtype=float) |
|
|
sma = calculate_sma(empty_prices, 20) |
|
|
self.assertEqual(len(sma), 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_calculate_ema_normal(self): |
|
|
"""Test EMA calculation with normal data.""" |
|
|
ema = calculate_ema(self.prices, 12) |
|
|
self.assertEqual(len(ema), len(self.prices)) |
|
|
self.assertTrue(ema.notna().any()) |
|
|
|
|
|
def test_calculate_ema_empty(self): |
|
|
"""Test EMA with empty series.""" |
|
|
empty_prices = pd.Series([], dtype=float) |
|
|
ema = calculate_ema(empty_prices, 12) |
|
|
self.assertEqual(len(ema), 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_calculate_rsi_normal(self): |
|
|
"""Test RSI calculation with normal data.""" |
|
|
rsi = calculate_rsi(self.prices, 14) |
|
|
self.assertIsInstance(rsi, float) |
|
|
self.assertGreaterEqual(rsi, 0) |
|
|
self.assertLessEqual(rsi, 100) |
|
|
|
|
|
def test_calculate_rsi_insufficient_data(self): |
|
|
"""Test RSI with insufficient data.""" |
|
|
short_prices = self.prices[:10] |
|
|
rsi = calculate_rsi(short_prices, 14) |
|
|
self.assertEqual(rsi, 50.0) |
|
|
|
|
|
def test_calculate_rsi_all_gains(self): |
|
|
"""Test RSI with all positive changes (should be high).""" |
|
|
increasing_prices = pd.Series([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115]) |
|
|
rsi = calculate_rsi(increasing_prices, 14) |
|
|
self.assertGreater(rsi, 50) |
|
|
|
|
|
def test_calculate_rsi_all_losses(self): |
|
|
"""Test RSI with all negative changes (should be low).""" |
|
|
decreasing_prices = pd.Series([115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100]) |
|
|
rsi = calculate_rsi(decreasing_prices, 14) |
|
|
self.assertLess(rsi, 50) |
|
|
|
|
|
def test_calculate_rsi_empty(self): |
|
|
"""Test RSI with empty series.""" |
|
|
empty_prices = pd.Series([], dtype=float) |
|
|
rsi = calculate_rsi(empty_prices, 14) |
|
|
self.assertEqual(rsi, 50.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_calculate_macd_normal(self): |
|
|
"""Test MACD calculation with normal data.""" |
|
|
macd = calculate_macd(self.prices) |
|
|
self.assertIn("macd", macd) |
|
|
self.assertIn("signal", macd) |
|
|
self.assertIn("histogram", macd) |
|
|
self.assertIn("trend", macd) |
|
|
self.assertIn(macd["trend"], ["bullish", "bearish"]) |
|
|
|
|
|
def test_calculate_macd_insufficient_data(self): |
|
|
"""Test MACD with insufficient data.""" |
|
|
short_prices = self.prices[:20] |
|
|
macd = calculate_macd(short_prices) |
|
|
self.assertEqual(macd["trend"], "neutral") |
|
|
self.assertEqual(macd["macd"], 0.0) |
|
|
|
|
|
def test_calculate_macd_empty(self): |
|
|
"""Test MACD with empty series.""" |
|
|
empty_prices = pd.Series([], dtype=float) |
|
|
macd = calculate_macd(empty_prices) |
|
|
self.assertEqual(macd["trend"], "neutral") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_calculate_bollinger_bands_normal(self): |
|
|
"""Test Bollinger Bands calculation with normal data.""" |
|
|
bands = calculate_bollinger_bands(self.prices, 20) |
|
|
self.assertIn("upper", bands) |
|
|
self.assertIn("middle", bands) |
|
|
self.assertIn("lower", bands) |
|
|
self.assertIn("current_price", bands) |
|
|
self.assertIn("position", bands) |
|
|
self.assertGreater(bands["upper"], bands["middle"]) |
|
|
self.assertLess(bands["lower"], bands["middle"]) |
|
|
|
|
|
def test_calculate_bollinger_bands_insufficient_data(self): |
|
|
"""Test Bollinger Bands with insufficient data.""" |
|
|
short_prices = self.prices[:10] |
|
|
bands = calculate_bollinger_bands(short_prices, 20) |
|
|
self.assertEqual(bands["upper"], bands["middle"]) |
|
|
self.assertEqual(bands["lower"], bands["middle"]) |
|
|
self.assertEqual(bands["position"], "neutral") |
|
|
|
|
|
def test_calculate_bollinger_bands_empty(self): |
|
|
"""Test Bollinger Bands with empty series.""" |
|
|
empty_prices = pd.Series([], dtype=float) |
|
|
bands = calculate_bollinger_bands(empty_prices, 20) |
|
|
self.assertEqual(bands["position"], "neutral") |
|
|
self.assertEqual(bands["current_price"], 0) |
|
|
|
|
|
def test_calculate_bollinger_bands_overbought(self): |
|
|
"""Test Bollinger Bands when price is above upper band.""" |
|
|
|
|
|
high_prices = pd.Series([100 + i*2 for i in range(30)]) |
|
|
bands = calculate_bollinger_bands(high_prices, 20) |
|
|
|
|
|
if bands["current_price"] > bands["upper"]: |
|
|
self.assertEqual(bands["position"], "overbought") |
|
|
|
|
|
def test_calculate_bollinger_bands_oversold(self): |
|
|
"""Test Bollinger Bands when price is below lower band.""" |
|
|
|
|
|
low_prices = pd.Series([100 - i*2 for i in range(30)]) |
|
|
bands = calculate_bollinger_bands(low_prices, 20) |
|
|
|
|
|
if bands["current_price"] < bands["lower"]: |
|
|
self.assertEqual(bands["position"], "oversold") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|
|
|
|