Spaces:
Sleeping
Sleeping
File size: 5,108 Bytes
4be2d4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
์ฃผ์ ์์ธก ๋ชจ๋ธ ํ์ต ๋ฐ ์ต์ ํ ์คํฌ๋ฆฝํธ
"""
import sys
import os
import argparse
import numpy as np
import json
import pickle
import pandas as pd
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.optimization.grid_search import run_optimization_pipeline
from src.optimization.utils import get_project_root
def load_processed_data(tickers):
"""
์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์ ๋ฉํ ์ ๋ณด๋ฅผ ๋ก๋ํฉ๋๋ค.
"""
tickers_path = tickers.replace(',', '_') if ',' in tickers else tickers
data_dir = get_project_root() / "data"
processed_dir = data_dir / "processed"
processed_path = processed_dir / f"{tickers_path}_processed.pkl"
encoder_path = processed_dir / f"{tickers_path}_encoder_info.json"
metadata_path = processed_dir / f"{tickers_path}_metadata.json"
if not processed_path.exists():
print(f"์ค๋ฅ: ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ ํ์ผ์ด ์์ต๋๋ค: {processed_path}")
return None, None, None
# ๋ฐ์ดํฐ ๋ก๋
with open(processed_path, 'rb') as f:
processed_data = pickle.load(f)
# ์๋ณธ ๋ฐ์ดํฐ ์ถ๊ฐ
raw_data_path = data_dir / f"{tickers_path}_data.csv"
if raw_data_path.exists():
raw_data = pd.read_csv(raw_data_path)
processed_data['data'] = raw_data
# ์ธ์ฝ๋ ์ ๋ณด ๋ก๋
encoder_info = {}
if encoder_path.exists():
with open(encoder_path, 'r') as f:
encoder_info = json.load(f)
# ๋ฉํ๋ฐ์ดํฐ ๋ก๋
metadata = {}
if metadata_path.exists():
with open(metadata_path, 'r') as f:
metadata = json.load(f)
# ํ์ ์ ๋ณด๋ง ์ถ๋ ฅ
print(f"\n๋ฐ์ดํฐ์
: {', '.join(metadata.get('tickers', []))}")
print(f"๊ธฐ๊ฐ: {metadata.get('start_date', '')} ~ {metadata.get('end_date', '')}")
print(f"ํน์ฑ ์: {metadata.get('feature_count', '')}, ์๋์ฐ: {metadata.get('window_size', '')}")
return processed_data, encoder_info, metadata
def main():
parser = argparse.ArgumentParser(description="์ฃผ์ ์์ธก ๋ชจ๋ธ ์ต์ ํ ๋ฐ ํ๊ฐ")
parser.add_argument('--tickers', type=str, default='NFLX,TSLA,NVDA,AMD,INTC',
help='๋์ ์ข
๋ชฉ (์ฝค๋ง๋ก ๊ตฌ๋ถ)')
parser.add_argument('--save', action='store_true', default=True,
help='์ต์ ๋ชจ๋ธ ์ ์ฅ ์ฌ๋ถ')
parser.add_argument('--metric', type=str, default='combined_score',
choices=['combined_score', 'avg_ticker_sharpe', 'sharpe_ratio', 'total_return'],
help='์ต์ ํ ๊ธฐ์ค ์งํ')
parser.add_argument('--output', type=str, default='grid_search_results.pkl',
help='๊ฒฐ๊ณผ ํ์ผ ๊ฒฝ๋ก')
parser.add_argument('--model_output', type=str, default='best_contime',
help='TensorFlow Lite ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก')
parser.add_argument('--visualize', action='store_true', default=True,
help='์๊ฐํ ์คํ ์ฌ๋ถ')
args = parser.parse_args()
# ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ ๋ก๋
processed_data, encoder_info, metadata = load_processed_data(args.tickers)
if processed_data is None:
print("์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.")
return
# ํฐ์ปค ์ธ์ฝ๋ ์ถ์ถ
ticker_encoder = LabelEncoder()
if 'ticker_encoder' in encoder_info:
# JSON ํ์ผ์์ ๋ก๋ํ ๋์
๋๋ฆฌ๋ฅผ LabelEncoder๋ก ๋ณํ
ticker_mapping = encoder_info['ticker_encoder']
ticker_list = [ticker for _, ticker in sorted([(int(i), ticker) for i, ticker in ticker_mapping.items()])]
ticker_encoder.fit(ticker_list)
else:
# ํฐ์ปค ๋ฆฌ์คํธ ์ถ์ถ
ticker_list = metadata.get('tickers', args.tickers.split(','))
ticker_encoder.fit(ticker_list)
# ์์์ ์๋ ์ค์
np.random.seed(42)
# ์ต์ ํ ์คํ
print("\n๋ชจ๋ธ ์ต์ ํ ์์...")
print(f"์ต์ ํ ๊ธฐ์ค: {args.metric}")
# ๊ฒฐ๊ณผ ํ์ผ ์ ์ฅ ๊ฒฝ๋ก ์ค์
results_dir = Path(get_project_root()) / "models" / "results"
results_dir.mkdir(parents=True, exist_ok=True)
output_path = results_dir / args.output
model_output = results_dir / args.model_output
optimization_results = run_optimization_pipeline(
data_dict=processed_data,
ticker_encoder=ticker_encoder,
metric=args.metric,
output_path=output_path,
save=args.save,
model_output=model_output,
sector_industry_df=processed_data.get('sector_industry_df'),
run_visualizations=args.visualize
)
print("\n์ต์ ํ ์๋ฃ!")
# ๊ฒฐ๊ณผ ์์ฝ
best_config = optimization_results.get('best_config', {})
if best_config:
print("\n์ต์ ์ค์ ์์ฝ:")
for key, value in best_config.items():
print(f" {key}: {value}")
return optimization_results
if __name__ == "__main__":
main() |