AlanRex commited on
Commit
b3b1c36
·
verified ·
1 Parent(s): 8f928f7

Upload predictor_py.py

Browse files
Files changed (1) hide show
  1. predictor_py.py +125 -0
predictor_py.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """predictor.py
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1AHdt8sCCYTJXRC_ENNsgMFv7rK_R0Oa4
8
+ """
9
+
10
+ # predictor.py
11
+
12
+ import pandas as pd
13
+ import numpy as np
14
+ from tensorflow.keras.models import load_model
15
+ import joblib
16
+ import os
17
+
18
+ # --- 全域變數,定義模型和相關檔案的路徑 ---
19
+ # 這樣可以方便地在同一個地方管理檔案名稱
20
+ MODEL_PATH = 'stock_lstm_model_v2.keras'
21
+ SCALER_PATH = 'data_scaler_v2.gz'
22
+
23
+ # 這個列表是關鍵:它定義了 v2 模型訓練時所使用的所有特徵
24
+ # 順序也必須和訓練時完全一樣
25
+ FEATURES_USED = [
26
+ 'open', 'high', 'low', 'close', 'volume',
27
+ 'MACD', 'KD_K', 'KD_D', 'RSI',
28
+ 'DMI_ADX',
29
+ 'Index'
30
+ ]
31
+ TARGET_FEATURE = 'close' # 預測的目標
32
+
33
+ # --- 內部輔助函式,用於載入元件 ---
34
+ def _load_predictor_components():
35
+ """
36
+ 載入模型和 Scaler。這是一個內部函式,外部不應該直接呼叫。
37
+ """
38
+ # 檢查檔案是否存在
39
+ if not os.path.exists(MODEL_PATH) or not os.path.exists(SCALER_PATH):
40
+ print(f"錯誤:找不到必要的檔案。請確保 '{MODEL_PATH}' 和 '{SCALER_PATH}' 存在於同一個目錄下。")
41
+ return None, None
42
+
43
+ try:
44
+ model = load_model(MODEL_PATH)
45
+ scaler = joblib.load(SCALER_PATH)
46
+ print("模型和 Scaler 載入成功!")
47
+ return model, scaler
48
+ except Exception as e:
49
+ print(f"載入預測器組件時發生錯誤: {e}")
50
+ return None, None
51
+
52
+ # --- 這是您唯一需要呼叫的主函式 ---
53
+ def predict_next_day_price(df_stock_original: pd.DataFrame, df_climate_original: pd.DataFrame) -> (float | None):
54
+ """
55
+ 接收原始的台積電股價 DataFrame 和景氣燈號 DataFrame,
56
+ 執行所有必要的預處理步驟,並回傳下一日的預測收盤價。
57
+
58
+ Args:
59
+ df_stock_original (pd.DataFrame): 包含 'Date' 和其他技術指標的原始股價資料。
60
+ df_climate_original (pd.DataFrame): 包含 'Date' 和 'Index' 的原始景氣燈號資料。
61
+
62
+ Returns:
63
+ float: 預測的收盤價。如果發生錯誤則回傳 None。
64
+ """
65
+ model, scaler = _load_predictor_components()
66
+ if model is None or scaler is None:
67
+ return None
68
+
69
+ # --- 1. 資料整合與預處理 (將模型訓練時的邏輯完全複製於此) ---
70
+ try:
71
+ # 複製一份資料以避免修改到原始傳入的 DataFrame
72
+ df_stock = df_stock_original.copy()
73
+ df_climate = df_climate_original.copy()
74
+
75
+ df_stock['Date'] = pd.to_datetime(df_stock['Date'])
76
+ df_stock.set_index('Date', inplace=True)
77
+ df_climate['Date'] = pd.to_datetime(df_climate['Date'], format='%Y-%m')
78
+ df_climate.set_index('Date', inplace=True)
79
+
80
+ df_merged = pd.merge(df_stock, df_climate, left_index=True, right_index=True, how='left')
81
+ df_merged['Index'].fillna(method='ffill', inplace=True)
82
+
83
+ # 確保所有需要的特徵都存在
84
+ if not all(feature in df_merged.columns for feature in FEATURES_USED):
85
+ missing = [f for f in FEATURES_USED if f not in df_merged.columns]
86
+ print(f"錯誤:輸入的資料中缺少模型所需特徵: {missing}")
87
+ return None
88
+
89
+ df_featured = df_merged[FEATURES_USED].copy()
90
+ df_featured.dropna(inplace=True)
91
+
92
+ # 檢查是否有足夠的資料 (至少需要 look_back 筆)
93
+ look_back = 60
94
+ if len(df_featured) < look_back:
95
+ print(f"錯誤:資料不足。模型需要至少 {look_back} 筆完整的歷史資料來進行預測,但只得到 {len(df_featured)} 筆。")
96
+ return None
97
+
98
+ # --- 2. 準備輸入模型的資料 ---
99
+ # 取得最新的 60 筆資料
100
+ last_60_days = df_featured[-look_back:].values
101
+
102
+ # 使用【已載入的】scaler 來轉換資料
103
+ scaled_last_60_days = scaler.transform(last_60_days)
104
+
105
+ # 重塑為模型期望的形狀: [1, 60, 特徵數量]
106
+ X_predict = np.reshape(scaled_last_60_days, (1, look_back, len(FEATURES_USED)))
107
+
108
+ # --- 3. 進行預測與反標準化 ---
109
+ predicted_price_scaled = model.predict(X_predict)
110
+
111
+ # 建立一個和 scaler 相同結構的 dummy 陣列來進行反標準化
112
+ dummy_array = np.zeros(shape=(1, len(FEATURES_USED)))
113
+
114
+ # 找到 'close' 在特徵列表中的位置
115
+ target_col_index = FEATURES_USED.index(TARGET_FEATURE)
116
+ dummy_array[0, target_col_index] = predicted_price_scaled[0, 0]
117
+
118
+ # 反標準化,取得真實的預測股價
119
+ predicted_price = scaler.inverse_transform(dummy_array)[0, target_col_index]
120
+
121
+ return float(predicted_price)
122
+
123
+ except Exception as e:
124
+ print(f"預測過程中發生未預期的錯誤: {e}")
125
+ return None