Spaces:
Sleeping
Sleeping
File size: 5,139 Bytes
4be2d4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | """
๋ฐ์ดํฐ ํตํฉ ๋ฐ ์ ์ฒ๋ฆฌ ๋ชจ๋
"""
import pandas as pd
import yfinance as yf
from .technical_indicators import add_technical_indicators
from .optimization import run_technical_optimization
from .financial_data import process_financial_data
from .economic_data import get_economic_data
from .hierarchical_embedding import get_industry_data, add_industry_encoding, combine_stocks_for_embedding
from src.config import FRED_API_KEY
def process_stock_data(tickers, start_date, end_date, fred_api_key=None):
"""
์ฃผ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ์ ์ ์ฒด ํ์ดํ๋ผ์ธ์ ์คํํ๋ ํจ์
"""
# API ํค ๊ธฐ๋ณธ๊ฐ ์ฒ๋ฆฌ
if fred_api_key is None:
fred_api_key = FRED_API_KEY
print(f"์ฃผ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์์: {len(tickers)}๊ฐ ์ข
๋ชฉ")
# 1. ๊ธฐ์ ์ ์งํ ์ต์ ํ
optimal_params = run_technical_optimization(tickers, start_date, end_date)
# 2. ์ต์ ํ๋ ํ๋ผ๋ฏธํฐ๋ก ๋ฐ์ดํฐ์
์์ฑ
all_data = {}
for ticker in tickers:
df = yf.download(ticker, start=start_date, end=end_date, auto_adjust=True)
if isinstance(df.columns, pd.MultiIndex):
df.columns = df.columns.droplevel(1)
# ์ต์ ํ๋ ํ๋ผ๋ฏธํฐ๋ก ๊ธฐ์ ์ ์งํ ์ถ๊ฐ
df_with_indicators = add_technical_indicators(
df.copy(),
ema_params=optimal_params['ema'],
macd_params=optimal_params['macd'],
cmf_period=optimal_params['cmf'],
rsi_params=optimal_params['rsi']
)
# ๊ฒฐ์ธก์น ์ ๊ฑฐ
df_with_indicators = df_with_indicators.dropna()
all_data[ticker] = df_with_indicators
print(f"{ticker} ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์๋ฃ: {len(df_with_indicators)}ํ")
# 3. ์ฌ๋ฌด์ ํ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ๋ฐ ํตํฉ
print("\n===== ์ฌ๋ฌด์ ํ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์์ =====\n")
financial_data_all = {}
common_features = None
# ๋ชจ๋ ์ข
๋ชฉ์ ์ฌ๋ฌด ๋ฐ์ดํฐ ์์ง
for ticker in tickers:
if ticker not in all_data:
continue
financial_data = process_financial_data(ticker, all_data, end_date)
if financial_data is not None and not financial_data.empty:
financial_features = [col for col in financial_data.columns if col != 'Close']
if financial_features:
financial_data_all[ticker] = financial_data
# ๊ณตํต ํน์ฑ ์ถ์
if common_features is None:
common_features = set(financial_features)
else:
common_features = common_features.intersection(set(financial_features))
# 4. ๊ฒฝ์ ์งํ ๋ฐ์ดํฐ ์ถ๊ฐ
print("\n===== ๊ฒฝ์ ์งํ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์์ =====\n")
econ_df = get_economic_data(start_date, end_date, fred_api_key)
# 5. ๋ฐ์ดํฐ ์ต์ข
ํตํฉ (๊ธฐ์ ์งํ + ์ฌ๋ฌด์ ํ + ๊ฒฝ์ ์งํ)
print("\n===== ์ต์ข
๋ฐ์ดํฐ ํตํฉ =====\n")
# ๊ณตํต ์ฌ๋ฌด ํน์ฑ ์ ํ
if common_features:
selected_common_features = list(common_features)[:20] # ์์ 20๊ฐ
print(f"๋ชจ๋ ์ข
๋ชฉ์ ๊ณตํต์ธ ์ฌ๋ฌด ํน์ฑ ์ค {len(selected_common_features)}๊ฐ ์ ํ")
else:
selected_common_features = []
print("๊ณตํต ์ฌ๋ฌด ํน์ฑ์ด ์์ต๋๋ค.")
# ๊ฐ ์ข
๋ชฉ๋ณ ๋ฐ์ดํฐ ํตํฉ
for ticker in tickers:
if ticker in all_data:
try:
# ๊ธฐ๋ณธ ๋ฐ์ดํฐ (๊ธฐ์ ์ ์งํ ํฌํจ)
stock_data = all_data[ticker].copy()
# ์ฌ๋ฌด์ ํ ๋ฐ์ดํฐ ์ถ๊ฐ (์ ํ๋ ๊ณตํต ํน์ฑ๋ง)
if ticker in financial_data_all and selected_common_features:
fin_data = financial_data_all[ticker][selected_common_features]
stock_data = stock_data.join(fin_data, how='left')
# ๊ฒฝ์ ์งํ ๋ฐ์ดํฐ ์ถ๊ฐ
stock_data = stock_data.join(econ_df, how='left')
# ๊ฒฐ์ธก์น ์ฒ๋ฆฌ
for col in stock_data.columns:
if col != 'Close' and stock_data[col].isna().any():
stock_data[col] = stock_data[col].interpolate(method='linear')
stock_data = stock_data.dropna()
# ์ต์ข
๋ฐ์ดํฐ ์ ์ฅ
all_data[ticker] = stock_data
print(f"{ticker} ๋ฐ์ดํฐ ํตํฉ ์๋ฃ: {stock_data.shape[1]}๊ฐ ํน์ฑ")
except Exception as e:
print(f"{ticker} ๋ฐ์ดํฐ ํตํฉ ์คํจ: {e}")
# 6. ์ฐ์
์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ ๋ฐ ๋ฐ์ดํฐ ํตํฉ
industry_df = get_industry_data(tickers)
# 7. ์ข
๋ชฉ ์๋ฒ ๋ฉ์ ์ํ ๋ฐ์ดํฐ ํตํฉ
combined_data = combine_stocks_for_embedding(all_data, tickers)
# 8. ์ฐ์
์ ๋ณด ์ธ์ฝ๋ฉ ์ถ๊ฐ
final_data, industry_encoders = add_industry_encoding(combined_data, industry_df)
return final_data, all_data, industry_encoders |