AlanRex commited on
Commit
595c0d3
·
verified ·
1 Parent(s): 43505ce

Upload 2 files

Browse files
Files changed (2) hide show
  1. model_predictor (1).py +425 -0
  2. 原本app.py +0 -0
model_predictor (1).py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_predictor.py - 支援漲幅百分比輸出的XGBoost模型預測器
2
+ # 修改版本:輸出改為漲幅百分比而非絕對價格
3
+
4
+ import os
5
+ import pandas as pd
6
+ import numpy as np
7
+ import xgboost as xgb
8
+ from sklearn.preprocessing import StandardScaler
9
+ import pickle
10
+ import joblib
11
+
12
+ class XGBoostModel:
13
+ def __init__(self):
14
+ """
15
+ 初始化 XGBoost 模型預測器
16
+
17
+ 【重要更新】
18
+ - 模型現在輸出漲幅百分比而非絕對價格
19
+ - 支援 1日、5日、10日、20日的漲幅預測
20
+ """
21
+ self.model = None
22
+ self.scaler = None
23
+ self.feature_columns = [
24
+ 'close', # 前一日收盤價
25
+ 'return_t-1', # 前一日報酬率
26
+ 'return_t-5', # 過去 5 日累積報酬率
27
+ 'MA5_close', # 5 日移動平均價
28
+ 'volatility_5d', # 5 日報酬標準差
29
+ 'volume_ratio_5d', # 今日成交量 ÷ 5 日均量
30
+ 'MACD_diff', # MACD - signal
31
+ 'dji_return_t-1', # 前一日道瓊指數報酬率
32
+ 'sox_return_t-1', # 前一日費半指數報酬率
33
+ 'NEWS' # 新聞情緒分數
34
+ ]
35
+
36
+ # 【新增】輸出目標對應表
37
+ self.output_targets = {
38
+ 1: 'Change_pct_t1_pred', # 1天後漲幅%
39
+ 5: 'Change_pct_t5_pred', # 5天後漲幅%
40
+ 10: 'Change_pct_t10_pred', # 10天後漲幅%
41
+ 20: 'Change_pct_t20_pred' # 20天後漲幅%
42
+ }
43
+
44
+ print("XGBoost 模型預測器初始化完成")
45
+ print("輸出格式:漲幅百分比 (1日, 5日, 10日, 20日)")
46
+
47
+ def load_model(self, model_path):
48
+ """
49
+ 載入預訓練的 XGBoost 模型
50
+
51
+ Args:
52
+ model_path (str): 模型檔案路徑 (.json 格式)
53
+
54
+ Returns:
55
+ bool: 是否成功載入
56
+ """
57
+ try:
58
+ # 檢查模型檔案是否存在
59
+ if not os.path.exists(model_path):
60
+ print(f"錯誤:找不到模型檔案 {model_path}")
61
+ return False
62
+
63
+ # 載入 XGBoost 模型
64
+ self.model = xgb.XGBRegressor()
65
+ self.model.load_model(model_path)
66
+
67
+ print(f"成功載入模型:{model_path}")
68
+ print(f"預期特徵數量:{len(self.feature_columns)}")
69
+
70
+ return True
71
+
72
+ except Exception as e:
73
+ print(f"載入模型時發生錯誤:{e}")
74
+ return False
75
+
76
+ def load_scaler(self, scaler_path):
77
+ """
78
+ 載入特徵標準化器
79
+
80
+ Args:
81
+ scaler_path (str): 標準化器檔案路徑 (.pkl 格式)
82
+
83
+ Returns:
84
+ bool: 是否成功載入
85
+ """
86
+ try:
87
+ if os.path.exists(scaler_path):
88
+ self.scaler = joblib.load(scaler_path)
89
+ print(f"成功載入標準化器:{scaler_path}")
90
+ return True
91
+ else:
92
+ print(f"警告:找不到標準化器檔案 {scaler_path}")
93
+ print("將使用預設標準化器")
94
+ self.scaler = StandardScaler()
95
+ return False
96
+
97
+ except Exception as e:
98
+ print(f"載入標準化器時發生錯誤:{e}")
99
+ self.scaler = StandardScaler()
100
+ return False
101
+
102
+ def preprocess_features(self, input_df):
103
+ """
104
+ 預處理輸入特徵
105
+
106
+ Args:
107
+ input_df (pd.DataFrame): 輸入特徵 DataFrame
108
+
109
+ Returns:
110
+ pd.DataFrame: 預處理後的特徵
111
+ """
112
+ try:
113
+ # 確保輸入包含所有必要特徵
114
+ missing_features = [f for f in self.feature_columns if f not in input_df.columns]
115
+ if missing_features:
116
+ print(f"警告:缺少以下特徵:{missing_features}")
117
+ # 用 0 填補缺少的特徵
118
+ for feature in missing_features:
119
+ input_df[feature] = 0
120
+
121
+ # 按照預期順序重新排列特徵
122
+ input_df = input_df[self.feature_columns]
123
+
124
+ # 處理 NaN 值
125
+ input_df = input_df.fillna(0)
126
+
127
+ # 如果有標準化器,進行標準化
128
+ if self.scaler is not None:
129
+ try:
130
+ # 嘗試使用已訓練的標準化器
131
+ scaled_features = self.scaler.transform(input_df)
132
+ input_df = pd.DataFrame(scaled_features,
133
+ columns=input_df.columns,
134
+ index=input_df.index)
135
+ except Exception as scaler_error:
136
+ print(f"標準化過程發生錯誤:{scaler_error}")
137
+ print("跳過標準化步驟")
138
+
139
+ return input_df
140
+
141
+ except Exception as e:
142
+ print(f"特徵預處理時發生錯誤:{e}")
143
+ return input_df
144
+
145
+ def predict(self, model_name, input_df):
146
+ """
147
+ 進行股價漲幅預測
148
+
149
+ Args:
150
+ model_name (str): 模型名稱(用於載入對應模型)
151
+ input_df (pd.DataFrame): 輸入特徵
152
+
153
+ Returns:
154
+ dict: 預測結果,包含各時間點的漲幅百分比
155
+ """
156
+ try:
157
+ # 載入模型(如果尚未載入)
158
+ if self.model is None:
159
+ model_path = f"{model_name}.json"
160
+ if not self.load_model(model_path):
161
+ return None
162
+
163
+ # 載入標準化器(如果存在)
164
+ if self.scaler is None:
165
+ scaler_path = f"{model_name}_scaler.pkl"
166
+ self.load_scaler(scaler_path)
167
+
168
+ # 預處理特徵
169
+ processed_df = self.preprocess_features(input_df.copy())
170
+
171
+ # 進行預測
172
+ predictions = self.model.predict(processed_df)
173
+
174
+ # 【重要修改】將預測結果格式化為漲幅百分比
175
+ if predictions.ndim == 1:
176
+ # 如果只有一個輸出,假設是 1 日預測
177
+ result = {
178
+ 'Change_pct_t1_pred': float(predictions[0])
179
+ }
180
+ else:
181
+ # 多輸出情況:1日, 5日, 10日, 20日
182
+ result = {
183
+ 'Change_pct_t1_pred': float(predictions[0][0]) if len(predictions[0]) > 0 else 0.0,
184
+ 'Change_pct_t5_pred': float(predictions[0][1]) if len(predictions[0]) > 1 else 0.0,
185
+ 'Change_pct_t10_pred': float(predictions[0][2]) if len(predictions[0]) > 2 else 0.0,
186
+ 'Change_pct_t20_pred': float(predictions[0][3]) if len(predictions[0]) > 3 else 0.0
187
+ }
188
+
189
+ # 輸出預測結果摘要
190
+ print("=== 漲幅預測結果 ===")
191
+ for key, value in result.items():
192
+ days = key.split('_')[2][1:] # 提取天數
193
+ direction = "上漲" if value > 0 else "下跌"
194
+ print(f" {days}日後預測: {value:+.2f}% ({direction})")
195
+
196
+ return result
197
+
198
+ except Exception as e:
199
+ print(f"預測過程中發生錯誤:{e}")
200
+ import traceback
201
+ traceback.print_exc()
202
+ return None
203
+
204
+ def predict_single_timeframe(self, model_name, input_df, days):
205
+ """
206
+ 預測特定時間框架的漲幅
207
+
208
+ Args:
209
+ model_name (str): 模型名稱
210
+ input_df (pd.DataFrame): 輸入特徵
211
+ days (int): 預測天數 (1, 5, 10, 20)
212
+
213
+ Returns:
214
+ float: 預測的漲幅百分比
215
+ """
216
+ try:
217
+ predictions = self.predict(model_name, input_df)
218
+ if predictions is None:
219
+ return None
220
+
221
+ # 根據天數選擇對應的預測結果
222
+ target_key = f'Change_pct_t{days}_pred'
223
+
224
+ if target_key in predictions:
225
+ return predictions[target_key]
226
+ else:
227
+ print(f"警告:找不到 {days} 日預測結果")
228
+ return None
229
+
230
+ except Exception as e:
231
+ print(f"單一時間框架預測時發生錯誤:{e}")
232
+ return None
233
+
234
+ def get_prediction_confidence(self, input_df):
235
+ """
236
+ 評估預測的信心度
237
+
238
+ Args:
239
+ input_df (pd.DataFrame): 輸入特徵
240
+
241
+ Returns:
242
+ float: 信心度 (0-1)
243
+ """
244
+ try:
245
+ # 基於特徵完整性和質量評估信心度
246
+ feature_completeness = 0
247
+ total_features = len(self.feature_columns)
248
+
249
+ for feature in self.feature_columns:
250
+ if feature in input_df.columns:
251
+ value = input_df[feature].iloc[0]
252
+ if not pd.isna(value) and value != 0:
253
+ feature_completeness += 1
254
+
255
+ completeness_ratio = feature_completeness / total_features
256
+
257
+ # 基於數據質量調整信心度
258
+ base_confidence = max(0.5, completeness_ratio)
259
+
260
+ # 如果重要特徵缺失,降低信心度
261
+ important_features = ['close', 'return_t-1', 'MA5_close']
262
+ missing_important = 0
263
+ for feature in important_features:
264
+ if feature not in input_df.columns or pd.isna(input_df[feature].iloc[0]):
265
+ missing_important += 1
266
+
267
+ if missing_important > 0:
268
+ base_confidence *= (1 - missing_important * 0.1)
269
+
270
+ return min(0.9, max(0.3, base_confidence))
271
+
272
+ except Exception as e:
273
+ print(f"計算信心度時發生錯誤:{e}")
274
+ return 0.5
275
+
276
+ def validate_input(self, input_df):
277
+ """
278
+ 驗證輸入數據的有效性
279
+
280
+ Args:
281
+ input_df (pd.DataFrame): 輸入特徵
282
+
283
+ Returns:
284
+ tuple: (是否有效, 錯誤訊息列表)
285
+ """
286
+ errors = []
287
+
288
+ try:
289
+ # 檢查是否為空
290
+ if input_df.empty:
291
+ errors.append("輸入數據為空")
292
+
293
+ # 檢查必要特徵
294
+ required_features = ['close', 'return_t-1']
295
+ for feature in required_features:
296
+ if feature not in input_df.columns:
297
+ errors.append(f"缺少必要特徵:{feature}")
298
+ elif pd.isna(input_df[feature].iloc[0]):
299
+ errors.append(f"必要特徵包含空值:{feature}")
300
+
301
+ # 檢查數據合理性
302
+ if 'close' in input_df.columns:
303
+ close_price = input_df['close'].iloc[0]
304
+ if close_price <= 0:
305
+ errors.append(f"收盤價不合理:{close_price}")
306
+
307
+ if 'return_t-1' in input_df.columns:
308
+ return_val = input_df['return_t-1'].iloc[0]
309
+ if abs(return_val) > 0.5: # 單日漲跌幅超過50%可能有問題
310
+ errors.append(f"報酬率異常:{return_val:.3f}")
311
+
312
+ return len(errors) == 0, errors
313
+
314
+ except Exception as e:
315
+ errors.append(f"驗證過程發生錯誤:{e}")
316
+ return False, errors
317
+
318
+ def get_feature_importance(self):
319
+ """
320
+ 獲取特徵重要性
321
+
322
+ Returns:
323
+ dict: 特徵重要性字典
324
+ """
325
+ try:
326
+ if self.model is None:
327
+ return None
328
+
329
+ # 獲取特徵重要性
330
+ importance_scores = self.model.feature_importances_
331
+
332
+ # 創建特徵重要性字典
333
+ importance_dict = {}
334
+ for i, feature in enumerate(self.feature_columns):
335
+ if i < len(importance_scores):
336
+ importance_dict[feature] = float(importance_scores[i])
337
+
338
+ # 按重要性排序
339
+ sorted_importance = dict(sorted(importance_dict.items(),
340
+ key=lambda x: x[1],
341
+ reverse=True))
342
+
343
+ return sorted_importance
344
+
345
+ except Exception as e:
346
+ print(f"獲取特徵重要性時發生錯誤:{e}")
347
+ return None
348
+
349
+ def explain_prediction(self, input_df, predictions):
350
+ """
351
+ 解釋預測結果
352
+
353
+ Args:
354
+ input_df (pd.DataFrame): 輸入特徵
355
+ predictions (dict): 預測結果
356
+
357
+ Returns:
358
+ str: 解釋文本
359
+ """
360
+ try:
361
+ explanation = []
362
+ explanation.append("=== 預測解釋 ===")
363
+
364
+ # 分析主要驅動因素
365
+ feature_importance = self.get_feature_importance()
366
+ if feature_importance:
367
+ explanation.append("主要影響因素:")
368
+ top_features = list(feature_importance.keys())[:3]
369
+ for feature in top_features:
370
+ if feature in input_df.columns:
371
+ value = input_df[feature].iloc[0]
372
+ importance = feature_importance[feature]
373
+ explanation.append(f" - {feature}: {value:.4f} (重要性: {importance:.3f})")
374
+
375
+ # 分析預測趨勢
376
+ explanation.append("\n預測趨勢分析:")
377
+ for key, value in predictions.items():
378
+ days = key.split('_')[2][1:]
379
+ trend = "看漲" if value > 1 else "看跌" if value < -1 else "持平"
380
+ explanation.append(f" - {days}日: {value:+.2f}% ({trend})")
381
+
382
+ return "\n".join(explanation)
383
+
384
+ except Exception as e:
385
+ return f"解釋生成失敗: {e}"
386
+
387
+ # 範例使用方式
388
+ if __name__ == "__main__":
389
+ # 初始化模型
390
+ model = XGBoostModel()
391
+
392
+ # 準備測試數據
393
+ test_data = pd.DataFrame({
394
+ 'close': [150.0],
395
+ 'return_t-1': [0.02],
396
+ 'return_t-5': [0.05],
397
+ 'MA5_close': [148.0],
398
+ 'volatility_5d': [0.025],
399
+ 'volume_ratio_5d': [1.2],
400
+ 'MACD_diff': [0.5],
401
+ 'dji_return_t-1': [0.01],
402
+ 'sox_return_t-1': [0.015],
403
+ 'NEWS': [0.1]
404
+ })
405
+
406
+ print("測試模型預測器...")
407
+ print("輸入特徵:")
408
+ print(test_data)
409
+
410
+ # 進行預測
411
+ predictions = model.predict('xgboost_model', test_data)
412
+
413
+ if predictions:
414
+ print("\n預測成功!")
415
+ print("結果說明:輸出為相對於當前價���的漲幅百分比")
416
+
417
+ # 解釋預測
418
+ explanation = model.explain_prediction(test_data, predictions)
419
+ print(f"\n{explanation}")
420
+
421
+ # 計算信心度
422
+ confidence = model.get_prediction_confidence(test_data)
423
+ print(f"\n預測信心度: {confidence:.2%}")
424
+ else:
425
+ print("預測失敗!")
原本app.py ADDED
The diff for this file is too large to render. See raw diff