Spaces:
Sleeping
Sleeping
| """ | |
| ์ ์ฅ๋ TensorFlow Lite ๋ชจ๋ธ๊ณผ ์ค์ ์ ์ฌ์ฉํ์ฌ ์ข ๋ชฉ์ ๋ํ ๋ค์ ๋ ์์ธก ์ํ | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| # NOTE: tensorflow ๋ ๋ฐ๋์ pandas/numpy ๋ณด๋ค ๋จผ์ import ํด์ผ ํ๋ค. | |
| # pandas/numpy ๊ฐ OpenMP ๋ฐํ์(libiomp5md.dll)์ ๋จผ์ ์ฌ๋ฆฌ๋ฉด tensorflow ๋ค์ดํฐ๋ธ | |
| # DLL(_pywrap_tensorflow_internal) ์ด๊ธฐํ๊ฐ ์ค๋ณต ์ถฉ๋ํ์ฌ WinError 1114 ๋ก ์ฃฝ๋๋ค. | |
| import tensorflow as tf | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| import subprocess | |
| import pickle | |
| import warnings | |
| # ๊ฒฝ๊ณ ๋ฌด์ | |
| warnings.filterwarnings('ignore', category=UserWarning, module='tensorflow') | |
| # ๋ชจ๋ ๊ฒฝ๋ก ์ถ๊ฐ | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.optimization.utils import predict_with_tflite | |
| def get_project_root(): | |
| """ํ๋ก์ ํธ ๋ฃจํธ ๋๋ ํ ๋ฆฌ๋ฅผ ๋ฐํํฉ๋๋ค.""" | |
| return Path(__file__).parent.parent | |
| def check_and_data(tickers): | |
| """ | |
| ์ฃผ์ ๋ฐ์ดํฐ๊ฐ ์กด์ฌํ๋์ง ํ์ธํ๊ณ , ์์ผ๋ฉด data.py ์คํฌ๋ฆฝํธ๋ก ๊ฐ์ ธ์ต๋๋ค. | |
| """ | |
| ticker_list = tickers.split('_') if '_' in tickers else [tickers] | |
| data_dir = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) / 'data' | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| missing_tickers = [ | |
| ticker for ticker in ticker_list | |
| if not (data_dir / f'{ticker}_data.csv').exists() | |
| ] | |
| if not missing_tickers: | |
| return True | |
| print(f"๋๋ฝ๋ ์ข ๋ชฉ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ: {', '.join(missing_tickers)}") | |
| missing_str = '_'.join(missing_tickers) | |
| try: | |
| script_path = Path(__file__).parent / 'data.py' | |
| result = subprocess.run( | |
| [sys.executable, str(script_path), '--tickers', missing_str], | |
| check=True, capture_output=True, text=True | |
| ) | |
| print(result.stdout) | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| print(f"๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ ์คํจ: {e}") | |
| print(f"์ค๋ฅ ์ถ๋ ฅ: {e.stderr}") | |
| return False | |
| def load_model_and_config(model_path, config_path): | |
| """ | |
| TensorFlow Lite ๋ชจ๋ธ, ์ค์ , ์ธ์ฝ๋ ์ ๋ณด๋ฅผ ๋ก๋ํฉ๋๋ค. | |
| """ | |
| try: | |
| # ์ค์ ํ์ผ ๋ก๋ | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| # ์ธ์ฝ๋ ์ ๋ณด ๋ก๋ | |
| encoders = None | |
| model_path_obj = Path(model_path) | |
| encoder_path = model_path_obj.with_suffix('').with_name(f"{model_path_obj.stem}_encoders.json") | |
| if os.path.exists(encoder_path): | |
| with open(encoder_path, 'r') as f: | |
| encoders = json.load(f) | |
| # TensorFlow Lite ๋ชจ๋ธ ๋ก๋ | |
| print(f"๋ชจ๋ธ ๋ก๋ ์ค: {Path(model_path).name}") | |
| interpreter = tf.lite.Interpreter(model_path=str(model_path)) | |
| interpreter.allocate_tensors() | |
| return interpreter, config, encoders | |
| except Exception as e: | |
| print(f"๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| return None, config, encoders | |
| def predict_next_day(model_path, config_path, ticker, output_file='next_day_prediction.csv'): | |
| """ | |
| ๋ค์ ๋ ์ฃผ๊ฐ ์์ธก ํจ์ | |
| """ | |
| # ๋ชจ๋ธ, ์ค์ , ์ธ์ฝ๋ ๋ก๋ | |
| model_or_interpreter, cfg, encoders = load_model_and_config(model_path, config_path) | |
| if model_or_interpreter is None or cfg is None: | |
| print("๋ชจ๋ธ ๋ก๋์ ์คํจํ์ต๋๋ค.") | |
| return None | |
| threshold = cfg.get('best_threshold', cfg.get('threshold', 0.0)) | |
| # ํ์ผ ๊ฒฝ๋ก ์ค์ | |
| tickers_path = ticker.replace(',', '_') if ',' in ticker else ticker | |
| data_dir = get_project_root() / "data" | |
| processed_dir = data_dir / "processed" | |
| processed_path = processed_dir / f"{tickers_path}_processed.pkl" | |
| # ์ ์ฒ๋ฆฌ๋ ํ์ผ์ด ์์ ๊ฒฝ์ฐ data.py ์คํ | |
| if not processed_path.exists(): | |
| try: | |
| script_path = Path(__file__).parent / 'data.py' | |
| result = subprocess.run( | |
| [sys.executable, str(script_path), '--tickers', ticker], | |
| check=True, capture_output=True, text=True | |
| ) | |
| except: | |
| return None | |
| # ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ ๋ก๋ | |
| try: | |
| print(f"{ticker} ๋ฐ์ดํฐ ๋ก๋ ์ค...") | |
| with open(processed_path, 'rb') as f: | |
| processed_data = pickle.load(f) | |
| # ๋ฐ์ดํฐ ๋์ ๋๋ฆฌ ์ถ์ถ | |
| data_dict = processed_data[0] if isinstance(processed_data, tuple) else processed_data | |
| # ์ํ์ค ๋ฐ์ดํฐ ์ ํ (ํ ์คํธ > ๊ฒ์ฆ > ํ๋ จ ์) | |
| if 'x_test' in data_dict and len(data_dict['x_test']) > 0: | |
| source = 'test' | |
| elif 'x_val' in data_dict and len(data_dict['x_val']) > 0: | |
| source = 'val' | |
| elif 'x_train' in data_dict and len(data_dict['x_train']) > 0: | |
| source = 'train' | |
| else: | |
| return None | |
| # ํ์ํ ๋ฐ์ดํฐ ์ถ์ถ | |
| last_sequence = data_dict[f'x_{source}'][-1:] | |
| last_ticker = data_dict[f'ticker_{source}'][-1] | |
| last_time_diff = data_dict[f'time_diffs_{source}'][-1:] | |
| # ์นํฐ/์ฐ์ ์ ๋ณด (์์ผ๋ฉด ์ฌ์ฉ, ์์ผ๋ฉด 0์ผ๋ก ์ค์ ) | |
| sector_id = 0 | |
| industry_id = 0 | |
| if f'sector_{source}' in data_dict: | |
| sector_id = data_dict[f'sector_{source}'][-1] | |
| if f'industry_{source}' in data_dict: | |
| industry_id = data_dict[f'industry_{source}'][-1] | |
| print(f"์์ธก ์ํ ์ค...") | |
| if hasattr(model_or_interpreter, 'predict'): | |
| inputs = [ | |
| tf.cast(last_sequence, tf.float32), | |
| tf.cast(np.array([last_ticker]), tf.int32), | |
| tf.cast(np.array([sector_id]), tf.int32), | |
| tf.cast(np.array([industry_id]), tf.int32), | |
| tf.cast(last_time_diff, tf.float32), | |
| ] | |
| y_pred_all = model_or_interpreter.predict(inputs, verbose=0) | |
| y_pred = y_pred_all[0] if isinstance(y_pred_all, list) else y_pred_all | |
| else: # TensorFlow Lite ์ธํฐํ๋ฆฌํฐ | |
| # ์ ๋ ฅ ํ ์ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| input_details = model_or_interpreter.get_input_details() | |
| # ํ์ํ ๊ฒฝ์ฐ ์ ๋ ฅ ๋ฐ์ดํฐ ์ฌ๊ตฌ์ฑ | |
| inputs = [] | |
| for i, detail in enumerate(input_details): | |
| name = detail['name'].lower() if hasattr(detail['name'], 'lower') else "" | |
| if 'time' in name: | |
| inputs.append(last_time_diff.astype(np.float32)) | |
| elif 'ticker' in name: | |
| inputs.append(np.array([last_ticker], dtype=np.int32)) | |
| elif 'industry' in name: | |
| inputs.append(np.array([industry_id], dtype=np.int32)) | |
| elif 'sector' in name: | |
| inputs.append(np.array([sector_id], dtype=np.int32)) | |
| else: | |
| seq_data = last_sequence.astype(np.float32) | |
| if len(detail['shape']) == 4 and len(seq_data.shape) == 3: | |
| seq_data = np.expand_dims(seq_data, axis=-1) | |
| inputs.append(seq_data) | |
| if not inputs or len(inputs) != len(input_details): | |
| inputs = [ | |
| last_time_diff.astype(np.float32), | |
| last_sequence.astype(np.float32), | |
| np.array([last_ticker], dtype=np.int32), | |
| np.array([industry_id], dtype=np.int32), | |
| np.array([sector_id], dtype=np.int32) | |
| ] | |
| # TensorFlow Lite ๋ชจ๋ธ๋ก ์์ธก | |
| y_pred = predict_with_tflite(model_or_interpreter, inputs, verbose=False) | |
| if y_pred is None: | |
| return None | |
| # ์์ธก๊ฐ ์ถ์ถ | |
| if isinstance(y_pred, list): | |
| value_output = y_pred[0] | |
| if len(value_output.shape) == 3: | |
| pred_value = float(value_output[0, -1, 0]) | |
| elif len(value_output.shape) == 2: | |
| pred_value = float(value_output[0, 0]) | |
| else: | |
| pred_value = float(value_output.flatten()[-1]) | |
| elif hasattr(y_pred, 'shape'): | |
| if len(y_pred.shape) == 3: | |
| pred_value = float(y_pred[0, -1, 0]) | |
| elif y_pred.shape == (1,): | |
| pred_value = float(y_pred[0]) | |
| elif y_pred.shape == (1, 1): | |
| pred_value = float(y_pred[0, 0]) | |
| else: | |
| pred_value = float(y_pred.flatten()[-1]) | |
| else: | |
| pred_value = float(y_pred) | |
| signal = 'BUY' if pred_value > threshold else 'SELL' if pred_value < -threshold else 'HOLD' | |
| confidence = abs(pred_value) | |
| # ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| print(f"\n===== {ticker} ๋ค์ ๋ ์์ธก =====") | |
| print(f"์์ธก๊ฐ: {pred_value:.6f}") | |
| print(f"์๊ณ๊ฐ: {threshold:.6f}") | |
| print(f"์ ํธ: {signal}") | |
| print(f"์ ๋ขฐ๋: {confidence:.6f}") | |
| # ๊ฒฐ๊ณผ๋ฅผ CSV๋ก ์ ์ฅ | |
| results_df = pd.DataFrame({ | |
| 'ticker': [ticker], | |
| 'prediction_date': [datetime.now().strftime('%Y-%m-%d')], | |
| 'predicted_value': [pred_value], | |
| 'threshold': [threshold], | |
| 'signal': [signal], | |
| 'confidence': [confidence] | |
| }) | |
| # ๋ชจ๋ธ ๋๋ ํ ๋ฆฌ์ ์ ์ฅ | |
| models_dir = get_project_root() / "models" | |
| models_dir.mkdir(exist_ok=True) | |
| output_path = models_dir / output_file | |
| results_df.to_csv(output_path, index=False) | |
| print(f"๊ฒฐ๊ณผ ์ ์ฅ: {output_path}") | |
| return results_df | |
| except Exception as e: | |
| return None | |
| def main(): | |
| parser = argparse.ArgumentParser(description="์ ์ฅ๋ TensorFlow Lite ๋ชจ๋ธ๋ก ๋ค์ ๋ ์ฃผ๊ฐ ์์ธก") | |
| parser.add_argument('--model', type=str, default='models/best_contime_grid_search.tflite', | |
| help='์ ์ฅ๋ TensorFlow Lite ๋ชจ๋ธ ๊ฒฝ๋ก') | |
| parser.add_argument('--config', type=str, default='models/results/best_contime_grid_search_meta.json', | |
| help='์ ์ฅ๋ ์ค์ ํ์ผ ๊ฒฝ๋ก') | |
| parser.add_argument('--tickers', type=str, required=True, | |
| help='์์ธกํ ์ข ๋ชฉ (๋จ์ผ ์ข ๋ชฉ)') | |
| parser.add_argument('--output', type=str, default='predictions.csv', | |
| help='์์ธก ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก') | |
| args = parser.parse_args() | |
| # ๋ฐ์ดํฐ ํ์ผ์ด ์๋์ง ํ์ธํ๊ณ , ์์ผ๋ฉด ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ | |
| if not check_and_data(args.tickers): | |
| print("๋ฐ์ดํฐ ์ค๋น์ ์คํจํ์ต๋๋ค.") | |
| return | |
| # ๋ค์ค ์ข ๋ชฉ ์ฒ๋ฆฌ ๋ฐฉ์ง | |
| if '_' in args.tickers: | |
| print("๋จ์ผ ์ข ๋ชฉ๋ง ์์ธก ๊ฐ๋ฅํฉ๋๋ค. ์ฌ๋ฌ ์ข ๋ชฉ์ ๊ฐ๋ณ์ ์ผ๋ก ์คํํด์ฃผ์ธ์.") | |
| return | |
| # ๋ค์๋ ์์ธก ์ํ | |
| predict_next_day( | |
| model_path=args.model, | |
| config_path=args.config, | |
| ticker=args.tickers, | |
| output_file=args.output | |
| ) | |
| if __name__ == "__main__": | |
| main() | |
| # ----- ํฌํธํด๋ฆฌ์ค ์ฑ๋ฅ ----- | |
| # ํ ์คํธ ์ธํธ ์ด ์์ต๋ฅ : 0.2360 | |
| # ํ ์คํธ ์ธํธ ์คํ ๋น์จ: 0.0472 | |
| # ํ ์คํธ ์ธํธ ์ต๋ ๋ํญ: -0.1322 | |
| # ํ ์คํธ ์ธํธ ๊ฑฐ๋ ์: 54 | |
| # ----- ๊ฐ๋ณ ์ข ๋ชฉ ํ๊ท ์ฑ๋ฅ ----- | |
| # ํ ์คํธ ์ธํธ ํ๊ท ์ข ๋ชฉ ์์ต๋ฅ : 0.2360 | |
| # ํ ์คํธ ์ธํธ ํ๊ท ์ข ๋ชฉ ์คํ ๋น์จ: 0.0457 |