ostock-backend / model /scripts /predict.py
johnaness's picture
Deploy OStock FastAPI backend to HF Space (Docker SDK, port 7860)
4be2d4d
"""
์ €์žฅ๋œ TensorFlow Lite ๋ชจ๋ธ๊ณผ ์„ค์ •์„ ์‚ฌ์šฉํ•˜์—ฌ ์ข…๋ชฉ์— ๋Œ€ํ•œ ๋‹ค์Œ ๋‚  ์˜ˆ์ธก ์ˆ˜ํ–‰
"""
import os
import sys
import json
import argparse
# NOTE: tensorflow ๋Š” ๋ฐ˜๋“œ์‹œ pandas/numpy ๋ณด๋‹ค ๋จผ์ € import ํ•ด์•ผ ํ•œ๋‹ค.
# pandas/numpy ๊ฐ€ OpenMP ๋Ÿฐํƒ€์ž„(libiomp5md.dll)์„ ๋จผ์ € ์˜ฌ๋ฆฌ๋ฉด tensorflow ๋„ค์ดํ‹ฐ๋ธŒ
# DLL(_pywrap_tensorflow_internal) ์ดˆ๊ธฐํ™”๊ฐ€ ์ค‘๋ณต ์ถฉ๋Œํ•˜์—ฌ WinError 1114 ๋กœ ์ฃฝ๋Š”๋‹ค.
import tensorflow as tf
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
import subprocess
import pickle
import warnings
# ๊ฒฝ๊ณ  ๋ฌด์‹œ
warnings.filterwarnings('ignore', category=UserWarning, module='tensorflow')
# ๋ชจ๋“ˆ ๊ฒฝ๋กœ ์ถ”๊ฐ€
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.optimization.utils import predict_with_tflite
def get_project_root():
"""ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
return Path(__file__).parent.parent
def check_and_data(tickers):
"""
์ฃผ์‹ ๋ฐ์ดํ„ฐ๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธํ•˜๊ณ , ์—†์œผ๋ฉด data.py ์Šคํฌ๋ฆฝํŠธ๋กœ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
"""
ticker_list = tickers.split('_') if '_' in tickers else [tickers]
data_dir = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) / 'data'
data_dir.mkdir(parents=True, exist_ok=True)
missing_tickers = [
ticker for ticker in ticker_list
if not (data_dir / f'{ticker}_data.csv').exists()
]
if not missing_tickers:
return True
print(f"๋ˆ„๋ฝ๋œ ์ข…๋ชฉ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ: {', '.join(missing_tickers)}")
missing_str = '_'.join(missing_tickers)
try:
script_path = Path(__file__).parent / 'data.py'
result = subprocess.run(
[sys.executable, str(script_path), '--tickers', missing_str],
check=True, capture_output=True, text=True
)
print(result.stdout)
return True
except subprocess.CalledProcessError as e:
print(f"๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ ์‹คํŒจ: {e}")
print(f"์˜ค๋ฅ˜ ์ถœ๋ ฅ: {e.stderr}")
return False
def load_model_and_config(model_path, config_path):
"""
TensorFlow Lite ๋ชจ๋ธ, ์„ค์ •, ์ธ์ฝ”๋” ์ •๋ณด๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"""
try:
# ์„ค์ • ํŒŒ์ผ ๋กœ๋“œ
with open(config_path, 'r') as f:
config = json.load(f)
# ์ธ์ฝ”๋” ์ •๋ณด ๋กœ๋“œ
encoders = None
model_path_obj = Path(model_path)
encoder_path = model_path_obj.with_suffix('').with_name(f"{model_path_obj.stem}_encoders.json")
if os.path.exists(encoder_path):
with open(encoder_path, 'r') as f:
encoders = json.load(f)
# TensorFlow Lite ๋ชจ๋ธ ๋กœ๋“œ
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {Path(model_path).name}")
interpreter = tf.lite.Interpreter(model_path=str(model_path))
interpreter.allocate_tensors()
return interpreter, config, encoders
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
return None, config, encoders
def predict_next_day(model_path, config_path, ticker, output_file='next_day_prediction.csv'):
"""
๋‹ค์Œ ๋‚  ์ฃผ๊ฐ€ ์˜ˆ์ธก ํ•จ์ˆ˜
"""
# ๋ชจ๋ธ, ์„ค์ •, ์ธ์ฝ”๋” ๋กœ๋“œ
model_or_interpreter, cfg, encoders = load_model_and_config(model_path, config_path)
if model_or_interpreter is None or cfg is None:
print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
return None
threshold = cfg.get('best_threshold', cfg.get('threshold', 0.0))
# ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ •
tickers_path = ticker.replace(',', '_') if ',' in ticker else ticker
data_dir = get_project_root() / "data"
processed_dir = data_dir / "processed"
processed_path = processed_dir / f"{tickers_path}_processed.pkl"
# ์ „์ฒ˜๋ฆฌ๋œ ํŒŒ์ผ์ด ์—†์„ ๊ฒฝ์šฐ data.py ์‹คํ–‰
if not processed_path.exists():
try:
script_path = Path(__file__).parent / 'data.py'
result = subprocess.run(
[sys.executable, str(script_path), '--tickers', ticker],
check=True, capture_output=True, text=True
)
except:
return None
# ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
try:
print(f"{ticker} ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์ค‘...")
with open(processed_path, 'rb') as f:
processed_data = pickle.load(f)
# ๋ฐ์ดํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ ์ถ”์ถœ
data_dict = processed_data[0] if isinstance(processed_data, tuple) else processed_data
# ์‹œํ€€์Šค ๋ฐ์ดํ„ฐ ์„ ํƒ (ํ…Œ์ŠคํŠธ > ๊ฒ€์ฆ > ํ›ˆ๋ จ ์ˆœ)
if 'x_test' in data_dict and len(data_dict['x_test']) > 0:
source = 'test'
elif 'x_val' in data_dict and len(data_dict['x_val']) > 0:
source = 'val'
elif 'x_train' in data_dict and len(data_dict['x_train']) > 0:
source = 'train'
else:
return None
# ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
last_sequence = data_dict[f'x_{source}'][-1:]
last_ticker = data_dict[f'ticker_{source}'][-1]
last_time_diff = data_dict[f'time_diffs_{source}'][-1:]
# ์„นํ„ฐ/์‚ฐ์—… ์ •๋ณด (์žˆ์œผ๋ฉด ์‚ฌ์šฉ, ์—†์œผ๋ฉด 0์œผ๋กœ ์„ค์ •)
sector_id = 0
industry_id = 0
if f'sector_{source}' in data_dict:
sector_id = data_dict[f'sector_{source}'][-1]
if f'industry_{source}' in data_dict:
industry_id = data_dict[f'industry_{source}'][-1]
print(f"์˜ˆ์ธก ์ˆ˜ํ–‰ ์ค‘...")
if hasattr(model_or_interpreter, 'predict'):
inputs = [
tf.cast(last_sequence, tf.float32),
tf.cast(np.array([last_ticker]), tf.int32),
tf.cast(np.array([sector_id]), tf.int32),
tf.cast(np.array([industry_id]), tf.int32),
tf.cast(last_time_diff, tf.float32),
]
y_pred_all = model_or_interpreter.predict(inputs, verbose=0)
y_pred = y_pred_all[0] if isinstance(y_pred_all, list) else y_pred_all
else: # TensorFlow Lite ์ธํ„ฐํ”„๋ฆฌํ„ฐ
# ์ž…๋ ฅ ํ…์„œ ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
input_details = model_or_interpreter.get_input_details()
# ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์žฌ๊ตฌ์„ฑ
inputs = []
for i, detail in enumerate(input_details):
name = detail['name'].lower() if hasattr(detail['name'], 'lower') else ""
if 'time' in name:
inputs.append(last_time_diff.astype(np.float32))
elif 'ticker' in name:
inputs.append(np.array([last_ticker], dtype=np.int32))
elif 'industry' in name:
inputs.append(np.array([industry_id], dtype=np.int32))
elif 'sector' in name:
inputs.append(np.array([sector_id], dtype=np.int32))
else:
seq_data = last_sequence.astype(np.float32)
if len(detail['shape']) == 4 and len(seq_data.shape) == 3:
seq_data = np.expand_dims(seq_data, axis=-1)
inputs.append(seq_data)
if not inputs or len(inputs) != len(input_details):
inputs = [
last_time_diff.astype(np.float32),
last_sequence.astype(np.float32),
np.array([last_ticker], dtype=np.int32),
np.array([industry_id], dtype=np.int32),
np.array([sector_id], dtype=np.int32)
]
# TensorFlow Lite ๋ชจ๋ธ๋กœ ์˜ˆ์ธก
y_pred = predict_with_tflite(model_or_interpreter, inputs, verbose=False)
if y_pred is None:
return None
# ์˜ˆ์ธก๊ฐ’ ์ถ”์ถœ
if isinstance(y_pred, list):
value_output = y_pred[0]
if len(value_output.shape) == 3:
pred_value = float(value_output[0, -1, 0])
elif len(value_output.shape) == 2:
pred_value = float(value_output[0, 0])
else:
pred_value = float(value_output.flatten()[-1])
elif hasattr(y_pred, 'shape'):
if len(y_pred.shape) == 3:
pred_value = float(y_pred[0, -1, 0])
elif y_pred.shape == (1,):
pred_value = float(y_pred[0])
elif y_pred.shape == (1, 1):
pred_value = float(y_pred[0, 0])
else:
pred_value = float(y_pred.flatten()[-1])
else:
pred_value = float(y_pred)
signal = 'BUY' if pred_value > threshold else 'SELL' if pred_value < -threshold else 'HOLD'
confidence = abs(pred_value)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print(f"\n===== {ticker} ๋‹ค์Œ ๋‚  ์˜ˆ์ธก =====")
print(f"์˜ˆ์ธก๊ฐ’: {pred_value:.6f}")
print(f"์ž„๊ณ„๊ฐ’: {threshold:.6f}")
print(f"์‹ ํ˜ธ: {signal}")
print(f"์‹ ๋ขฐ๋„: {confidence:.6f}")
# ๊ฒฐ๊ณผ๋ฅผ CSV๋กœ ์ €์žฅ
results_df = pd.DataFrame({
'ticker': [ticker],
'prediction_date': [datetime.now().strftime('%Y-%m-%d')],
'predicted_value': [pred_value],
'threshold': [threshold],
'signal': [signal],
'confidence': [confidence]
})
# ๋ชจ๋ธ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ
models_dir = get_project_root() / "models"
models_dir.mkdir(exist_ok=True)
output_path = models_dir / output_file
results_df.to_csv(output_path, index=False)
print(f"๊ฒฐ๊ณผ ์ €์žฅ: {output_path}")
return results_df
except Exception as e:
return None
def main():
parser = argparse.ArgumentParser(description="์ €์žฅ๋œ TensorFlow Lite ๋ชจ๋ธ๋กœ ๋‹ค์Œ ๋‚  ์ฃผ๊ฐ€ ์˜ˆ์ธก")
parser.add_argument('--model', type=str, default='models/best_contime_grid_search.tflite',
help='์ €์žฅ๋œ TensorFlow Lite ๋ชจ๋ธ ๊ฒฝ๋กœ')
parser.add_argument('--config', type=str, default='models/results/best_contime_grid_search_meta.json',
help='์ €์žฅ๋œ ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ')
parser.add_argument('--tickers', type=str, required=True,
help='์˜ˆ์ธกํ•  ์ข…๋ชฉ (๋‹จ์ผ ์ข…๋ชฉ)')
parser.add_argument('--output', type=str, default='predictions.csv',
help='์˜ˆ์ธก ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ')
args = parser.parse_args()
# ๋ฐ์ดํ„ฐ ํŒŒ์ผ์ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๊ณ , ์—†์œผ๋ฉด ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
if not check_and_data(args.tickers):
print("๋ฐ์ดํ„ฐ ์ค€๋น„์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
return
# ๋‹ค์ค‘ ์ข…๋ชฉ ์ฒ˜๋ฆฌ ๋ฐฉ์ง€
if '_' in args.tickers:
print("๋‹จ์ผ ์ข…๋ชฉ๋งŒ ์˜ˆ์ธก ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ ์ข…๋ชฉ์€ ๊ฐœ๋ณ„์ ์œผ๋กœ ์‹คํ–‰ํ•ด์ฃผ์„ธ์š”.")
return
# ๋‹ค์Œ๋‚  ์˜ˆ์ธก ์ˆ˜ํ–‰
predict_next_day(
model_path=args.model,
config_path=args.config,
ticker=args.tickers,
output_file=args.output
)
if __name__ == "__main__":
main()
# ----- ํฌํŠธํด๋ฆฌ์˜ค ์„ฑ๋Šฅ -----
# ํ…Œ์ŠคํŠธ ์„ธํŠธ ์ด ์ˆ˜์ต๋ฅ : 0.2360
# ํ…Œ์ŠคํŠธ ์„ธํŠธ ์ƒคํ”„ ๋น„์œจ: 0.0472
# ํ…Œ์ŠคํŠธ ์„ธํŠธ ์ตœ๋Œ€ ๋‚™ํญ: -0.1322
# ํ…Œ์ŠคํŠธ ์„ธํŠธ ๊ฑฐ๋ž˜ ์ˆ˜: 54
# ----- ๊ฐœ๋ณ„ ์ข…๋ชฉ ํ‰๊ท  ์„ฑ๋Šฅ -----
# ํ…Œ์ŠคํŠธ ์„ธํŠธ ํ‰๊ท  ์ข…๋ชฉ ์ˆ˜์ต๋ฅ : 0.2360
# ํ…Œ์ŠคํŠธ ์„ธํŠธ ํ‰๊ท  ์ข…๋ชฉ ์ƒคํ”„ ๋น„์œจ: 0.0457