|
|
""" |
|
|
Tests for the YFinanceDataFetcher class in src/yfinance.py |
|
|
|
|
|
These tests verify the core functionality of the YFinanceDataFetcher class, including: |
|
|
1. Initialization and configuration |
|
|
2. Data fetching and caching |
|
|
3. Error handling |
|
|
4. Data format and structure |
|
|
|
|
|
The tests use mocking to avoid actual API calls and to provide consistent test data. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import pandas as pd |
|
|
import pytest |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
|
|
|
from src.yfinance import YFinanceDataFetcher |
|
|
|
|
|
|
|
|
from tests.test_data.mock_stock_data import get_real_beta, get_real_data |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_ticker(): |
|
|
"""Create a mock Ticker object for yfinance.""" |
|
|
mock = MagicMock() |
|
|
|
|
|
|
|
|
df = get_real_data("AAPL", "1y") |
|
|
mock.history.return_value = df |
|
|
|
|
|
return mock |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_spy_ticker(): |
|
|
"""Create a mock Ticker object for SPY data.""" |
|
|
mock = MagicMock() |
|
|
|
|
|
|
|
|
df = get_real_data("SPY", "1y") |
|
|
mock.history.return_value = df |
|
|
|
|
|
return mock |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_empty_ticker(): |
|
|
"""Create a mock Ticker object with no data.""" |
|
|
mock = MagicMock() |
|
|
mock.history.return_value = pd.DataFrame() |
|
|
return mock |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def temp_cache_dir(tmpdir): |
|
|
"""Create a temporary directory for cache files.""" |
|
|
cache_dir = tmpdir.mkdir("test_cache") |
|
|
return str(cache_dir) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_dataframe(): |
|
|
"""Create a sample DataFrame with the expected structure using real data.""" |
|
|
return get_real_data("AAPL", "1y").head(5) |
|
|
|
|
|
|
|
|
class TestYFinanceDataFetcherInitialization: |
|
|
"""Tests for YFinanceDataFetcher initialization and configuration.""" |
|
|
|
|
|
def test_init_with_default_cache_dir(self): |
|
|
"""Test initialization with default cache directory.""" |
|
|
fetcher = YFinanceDataFetcher() |
|
|
assert fetcher.cache_dir == ".cache_yf" |
|
|
assert fetcher.cache_ttl == 86400 |
|
|
|
|
|
def test_init_with_custom_cache_dir(self, temp_cache_dir): |
|
|
"""Test initialization with custom cache directory.""" |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
assert fetcher.cache_dir == temp_cache_dir |
|
|
assert os.path.exists(temp_cache_dir) |
|
|
|
|
|
def test_init_with_custom_ttl(self): |
|
|
"""Test initialization with custom cache TTL.""" |
|
|
fetcher = YFinanceDataFetcher(cache_ttl=3600) |
|
|
assert fetcher.cache_ttl == 3600 |
|
|
|
|
|
|
|
|
class TestDataFetching: |
|
|
"""Tests for data fetching functionality.""" |
|
|
|
|
|
def test_fetch_data_api_call(self, mock_ticker, temp_cache_dir): |
|
|
"""Test fetching data from API.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
df = fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
assert isinstance(df, pd.DataFrame) |
|
|
assert len(df) > 0 |
|
|
|
|
|
|
|
|
required_columns = ["Open", "High", "Low", "Close", "Volume"] |
|
|
for col in required_columns: |
|
|
assert col in df.columns, f"Column {col} not found in DataFrame" |
|
|
|
|
|
assert df.index.name == "date" |
|
|
assert pd.api.types.is_datetime64_dtype(df.index) |
|
|
|
|
|
def test_fetch_data_cache_creation(self, mock_ticker, temp_cache_dir): |
|
|
"""Test that data is cached after fetching.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
cache_file = os.path.join(temp_cache_dir, "AAPL_1y_1d.csv") |
|
|
assert os.path.exists(cache_file) |
|
|
|
|
|
def test_fetch_data_from_cache(self, mock_ticker, temp_cache_dir, sample_dataframe): |
|
|
"""Test fetching data from cache.""" |
|
|
|
|
|
cache_file = os.path.join(temp_cache_dir, "AAPL_1y_1d.csv") |
|
|
sample_dataframe.to_csv(cache_file) |
|
|
|
|
|
|
|
|
os.utime(cache_file, (time.time(), time.time())) |
|
|
|
|
|
with patch("yfinance.Ticker", return_value=mock_ticker) as mock_yf: |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
df = fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
mock_yf.assert_not_called() |
|
|
|
|
|
|
|
|
pd.testing.assert_frame_equal(df, sample_dataframe) |
|
|
|
|
|
def test_fetch_data_expired_cache( |
|
|
self, mock_ticker, temp_cache_dir, sample_dataframe |
|
|
): |
|
|
"""Test fetching data with expired cache.""" |
|
|
|
|
|
cache_file = os.path.join(temp_cache_dir, "AAPL_1y_1d.csv") |
|
|
sample_dataframe.to_csv(cache_file) |
|
|
|
|
|
|
|
|
old_time = time.time() - 100000 |
|
|
os.utime(cache_file, (old_time, old_time)) |
|
|
|
|
|
with patch("yfinance.Ticker", return_value=mock_ticker) as mock_yf: |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
mock_yf.assert_called_once() |
|
|
|
|
|
def test_fetch_market_data(self, mock_spy_ticker, temp_cache_dir): |
|
|
"""Test fetching market data.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_spy_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
|
|
|
|
|
|
df = fetcher.fetch_market_data(market_index="SPY", period="1y") |
|
|
|
|
|
|
|
|
assert isinstance(df, pd.DataFrame) |
|
|
assert len(df) > 0 |
|
|
|
|
|
|
|
|
required_columns = ["Open", "High", "Low", "Close", "Volume"] |
|
|
for col in required_columns: |
|
|
assert col in df.columns, f"Column {col} not found in DataFrame" |
|
|
|
|
|
|
|
|
with patch.object(YFinanceDataFetcher, "beta_period", "6m"): |
|
|
df_default = fetcher.fetch_market_data(market_index="SPY") |
|
|
assert isinstance(df_default, pd.DataFrame) |
|
|
assert len(df_default) > 0 |
|
|
|
|
|
|
|
|
class TestErrorHandling: |
|
|
"""Tests for error handling in YFinanceDataFetcher.""" |
|
|
|
|
|
def test_empty_data_response(self, mock_empty_ticker, temp_cache_dir): |
|
|
"""Test handling of empty data responses.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_empty_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
with pytest.raises(ValueError, match="No historical data found"): |
|
|
fetcher.fetch_data("INVALID", period="1y") |
|
|
|
|
|
def test_network_error_with_fallback(self, temp_cache_dir, sample_dataframe): |
|
|
"""Test fallback to expired cache on network error.""" |
|
|
|
|
|
cache_file = os.path.join(temp_cache_dir, "AAPL_1y_1d.csv") |
|
|
sample_dataframe.to_csv(cache_file) |
|
|
|
|
|
|
|
|
old_time = time.time() - 100000 |
|
|
os.utime(cache_file, (old_time, old_time)) |
|
|
|
|
|
|
|
|
with patch("yfinance.Ticker", side_effect=Exception("Network error")): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
df = fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
pd.testing.assert_frame_equal(df, sample_dataframe) |
|
|
|
|
|
def test_network_error_without_fallback(self, temp_cache_dir): |
|
|
"""Test network error without cache fallback raises exception.""" |
|
|
|
|
|
with patch("yfinance.Ticker", side_effect=Exception("Network error")): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
with pytest.raises(ValueError): |
|
|
fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
class TestDataFormat: |
|
|
"""Tests for data format and structure.""" |
|
|
|
|
|
def test_date_parsing(self, mock_ticker, temp_cache_dir): |
|
|
"""Test that dates are properly parsed and set as index.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
df = fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
assert pd.api.types.is_datetime64_dtype(df.index) |
|
|
assert df.index.name == "date" |
|
|
|
|
|
|
|
|
def test_column_renaming(self, mock_ticker, temp_cache_dir): |
|
|
"""Test that columns are properly renamed.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
df = fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
required_columns = ["Open", "High", "Low", "Close", "Volume"] |
|
|
for col in required_columns: |
|
|
assert col in df.columns, f"Column {col} not found in DataFrame" |
|
|
|
|
|
def test_data_sorting(self, mock_ticker, temp_cache_dir): |
|
|
"""Test that data is sorted by date in ascending order.""" |
|
|
with patch("yfinance.Ticker", return_value=mock_ticker): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
df = fetcher.fetch_data("AAPL", period="1y") |
|
|
|
|
|
|
|
|
assert df.index.is_monotonic_increasing |
|
|
|
|
|
|
|
|
class TestPeriodHandling: |
|
|
"""Tests for period handling in YFinanceDataFetcher.""" |
|
|
|
|
|
def test_period_mapping(self): |
|
|
"""Test period mapping to yfinance format.""" |
|
|
fetcher = YFinanceDataFetcher() |
|
|
|
|
|
|
|
|
assert fetcher._map_period_to_yfinance("1y") == "1y" |
|
|
assert fetcher._map_period_to_yfinance("5y") == "5y" |
|
|
assert fetcher._map_period_to_yfinance("1d") == "1d" |
|
|
|
|
|
|
|
|
assert fetcher._map_period_to_yfinance("2y") == "2y" |
|
|
assert fetcher._map_period_to_yfinance("3y") == "5y" |
|
|
assert fetcher._map_period_to_yfinance("6m") == "6mo" |
|
|
assert fetcher._map_period_to_yfinance("3m") == "3mo" |
|
|
|
|
|
|
|
|
assert fetcher._map_period_to_yfinance("invalid") == "1y" |
|
|
|
|
|
|
|
|
class TestBetaCalculation: |
|
|
"""Tests for beta calculation using the YFinanceDataFetcher.""" |
|
|
|
|
|
def test_beta_calculation(self, mock_ticker, mock_spy_ticker, temp_cache_dir): |
|
|
"""Test beta calculation with mock data.""" |
|
|
with patch( |
|
|
"yfinance.Ticker", |
|
|
side_effect=lambda ticker: mock_spy_ticker |
|
|
if ticker == "SPY" |
|
|
else mock_ticker, |
|
|
): |
|
|
fetcher = YFinanceDataFetcher(cache_dir=temp_cache_dir) |
|
|
|
|
|
|
|
|
stock_data = fetcher.fetch_data("AAPL", period="1y") |
|
|
market_data = fetcher.fetch_market_data("SPY", period="1y") |
|
|
|
|
|
|
|
|
stock_returns = stock_data["Close"].pct_change().dropna() |
|
|
market_returns = market_data["Close"].pct_change().dropna() |
|
|
|
|
|
|
|
|
common_dates = stock_returns.index.intersection(market_returns.index) |
|
|
stock_returns = stock_returns.loc[common_dates] |
|
|
market_returns = market_returns.loc[common_dates] |
|
|
|
|
|
|
|
|
covariance = stock_returns.cov(market_returns) |
|
|
market_variance = market_returns.var() |
|
|
beta = covariance / market_variance |
|
|
|
|
|
|
|
|
get_real_beta("AAPL") |
|
|
|
|
|
|
|
|
|
|
|
assert 0.5 < beta < 2.0, f"Beta {beta} is outside reasonable range" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main(["-v", "test_yfinance.py"]) |
|
|
|