johnaness's picture
Deploy OStock FastAPI backend to HF Space (Docker SDK, port 7860)
4be2d4d
"""
์ฃผ์‹ ์˜ˆ์ธก ๋ชจ๋ธ ํ•™์Šต ๋ฐ ์ตœ์ ํ™” ์Šคํฌ๋ฆฝํŠธ
"""
import sys
import os
import argparse
import numpy as np
import json
import pickle
import pandas as pd
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.optimization.grid_search import run_optimization_pipeline
from src.optimization.utils import get_project_root
def load_processed_data(tickers):
"""
์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์™€ ๋ฉ”ํƒ€ ์ •๋ณด๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"""
tickers_path = tickers.replace(',', '_') if ',' in tickers else tickers
data_dir = get_project_root() / "data"
processed_dir = data_dir / "processed"
processed_path = processed_dir / f"{tickers_path}_processed.pkl"
encoder_path = processed_dir / f"{tickers_path}_encoder_info.json"
metadata_path = processed_dir / f"{tickers_path}_metadata.json"
if not processed_path.exists():
print(f"์˜ค๋ฅ˜: ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค: {processed_path}")
return None, None, None
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
with open(processed_path, 'rb') as f:
processed_data = pickle.load(f)
# ์›๋ณธ ๋ฐ์ดํ„ฐ ์ถ”๊ฐ€
raw_data_path = data_dir / f"{tickers_path}_data.csv"
if raw_data_path.exists():
raw_data = pd.read_csv(raw_data_path)
processed_data['data'] = raw_data
# ์ธ์ฝ”๋” ์ •๋ณด ๋กœ๋“œ
encoder_info = {}
if encoder_path.exists():
with open(encoder_path, 'r') as f:
encoder_info = json.load(f)
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ๋กœ๋“œ
metadata = {}
if metadata_path.exists():
with open(metadata_path, 'r') as f:
metadata = json.load(f)
# ํ•„์ˆ˜ ์ •๋ณด๋งŒ ์ถœ๋ ฅ
print(f"\n๋ฐ์ดํ„ฐ์…‹: {', '.join(metadata.get('tickers', []))}")
print(f"๊ธฐ๊ฐ„: {metadata.get('start_date', '')} ~ {metadata.get('end_date', '')}")
print(f"ํŠน์„ฑ ์ˆ˜: {metadata.get('feature_count', '')}, ์œˆ๋„์šฐ: {metadata.get('window_size', '')}")
return processed_data, encoder_info, metadata
def main():
parser = argparse.ArgumentParser(description="์ฃผ์‹ ์˜ˆ์ธก ๋ชจ๋ธ ์ตœ์ ํ™” ๋ฐ ํ‰๊ฐ€")
parser.add_argument('--tickers', type=str, default='NFLX,TSLA,NVDA,AMD,INTC',
help='๋Œ€์ƒ ์ข…๋ชฉ (์ฝค๋งˆ๋กœ ๊ตฌ๋ถ„)')
parser.add_argument('--save', action='store_true', default=True,
help='์ตœ์  ๋ชจ๋ธ ์ €์žฅ ์—ฌ๋ถ€')
parser.add_argument('--metric', type=str, default='combined_score',
choices=['combined_score', 'avg_ticker_sharpe', 'sharpe_ratio', 'total_return'],
help='์ตœ์ ํ™” ๊ธฐ์ค€ ์ง€ํ‘œ')
parser.add_argument('--output', type=str, default='grid_search_results.pkl',
help='๊ฒฐ๊ณผ ํŒŒ์ผ ๊ฒฝ๋กœ')
parser.add_argument('--model_output', type=str, default='best_contime',
help='TensorFlow Lite ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ')
parser.add_argument('--visualize', action='store_true', default=True,
help='์‹œ๊ฐํ™” ์‹คํ–‰ ์—ฌ๋ถ€')
args = parser.parse_args()
# ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
processed_data, encoder_info, metadata = load_processed_data(args.tickers)
if processed_data is None:
print("์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
return
# ํ‹ฐ์ปค ์ธ์ฝ”๋” ์ถ”์ถœ
ticker_encoder = LabelEncoder()
if 'ticker_encoder' in encoder_info:
# JSON ํŒŒ์ผ์—์„œ ๋กœ๋“œํ•œ ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ LabelEncoder๋กœ ๋ณ€ํ™˜
ticker_mapping = encoder_info['ticker_encoder']
ticker_list = [ticker for _, ticker in sorted([(int(i), ticker) for i, ticker in ticker_mapping.items()])]
ticker_encoder.fit(ticker_list)
else:
# ํ‹ฐ์ปค ๋ฆฌ์ŠคํŠธ ์ถ”์ถœ
ticker_list = metadata.get('tickers', args.tickers.split(','))
ticker_encoder.fit(ticker_list)
# ์ž„์˜์˜ ์‹œ๋“œ ์„ค์ •
np.random.seed(42)
# ์ตœ์ ํ™” ์‹คํ–‰
print("\n๋ชจ๋ธ ์ตœ์ ํ™” ์‹œ์ž‘...")
print(f"์ตœ์ ํ™” ๊ธฐ์ค€: {args.metric}")
# ๊ฒฐ๊ณผ ํŒŒ์ผ ์ €์žฅ ๊ฒฝ๋กœ ์„ค์ •
results_dir = Path(get_project_root()) / "models" / "results"
results_dir.mkdir(parents=True, exist_ok=True)
output_path = results_dir / args.output
model_output = results_dir / args.model_output
optimization_results = run_optimization_pipeline(
data_dict=processed_data,
ticker_encoder=ticker_encoder,
metric=args.metric,
output_path=output_path,
save=args.save,
model_output=model_output,
sector_industry_df=processed_data.get('sector_industry_df'),
run_visualizations=args.visualize
)
print("\n์ตœ์ ํ™” ์™„๋ฃŒ!")
# ๊ฒฐ๊ณผ ์š”์•ฝ
best_config = optimization_results.get('best_config', {})
if best_config:
print("\n์ตœ์  ์„ค์ • ์š”์•ฝ:")
for key, value in best_config.items():
print(f" {key}: {value}")
return optimization_results
if __name__ == "__main__":
main()