Spaces:
Sleeping
Sleeping
| """ | |
| ์ฃผ์ ์์ธก ๋ชจ๋ธ ํ์ต ๋ฐ ์ต์ ํ ์คํฌ๋ฆฝํธ | |
| """ | |
| import sys | |
| import os | |
| import argparse | |
| import numpy as np | |
| import json | |
| import pickle | |
| import pandas as pd | |
| from pathlib import Path | |
| from sklearn.preprocessing import LabelEncoder | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.optimization.grid_search import run_optimization_pipeline | |
| from src.optimization.utils import get_project_root | |
| def load_processed_data(tickers): | |
| """ | |
| ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์ ๋ฉํ ์ ๋ณด๋ฅผ ๋ก๋ํฉ๋๋ค. | |
| """ | |
| tickers_path = tickers.replace(',', '_') if ',' in tickers else tickers | |
| data_dir = get_project_root() / "data" | |
| processed_dir = data_dir / "processed" | |
| processed_path = processed_dir / f"{tickers_path}_processed.pkl" | |
| encoder_path = processed_dir / f"{tickers_path}_encoder_info.json" | |
| metadata_path = processed_dir / f"{tickers_path}_metadata.json" | |
| if not processed_path.exists(): | |
| print(f"์ค๋ฅ: ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ ํ์ผ์ด ์์ต๋๋ค: {processed_path}") | |
| return None, None, None | |
| # ๋ฐ์ดํฐ ๋ก๋ | |
| with open(processed_path, 'rb') as f: | |
| processed_data = pickle.load(f) | |
| # ์๋ณธ ๋ฐ์ดํฐ ์ถ๊ฐ | |
| raw_data_path = data_dir / f"{tickers_path}_data.csv" | |
| if raw_data_path.exists(): | |
| raw_data = pd.read_csv(raw_data_path) | |
| processed_data['data'] = raw_data | |
| # ์ธ์ฝ๋ ์ ๋ณด ๋ก๋ | |
| encoder_info = {} | |
| if encoder_path.exists(): | |
| with open(encoder_path, 'r') as f: | |
| encoder_info = json.load(f) | |
| # ๋ฉํ๋ฐ์ดํฐ ๋ก๋ | |
| metadata = {} | |
| if metadata_path.exists(): | |
| with open(metadata_path, 'r') as f: | |
| metadata = json.load(f) | |
| # ํ์ ์ ๋ณด๋ง ์ถ๋ ฅ | |
| print(f"\n๋ฐ์ดํฐ์ : {', '.join(metadata.get('tickers', []))}") | |
| print(f"๊ธฐ๊ฐ: {metadata.get('start_date', '')} ~ {metadata.get('end_date', '')}") | |
| print(f"ํน์ฑ ์: {metadata.get('feature_count', '')}, ์๋์ฐ: {metadata.get('window_size', '')}") | |
| return processed_data, encoder_info, metadata | |
| def main(): | |
| parser = argparse.ArgumentParser(description="์ฃผ์ ์์ธก ๋ชจ๋ธ ์ต์ ํ ๋ฐ ํ๊ฐ") | |
| parser.add_argument('--tickers', type=str, default='NFLX,TSLA,NVDA,AMD,INTC', | |
| help='๋์ ์ข ๋ชฉ (์ฝค๋ง๋ก ๊ตฌ๋ถ)') | |
| parser.add_argument('--save', action='store_true', default=True, | |
| help='์ต์ ๋ชจ๋ธ ์ ์ฅ ์ฌ๋ถ') | |
| parser.add_argument('--metric', type=str, default='combined_score', | |
| choices=['combined_score', 'avg_ticker_sharpe', 'sharpe_ratio', 'total_return'], | |
| help='์ต์ ํ ๊ธฐ์ค ์งํ') | |
| parser.add_argument('--output', type=str, default='grid_search_results.pkl', | |
| help='๊ฒฐ๊ณผ ํ์ผ ๊ฒฝ๋ก') | |
| parser.add_argument('--model_output', type=str, default='best_contime', | |
| help='TensorFlow Lite ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก') | |
| parser.add_argument('--visualize', action='store_true', default=True, | |
| help='์๊ฐํ ์คํ ์ฌ๋ถ') | |
| args = parser.parse_args() | |
| # ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ ๋ก๋ | |
| processed_data, encoder_info, metadata = load_processed_data(args.tickers) | |
| if processed_data is None: | |
| print("์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.") | |
| return | |
| # ํฐ์ปค ์ธ์ฝ๋ ์ถ์ถ | |
| ticker_encoder = LabelEncoder() | |
| if 'ticker_encoder' in encoder_info: | |
| # JSON ํ์ผ์์ ๋ก๋ํ ๋์ ๋๋ฆฌ๋ฅผ LabelEncoder๋ก ๋ณํ | |
| ticker_mapping = encoder_info['ticker_encoder'] | |
| ticker_list = [ticker for _, ticker in sorted([(int(i), ticker) for i, ticker in ticker_mapping.items()])] | |
| ticker_encoder.fit(ticker_list) | |
| else: | |
| # ํฐ์ปค ๋ฆฌ์คํธ ์ถ์ถ | |
| ticker_list = metadata.get('tickers', args.tickers.split(',')) | |
| ticker_encoder.fit(ticker_list) | |
| # ์์์ ์๋ ์ค์ | |
| np.random.seed(42) | |
| # ์ต์ ํ ์คํ | |
| print("\n๋ชจ๋ธ ์ต์ ํ ์์...") | |
| print(f"์ต์ ํ ๊ธฐ์ค: {args.metric}") | |
| # ๊ฒฐ๊ณผ ํ์ผ ์ ์ฅ ๊ฒฝ๋ก ์ค์ | |
| results_dir = Path(get_project_root()) / "models" / "results" | |
| results_dir.mkdir(parents=True, exist_ok=True) | |
| output_path = results_dir / args.output | |
| model_output = results_dir / args.model_output | |
| optimization_results = run_optimization_pipeline( | |
| data_dict=processed_data, | |
| ticker_encoder=ticker_encoder, | |
| metric=args.metric, | |
| output_path=output_path, | |
| save=args.save, | |
| model_output=model_output, | |
| sector_industry_df=processed_data.get('sector_industry_df'), | |
| run_visualizations=args.visualize | |
| ) | |
| print("\n์ต์ ํ ์๋ฃ!") | |
| # ๊ฒฐ๊ณผ ์์ฝ | |
| best_config = optimization_results.get('best_config', {}) | |
| if best_config: | |
| print("\n์ต์ ์ค์ ์์ฝ:") | |
| for key, value in best_config.items(): | |
| print(f" {key}: {value}") | |
| return optimization_results | |
| if __name__ == "__main__": | |
| main() |