AlanRex commited on
Commit
f9e7f22
·
verified ·
1 Parent(s): 05993cc

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +329 -184
model_predictor.py CHANGED
@@ -1,276 +1,421 @@
1
  # model_predictor.py - 支援漲幅百分比輸出的XGBoost模型預測器
2
  # 修改版本:輸出改為漲幅百分比而非絕對價格
3
 
 
 
4
  import os
5
- import pandas as pd
6
  import numpy as np
 
7
  import xgboost as xgb
8
- from sklearn.preprocessing import StandardScaler
9
- import pickle
10
  import joblib
 
 
11
 
12
  class XGBoostModel:
13
  def __init__(self):
14
  """
15
- 初始化 XGBoost 模型預測器
16
-
17
- 【重要更新】
18
- - 模型現在輸出漲幅百分比而非絕對價格
19
- - 支援 1日、5日、10日、20日的漲幅預測
20
  """
21
- self.model = None
22
- self.scaler = None
23
  self.feature_columns = [
24
- 'close', # 前一日收盤價
25
- 'return_t-1', # 前一日報酬率
26
- 'return_t-5', # 過去 5 日累積報酬率
27
- 'MA5_close', # 5 日移動平均價
28
- 'volatility_5d', # 5 日報酬標準差
29
- 'volume_ratio_5d', # 今日成交量 ÷ 5 日均量
30
- 'MACD_diff', # MACD - signal
31
- 'dji_return_t-1', # 前一日道瓊指數報酬率
32
- 'sox_return_t-1', # 前一日費半指數報酬率
33
- 'NEWS' # 新聞情緒分數
 
 
 
 
34
  ]
35
 
36
- # 【新增】輸出目標對應表
37
- self.output_targets = {
38
- 1: 'Change_pct_t1_pred', # 1天後漲幅%
39
- 5: 'Change_pct_t5_pred', # 5天後漲幅%
40
- 10: 'Change_pct_t10_pred', # 10天後漲幅%
41
- 20: 'Change_pct_t20_pred' # 20天後漲幅%
42
  }
43
 
44
- print("XGBoost 模型預測器初始化完成")
45
- print("輸出格式:漲幅百分比 (1日, 5日, 10日, 20日)")
 
 
 
 
 
46
 
47
- def load_model(self, model_path):
48
  """
49
- 載入預訓練的 XGBoost 模型
 
50
 
51
  Args:
52
- model_path (str): 模型檔案路徑 (.json 格式)
53
-
54
  Returns:
55
- bool: 是否成功載入
56
  """
57
- try:
58
- # 檢查模型檔案是否存在
59
- if not os.path.exists(model_path):
60
- print(f"錯誤:找不到模型檔案 {model_path}")
61
- return False
62
-
63
- # 載入 XGBoost 模型
64
- self.model = xgb.XGBRegressor()
65
- self.model.load_model(model_path)
66
-
67
- print(f"成功載入模型:{model_path}")
68
- print(f"預期特徵數量:{len(self.feature_columns)}")
69
-
70
- return True
71
-
72
- except Exception as e:
73
- print(f"載入模型時發生錯誤:{e}")
74
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def load_scaler(self, scaler_path):
77
  """
78
- 載入特徵標準化器
79
 
80
  Args:
81
- scaler_path (str): 標準化器檔案路徑 (.pkl 格式)
82
-
83
  Returns:
84
- bool: 是否成功載入
85
  """
86
  try:
87
- if os.path.exists(scaler_path):
88
- self.scaler = joblib.load(scaler_path)
89
- print(f"成功載入標準化器:{scaler_path}")
90
- return True
 
91
  else:
92
- print(f"警告:找不到標準化器檔案 {scaler_path}")
93
- print("將使用預設標準化器")
94
- self.scaler = StandardScaler()
95
  return False
96
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
- print(f"載入標準化器時發生錯誤:{e}")
99
- self.scaler = StandardScaler()
100
  return False
101
 
102
- def preprocess_features(self, input_df):
103
  """
104
- 預處理輸入特徵
105
 
106
  Args:
107
- input_df (pd.DataFrame): 輸入特徵 DataFrame
 
108
 
109
  Returns:
110
- pd.DataFrame: 預處理後的特徵
111
  """
 
 
 
 
112
  try:
113
- # 確保輸入包含所有必要特徵
114
- missing_features = [f for f in self.feature_columns if f not in input_df.columns]
 
 
 
 
 
 
 
 
 
 
115
  if missing_features:
116
- print(f"警告:缺少以下特徵:{missing_features}")
117
- # 用 0 填補缺少的特徵
118
- for feature in missing_features:
119
- input_df[feature] = 0
120
 
121
- # 按照預期順序重新排列特徵
122
- input_df = input_df[self.feature_columns]
123
 
124
- # 處理 NaN 值
125
- input_df = input_df.fillna(0)
 
 
126
 
127
- # 如果有標準化器,進行標準化
128
  if self.scaler is not None:
129
- try:
130
- # 嘗試使用已訓練的標準化器
131
- scaled_features = self.scaler.transform(input_df)
132
- input_df = pd.DataFrame(scaled_features,
133
- columns=input_df.columns,
134
- index=input_df.index)
135
- except Exception as scaler_error:
136
- print(f"標準化過程發生錯誤:{scaler_error}")
137
- print("跳過標準化步驟")
138
-
139
- return input_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  except Exception as e:
142
- print(f"特徵預處理時發生錯誤:{e}")
143
- return input_df
144
 
145
- def predict(self, model_name, input_df):
146
  """
147
- 進行股價漲幅預測
148
 
149
  Args:
150
- model_name (str): 模型名稱(用於載入對應模型)
151
- input_df (pd.DataFrame): 輸入特徵
152
-
 
 
153
  Returns:
154
- dict: 預測結果,包含各時間點的漲幅百分比
155
  """
156
  try:
157
- # 載入模型(如果尚未載入)
158
- if self.model is None:
159
- model_path = f"{model_name}.json"
160
- if not self.load_model(model_path):
161
- return None
162
 
163
- # 載入標準化器(如果存在)
164
- if self.scaler is None:
165
- scaler_path = f"{model_name}_scaler.pkl"
166
- self.load_scaler(scaler_path)
167
 
168
- # 預處理特徵
169
- processed_df = self.preprocess_features(input_df.copy())
170
 
171
- # 進行預測
172
- predictions = self.model.predict(processed_df)
 
 
 
 
 
 
 
173
 
174
- # 【重要修改】將預測結果格式化為漲幅百分比
175
- if predictions.ndim == 1:
176
- # 如果只有一個輸出,假設是 1 日預測
177
- result = {
178
- 'Change_pct_t1_pred': float(predictions[0])
179
- }
 
 
 
 
 
 
180
  else:
181
- # 多輸出情況:1日, 5日, 10日, 20日
182
- result = {
183
- 'Change_pct_t1_pred': float(predictions[0][0]) if len(predictions[0]) > 0 else 0.0,
184
- 'Change_pct_t5_pred': float(predictions[0][1]) if len(predictions[0]) > 1 else 0.0,
185
- 'Change_pct_t10_pred': float(predictions[0][2]) if len(predictions[0]) > 2 else 0.0,
186
- 'Change_pct_t20_pred': float(predictions[0][3]) if len(predictions[0]) > 3 else 0.0
187
- }
188
-
189
- # 輸出預測結果摘要
190
- print("=== 漲幅預測結果 ===")
191
- for key, value in result.items():
192
- days = key.split('_')[2][1:] # 提取天數
193
- direction = "上漲" if value > 0 else "下跌"
194
- print(f" {days}日後預測: {value:+.2f}% ({direction})")
195
-
196
- return result
197
-
198
  except Exception as e:
199
- print(f"預測過程中發生錯誤:{e}")
200
- import traceback
201
- traceback.print_exc()
202
- return None
203
 
204
- def predict_single_timeframe(self, model_name, input_df, days):
205
  """
206
- 預測特定時間框架的漲幅
207
 
208
  Args:
209
- model_name (str): 模型名稱
210
- input_df (pd.DataFrame): 輸入特徵
211
- days (int): 預測天數 (1, 5, 10, 20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  Returns:
214
- float: 預測的漲幅百分比
215
  """
 
 
 
216
  try:
217
- predictions = self.predict(model_name, input_df)
218
- if predictions is None:
219
- return None
220
 
221
- # 根據天數選擇對應的預測結果
222
- target_key = f'Change_pct_t{days}_pred'
 
 
 
 
 
223
 
224
- if target_key in predictions:
225
- return predictions[target_key]
226
- else:
227
- print(f"警告:找不到 {days} 日預測結果")
228
- return None
229
-
230
  except Exception as e:
231
- print(f"單一時間框架預測時發生錯誤:{e}")
232
- return None
233
 
234
- def get_prediction_confidence(self, input_df):
235
  """
236
- 評估預測的信心度
237
 
238
  Args:
239
- input_df (pd.DataFrame): 輸入特徵
240
 
241
  Returns:
242
- float: 信心度 (0-1)
243
  """
244
  try:
245
- # 基於特徵完整性和質量評估信心度
246
- feature_completeness = 0
247
- total_features = len(self.feature_columns)
248
-
249
- for feature in self.feature_columns:
250
- if feature in input_df.columns:
251
- value = input_df[feature].iloc[0]
252
- if not pd.isna(value) and value != 0:
253
- feature_completeness += 1
254
 
255
- completeness_ratio = feature_completeness / total_features
 
256
 
257
- # 基於數據質量調整信心度
258
- base_confidence = max(0.5, completeness_ratio)
259
 
260
- # 如果重要特徵缺失,降低信心度
261
- important_features = ['close', 'return_t-1', 'MA5_close']
262
- missing_important = 0
263
- for feature in important_features:
264
- if feature not in input_df.columns or pd.isna(input_df[feature].iloc[0]):
265
- missing_important += 1
266
 
267
- if missing_important > 0:
268
- base_confidence *= (1 - missing_important * 0.1)
269
 
270
- return min(0.9, max(0.3, base_confidence))
271
 
272
  except Exception as e:
273
- print(f"計算信心度時發生錯誤:{e}")
274
  return 0.5
275
 
276
  def validate_input(self, input_df):
 
1
  # model_predictor.py - 支援漲幅百分比輸出的XGBoost模型預測器
2
  # 修改版本:輸出改為漲幅百分比而非絕對價格
3
 
4
+ # model_predictor.py - 修正版本,對應訓練腳本的確切配置
5
+
6
  import os
 
7
  import numpy as np
8
+ import pandas as pd
9
  import xgboost as xgb
10
+ from sklearn.preprocessing import MinMaxScaler
 
11
  import joblib
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
 
15
  class XGBoostModel:
16
  def __init__(self):
17
  """
18
+ 初始化 XGBoost 模型類別
19
+ 根據訓練腳本 xgboost_for_stock_trend_&_prices_prediction_gpu_v_2_1_3.py 的配置
 
 
 
20
  """
21
+ # 根據訓練腳本的 new_feature_columns,確保順序完全一致
 
22
  self.feature_columns = [
23
+ 'close', # 前一日收盤價
24
+ 'return_t-1', # 前一日報酬率
25
+ 'return_t-5', # 過去 5 日累積報酬率
26
+ 'MA5_close', # 5 日移動平均價
27
+ 'volatility_5d', # 5 日報酬標準差
28
+ 'volume_ratio_5d', # 今日成交量 ÷ 5 日均量
29
+ 'MACD_diff', # MACD - signal
30
+ 'dji_return_t-1', # 前一日道瓊指數報酬率
31
+ 'sox_return_t-1', # 前一日費半指數報酬率
32
+ 'NEWS', # 新聞情緒分數
33
+ 'MACDvol', # MACD柱狀圖
34
+ 'RSI_14', # 14日RSI
35
+ 'ADX', # ADX指標
36
+ 'volume_weighted_return' # 成交量加權報酬率
37
  ]
38
 
39
+ # 預測目標對應(根據訓練腳本的 train_y)
40
+ self.prediction_mapping = {
41
+ 'Change_pct_t1_pred': 1, # 1天後漲幅%
42
+ 'Change_pct_t5_pred': 5, # 5天後漲幅%
43
+ 'Change_pct_t10_pred': 10, # 10天後漲幅%
44
+ 'Change_pct_t20_pred': 20 # 20天後漲幅%
45
  }
46
 
47
+ self.model = None
48
+ self.scaler = None
49
+ self.is_model_loaded = False
50
+
51
+ # 模型檔案路徑
52
+ self.model_path = 'xgboost_model.json'
53
+ self.scaler_path = 'feature_scaler.pkl'
54
 
55
+ def create_features_from_stock_data(self, stock_data):
56
  """
57
+ 從股票資料創建所需的特徵
58
+ 完全對應訓練腳本中的 create_new_features 函數
59
 
60
  Args:
61
+ stock_data: yfinance 格式的股票資料 DataFrame
62
+
63
  Returns:
64
+ processed_df: 包含所有特徵的 DataFrame
65
  """
66
+ df = stock_data.copy()
67
+
68
+ # 確保必要的基礎欄位存在
69
+ required_base_columns = ['Close', 'Volume', 'High', 'Low']
70
+ for col in required_base_columns:
71
+ if col not in df.columns:
72
+ raise ValueError(f"缺少必要的基礎欄位: {col}")
73
+
74
+ # 統一欄位名稱(yfinance 使用大寫)
75
+ df['close'] = df['Close']
76
+ df['volume'] = df['Volume']
77
+
78
+ # 1. return_t-1 — 前一日報酬率
79
+ df['return_t-1'] = df['close'].pct_change()
80
+
81
+ # 2. return_t-5 — 過去 5 日累積報酬率
82
+ df['return_t-5'] = (df['close'] / df['close'].shift(5) - 1)
83
+
84
+ # 3. MA5_close — 5 日移動平均價
85
+ df['MA5_close'] = df['close'].rolling(window=5).mean()
86
+
87
+ # 4. volatility_5d — 5 日報酬標準差
88
+ df['volatility_5d'] = df['return_t-1'].rolling(window=5).std()
89
+
90
+ # 5. volume_ratio_5d — 今日成交量 ÷ 5 日均量
91
+ df['volume_5d_avg'] = df['volume'].rolling(window=5).mean()
92
+ df['volume_ratio_5d'] = df['volume'] / df['volume_5d_avg']
93
+
94
+ # 6. MACD_diff — MACD - signal
95
+ exp1 = df['close'].ewm(span=12).mean()
96
+ exp2 = df['close'].ewm(span=26).mean()
97
+ macd_line = exp1 - exp2
98
+ signal_line = macd_line.ewm(span=9).mean()
99
+ df['MACD_diff'] = macd_line - signal_line
100
+
101
+ # 7-8. 美股指數報酬率(需要外部資料,暫設為0)
102
+ df['dji_return_t-1'] = 0.0 # 這需要從外部獲取道瓊指數資料
103
+ df['sox_return_t-1'] = 0.0 # 這需要從外部獲取費半指數資料
104
+
105
+ # 9. NEWS — 新聞情緒分數(需要外部資料,暫設為0)
106
+ df['NEWS'] = 0.0
107
+
108
+ # 10. MACDvol — MACD柱狀圖
109
+ df['MACDvol'] = macd_line - signal_line
110
+
111
+ # 11. RSI_14 — 14日RSI
112
+ delta = df['close'].diff()
113
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
114
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
115
+ rs = gain / loss
116
+ df['RSI_14'] = 100 - (100 / (1 + rs))
117
+
118
+ # 12. ADX — 平均趨向指標
119
+ df['up_move'] = df['High'] - df['High'].shift(1)
120
+ df['down_move'] = df['Low'].shift(1) - df['Low']
121
+ df['+DM'] = np.where((df['up_move'] > df['down_move']) & (df['up_move'] > 0), df['up_move'], 0)
122
+ df['-DM'] = np.where((df['down_move'] > df['up_move']) & (df['down_move'] > 0), df['down_move'], 0)
123
+
124
+ high_low = df['High'] - df['Low']
125
+ high_close_prev = np.abs(df['High'] - df['close'].shift(1))
126
+ low_close_prev = np.abs(df['Low'] - df['close'].shift(1))
127
+ df['TR'] = np.maximum.reduce([high_low, high_close_prev, low_close_prev])
128
+
129
+ df['+DI'] = (df['+DM'].ewm(com=13, adjust=False).mean() / df['TR'].ewm(com=13, adjust=False).mean()) * 100
130
+ df['-DI'] = (df['-DM'].ewm(com=13, adjust=False).mean() / df['TR'].ewm(com=13, adjust=False).mean()) * 100
131
+ df['DX'] = np.abs(df['+DI'] - df['-DI']) / (df['+DI'] + df['-DI']) * 100
132
+ df['ADX'] = df['DX'].ewm(com=13, adjust=False).mean()
133
+
134
+ # 13. volume_weighted_return — 成交量加權報酬率
135
+ df['volume_weighted_return'] = np.abs(df['return_t-1']) * df['volume']
136
+
137
+ # 清理輔助欄位
138
+ cleanup_columns = ['volume_5d_avg', 'up_move', 'down_move', '+DM', '-DM', 'TR', '+DI', '-DI', 'DX']
139
+ df.drop(columns=[col for col in cleanup_columns if col in df.columns], inplace=True)
140
+
141
+ # 填補 NaN 值
142
+ df.fillna(method='ffill', inplace=True)
143
+ df.fillna(0, inplace=True) # 剩餘的 NaN 用 0 填補
144
+
145
+ return df
146
 
147
+ def load_model(self, model_name='xgboost_model'):
148
  """
149
+ 載入訓練好的模型和標準化器
150
 
151
  Args:
152
+ model_name: 模型名稱
153
+
154
  Returns:
155
+ bool: 載入是否成功
156
  """
157
  try:
158
+ # 載入 XGBoost 模型
159
+ if os.path.exists(self.model_path):
160
+ self.model = xgb.XGBRegressor()
161
+ self.model.load_model(self.model_path)
162
+ print(f"成功載入模型: {self.model_path}")
163
  else:
164
+ print(f"警告:模型檔案 {self.model_path} 不存在")
 
 
165
  return False
166
 
167
+ # 嘗試載入標準化器(如果存在)
168
+ if os.path.exists(self.scaler_path):
169
+ self.scaler = joblib.load(self.scaler_path)
170
+ print(f"成功載入標準化器: {self.scaler_path}")
171
+ else:
172
+ print(f"警告:未找到標準化器檔案 {self.scaler_path},將使用原始數據進行預測")
173
+ # 根據訓練腳本,模型沒有使用標準化,所以這是正常的
174
+ self.scaler = None
175
+
176
+ self.is_model_loaded = True
177
+ return True
178
+
179
  except Exception as e:
180
+ print(f"載入模型時發生錯誤: {e}")
 
181
  return False
182
 
183
+ def predict(self, model_name, input_data):
184
  """
185
+ 使用載入的模型進行預測
186
 
187
  Args:
188
+ model_name: 模型名稱(保持接口一致性)
189
+ input_data: 輸入特徵 DataFrame 或 numpy array
190
 
191
  Returns:
192
+ dict: 預測結果字典,包含各時間框架的漲幅百分比
193
  """
194
+ if not self.is_model_loaded:
195
+ if not self.load_model(model_name):
196
+ raise RuntimeError("模型載入失敗,無法進行預測")
197
+
198
  try:
199
+ # 確保輸入是 DataFrame 格式
200
+ if isinstance(input_data, np.ndarray):
201
+ if input_data.shape[1] != len(self.feature_columns):
202
+ raise ValueError(f"輸入特徵數量不匹配。期望: {len(self.feature_columns)}, 實際: {input_data.shape[1]}")
203
+ input_df = pd.DataFrame(input_data, columns=self.feature_columns)
204
+ elif isinstance(input_data, pd.DataFrame):
205
+ input_df = input_data.copy()
206
+ else:
207
+ raise ValueError("輸入數據必須是 DataFrame 或 numpy array")
208
+
209
+ # 確保所有必需的特徵都存在
210
+ missing_features = [col for col in self.feature_columns if col not in input_df.columns]
211
  if missing_features:
212
+ raise ValueError(f"缺少必要的特徵欄位: {missing_features}")
 
 
 
213
 
214
+ # 選擇並排序特徵
215
+ input_features = input_df[self.feature_columns]
216
 
217
+ # 檢查 NaN 值
218
+ if input_features.isnull().any().any():
219
+ print("警告:輸入數據包含 NaN 值,將用 0 填補")
220
+ input_features = input_features.fillna(0)
221
 
222
+ # 應用標準化(如果有的話)
223
  if self.scaler is not None:
224
+ input_features_scaled = self.scaler.transform(input_features)
225
+ else:
226
+ input_features_scaled = input_features.values
227
+
228
+ # 進行預測
229
+ predictions = self.model.predict(input_features_scaled)
230
+
231
+ # 處理預測結果的維度
232
+ if predictions.ndim == 1:
233
+ # 如果是單一樣本的預測,reshape 成 (1, 4)
234
+ if len(predictions) == 4:
235
+ predictions = predictions.reshape(1, -1)
236
+ else:
237
+ raise ValueError(f"預測結果維度不正確: {predictions.shape}")
238
+
239
+ # 確保結果是 (n_samples, 4) 的形狀
240
+ if predictions.shape[1] != 4:
241
+ raise ValueError(f"模型預測輸出維度錯誤,期望 4 個輸出,實際: {predictions.shape[1]}")
242
+
243
+ # 構建預測結果字典(取第一個樣本的預測)
244
+ result = {}
245
+ prediction_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred', 'Change_pct_t10_pred', 'Change_pct_t20_pred']
246
+
247
+ for i, key in enumerate(prediction_keys):
248
+ result[key] = float(predictions[0, i]) # 取第一個樣本的第 i 個預測
249
+
250
+ return result
251
 
252
  except Exception as e:
253
+ print(f"預測過程中發生錯誤: {e}")
254
+ raise
255
 
256
+ def predict_single_timeframe(self, stock_data, days, news_score=0.0, us_market_data=None):
257
  """
258
+ 預測單一時間框架的漲幅
259
 
260
  Args:
261
+ stock_data: 股票歷史數據 (yfinance格式)
262
+ days: 預測天數 (1, 5, 10, 20)
263
+ news_score: 新聞情緒分數
264
+ us_market_data: 美股市場數據 (可選)
265
+
266
  Returns:
267
+ float: 預測的漲幅百分比
268
  """
269
  try:
270
+ # 創建特徵
271
+ processed_df = self.create_features_from_stock_data(stock_data)
 
 
 
272
 
273
+ # 使用最新的數據點
274
+ latest_data = processed_df.iloc[-1:].copy()
 
 
275
 
276
+ # 更新新聞分數
277
+ latest_data.loc[latest_data.index[0], 'NEWS'] = news_score
278
 
279
+ # 更新美股數據(如果提供)
280
+ if us_market_data:
281
+ if 'DJI' in us_market_data and len(us_market_data) > 1:
282
+ dji_return = (us_market_data['DJI'][-1] - us_market_data['DJI'][-2]) / us_market_data['DJI'][-2]
283
+ latest_data.loc[latest_data.index[0], 'dji_return_t-1'] = dji_return
284
+
285
+ if 'SOX' in us_market_data and len(us_market_data) > 1:
286
+ sox_return = (us_market_data['SOX'][-1] - us_market_data['SOX'][-2]) / us_market_data['SOX'][-2]
287
+ latest_data.loc[latest_data.index[0], 'sox_return_t-1'] = sox_return
288
 
289
+ # 進行預測
290
+ predictions = self.predict('xgboost_model', latest_data)
291
+
292
+ # 根據天數返回對應的預測值
293
+ if days == 1:
294
+ return predictions['Change_pct_t1_pred']
295
+ elif days == 5:
296
+ return predictions['Change_pct_t5_pred']
297
+ elif days == 10:
298
+ return predictions['Change_pct_t10_pred']
299
+ elif days == 20:
300
+ return predictions['Change_pct_t20_pred']
301
  else:
302
+ # 對於其他天數,使用最接近的預測值
303
+ if days <= 3:
304
+ return predictions['Change_pct_t1_pred']
305
+ elif days <= 7:
306
+ return predictions['Change_pct_t5_pred']
307
+ elif days <= 15:
308
+ return predictions['Change_pct_t10_pred']
309
+ else:
310
+ return predictions['Change_pct_t20_pred']
311
+
 
 
 
 
 
 
 
312
  except Exception as e:
313
+ print(f"單一時間框架預測失敗: {e}")
314
+ return 0.0
 
 
315
 
316
+ def validate_input_features(self, input_data):
317
  """
318
+ 驗證輸入特徵的完整性和有效性
319
 
320
  Args:
321
+ input_data: 輸入的特徵數據
322
+
323
+ Returns:
324
+ dict: 驗證結果
325
+ """
326
+ validation_result = {
327
+ 'is_valid': True,
328
+ 'missing_features': [],
329
+ 'invalid_values': [],
330
+ 'warnings': []
331
+ }
332
+
333
+ try:
334
+ if isinstance(input_data, np.ndarray):
335
+ if input_data.shape[1] != len(self.feature_columns):
336
+ validation_result['is_valid'] = False
337
+ validation_result['warnings'].append(f"特徵數量不匹配: 期望{len(self.feature_columns)}, 實際{input_data.shape[1]}")
338
+ return validation_result
339
+
340
+ # 檢查缺失特徵
341
+ if isinstance(input_data, pd.DataFrame):
342
+ missing_features = [col for col in self.feature_columns if col not in input_data.columns]
343
+ if missing_features:
344
+ validation_result['missing_features'] = missing_features
345
+ validation_result['is_valid'] = False
346
+
347
+ # 檢查數值有效性
348
+ for feature in self.feature_columns:
349
+ if feature in input_data.columns:
350
+ if input_data[feature].isnull().any():
351
+ validation_result['invalid_values'].append(f"{feature}: 包含NaN值")
352
+
353
+ if np.isinf(input_data[feature]).any():
354
+ validation_result['invalid_values'].append(f"{feature}: 包含無限值")
355
+
356
+ return validation_result
357
+
358
+ except Exception as e:
359
+ validation_result['is_valid'] = False
360
+ validation_result['warnings'].append(f"驗證過程出錯: {e}")
361
+ return validation_result
362
+
363
+ def get_feature_importance(self):
364
+ """
365
+ 獲取模型的特徵重要性
366
 
367
  Returns:
368
+ dict: 特徵重要性字典
369
  """
370
+ if not self.is_model_loaded:
371
+ return {}
372
+
373
  try:
374
+ importance_scores = self.model.feature_importances_
375
+ importance_dict = {}
 
376
 
377
+ for i, feature in enumerate(self.feature_columns):
378
+ importance_dict[feature] = float(importance_scores[i])
379
+
380
+ # 按重要性排序
381
+ sorted_importance = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True))
382
+
383
+ return sorted_importance
384
 
 
 
 
 
 
 
385
  except Exception as e:
386
+ print(f"獲取特徵重要性失敗: {e}")
387
+ return {}
388
 
389
+ def get_prediction_confidence(self, input_data):
390
  """
391
+ 估算預測信心度
392
 
393
  Args:
394
+ input_data: 輸入特徵數據
395
 
396
  Returns:
397
+ float: 信心度分數 (0-1)
398
  """
399
  try:
400
+ # 基礎信心度檢查
401
+ validation_result = self.validate_input_features(input_data)
 
 
 
 
 
 
 
402
 
403
+ if not validation_result['is_valid']:
404
+ return 0.3 # 數據有問題時給予較低信心度
405
 
406
+ # 根據特徵完整性調整信心度
407
+ base_confidence = 0.7
408
 
409
+ if validation_result['missing_features']:
410
+ base_confidence -= len(validation_result['missing_features']) * 0.05
 
 
 
 
411
 
412
+ if validation_result['invalid_values']:
413
+ base_confidence -= len(validation_result['invalid_values']) * 0.05
414
 
415
+ return max(0.3, min(0.9, base_confidence))
416
 
417
  except Exception as e:
418
+ print(f"計算預測信心度失敗: {e}")
419
  return 0.5
420
 
421
  def validate_input(self, input_df):