AlanRex commited on
Commit
4edf1bc
·
verified ·
1 Parent(s): 595c0d3

Delete model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +0 -308
model_predictor.py DELETED
@@ -1,308 +0,0 @@
1
- # model_predictor.py - 支援漲幅百分比輸出的XGBoost模型預測器
2
- # 修改版本:輸出改為漲幅百分比而非絕對價格
3
-
4
- import os
5
- import pandas as pd
6
- import numpy as np
7
- import xgboost as xgb
8
- from sklearn.preprocessing import StandardScaler
9
- import pickle
10
- import joblib
11
-
12
- class XGBoostModel:
13
- """
14
- 用於載入和使用預先訓練好的 XGBoost 模型的類別。
15
- """
16
- # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
17
- MODELS = {
18
- 'xgboost_model': 'xgboost_model.json'
19
- }
20
-
21
- def __init__(self, default_model='xgboost_model'):
22
- """
23
- 初始化時自動載入預設模型。
24
- """
25
- self.current_model_name = default_model
26
- self.model = self._load_model(self.current_model_name)
27
-
28
- def _load_model(self, model_name):
29
- """
30
- 從檔案載入 XGBoost 模型。
31
- """
32
- if model_name not in self.MODELS:
33
- raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
34
-
35
- filename = self.MODELS[model_name]
36
- try:
37
- # 建立一個新的 XGBoost 模型實例
38
- model = xgb.XGBRegressor()
39
- # 使用 XGBoost 內建的 load_model 方法載入檔案
40
- model.load_model(filename)
41
- print(f"成功載入模型檔案: {filename}")
42
- return model
43
- except Exception as e:
44
- raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
45
-
46
- def predict(self, model_name, input_df):
47
- """
48
- 使用載入的模型進行預測。
49
-
50
- Args:
51
- model_name (str): 要使用的模型名稱。
52
- input_df (pd.DataFrame): 包含特徵數據的 DataFrame,應只有一筆紀錄。
53
-
54
- Returns:
55
- dict: 包含四個預測目標的預測結果字典。
56
- {'Change_pct_t1_pred': float, 'Change_pct_t5_pred': float, ...}
57
- """
58
- # 如果請求的模型名稱與目前載入的不同,則動態載入
59
- if model_name != self.current_model_name:
60
- self.model = self._load_model(model_name)
61
- self.current_model_name = model_name
62
-
63
- # 進行預測
64
- # model.predict 會回傳一個 numpy 陣列,形狀為 (n_samples, n_targets)
65
- # 在我們的案例中,n_samples=1, n_targets=4
66
- predictions = self.model.predict(input_df)
67
-
68
- # 【【核心修正】】
69
- # 您的模型是多輸出模型,預測結果是一個包含4個值的陣列。
70
- # 我們需要將這個陣列轉換為一個包含各預測目標的字典,以便 app.py 使用。
71
- # predictions[0] 會取得第一筆樣本的所有預測值 (一個有4個元素的陣列)
72
- if predictions.ndim == 2 and predictions.shape[0] > 0:
73
- pred_values = predictions[0]
74
- elif predictions.ndim == 1:
75
- pred_values = predictions
76
- else:
77
- raise ValueError("預測結果的格式不符合預期。")
78
-
79
- result = {
80
- 'Change_pct_t1_pred': pred_values[0],
81
- 'Change_pct_t5_pred': pred_values[1],
82
- 'Change_pct_t10_pred': pred_values[2],
83
- 'Change_pct_t20_pred': pred_values[3]
84
- }
85
- return result
86
-
87
- def predict_single_timeframe(self, model_name, input_df, days):
88
- """
89
- 預測特定時間框架的漲幅
90
-
91
- Args:
92
- model_name (str): 模型名稱
93
- input_df (pd.DataFrame): 輸入特徵
94
- days (int): 預測天數 (1, 5, 10, 20)
95
-
96
- Returns:
97
- float: 預測的漲幅百分比
98
- """
99
- try:
100
- predictions = self.predict(model_name, input_df)
101
- if predictions is None:
102
- return None
103
-
104
- # 根據天數選擇對應的預測結果
105
- target_key = f'Change_pct_t{days}_pred'
106
-
107
- if target_key in predictions:
108
- return predictions[target_key]
109
- else:
110
- print(f"警告:找不到 {days} 日預測結果")
111
- return None
112
-
113
- except Exception as e:
114
- print(f"單一時間框架預測時發生錯誤:{e}")
115
- return None
116
-
117
- def get_prediction_confidence(self, input_df):
118
- """
119
- 評估預測的信心度
120
-
121
- Args:
122
- input_df (pd.DataFrame): 輸入特徵
123
-
124
- Returns:
125
- float: 信心度 (0-1)
126
- """
127
- try:
128
- # 基於特徵完整性和質量評估信心度
129
- feature_completeness = 0
130
- total_features = len(self.feature_columns)
131
-
132
- for feature in self.feature_columns:
133
- if feature in input_df.columns:
134
- value = input_df[feature].iloc[0]
135
- if not pd.isna(value) and value != 0:
136
- feature_completeness += 1
137
-
138
- completeness_ratio = feature_completeness / total_features
139
-
140
- # 基於數據質量調整信心度
141
- base_confidence = max(0.5, completeness_ratio)
142
-
143
- # 如果重要特徵缺失,降低信心度
144
- important_features = ['close', 'return_t-1', 'MA5_close']
145
- missing_important = 0
146
- for feature in important_features:
147
- if feature not in input_df.columns or pd.isna(input_df[feature].iloc[0]):
148
- missing_important += 1
149
-
150
- if missing_important > 0:
151
- base_confidence *= (1 - missing_important * 0.1)
152
-
153
- return min(0.9, max(0.3, base_confidence))
154
-
155
- except Exception as e:
156
- print(f"計算信心度時發生錯誤:{e}")
157
- return 0.5
158
-
159
- def validate_input(self, input_df):
160
- """
161
- 驗證輸入數據的有效性
162
-
163
- Args:
164
- input_df (pd.DataFrame): 輸入特徵
165
-
166
- Returns:
167
- tuple: (是否有效, 錯誤訊息列表)
168
- """
169
- errors = []
170
-
171
- try:
172
- # 檢查是否為空
173
- if input_df.empty:
174
- errors.append("輸入數據為空")
175
-
176
- # 檢查必要特徵
177
- required_features = ['close', 'return_t-1']
178
- for feature in required_features:
179
- if feature not in input_df.columns:
180
- errors.append(f"缺少必要特徵:{feature}")
181
- elif pd.isna(input_df[feature].iloc[0]):
182
- errors.append(f"必要特徵包含空值:{feature}")
183
-
184
- # 檢查數據合理性
185
- if 'close' in input_df.columns:
186
- close_price = input_df['close'].iloc[0]
187
- if close_price <= 0:
188
- errors.append(f"收盤價不合理:{close_price}")
189
-
190
- if 'return_t-1' in input_df.columns:
191
- return_val = input_df['return_t-1'].iloc[0]
192
- if abs(return_val) > 0.5: # 單日漲跌幅超過50%可能有問題
193
- errors.append(f"報酬率異常:{return_val:.3f}")
194
-
195
- return len(errors) == 0, errors
196
-
197
- except Exception as e:
198
- errors.append(f"驗證過程發生錯誤:{e}")
199
- return False, errors
200
-
201
- def get_feature_importance(self):
202
- """
203
- 獲取特徵重要性
204
-
205
- Returns:
206
- dict: 特徵重要性字典
207
- """
208
- try:
209
- if self.model is None:
210
- return None
211
-
212
- # 獲取特徵重要性
213
- importance_scores = self.model.feature_importances_
214
-
215
- # 創建特徵重要性字典
216
- importance_dict = {}
217
- for i, feature in enumerate(self.feature_columns):
218
- if i < len(importance_scores):
219
- importance_dict[feature] = float(importance_scores[i])
220
-
221
- # 按重要性排序
222
- sorted_importance = dict(sorted(importance_dict.items(),
223
- key=lambda x: x[1],
224
- reverse=True))
225
-
226
- return sorted_importance
227
-
228
- except Exception as e:
229
- print(f"獲取特徵重要性時發生錯誤:{e}")
230
- return None
231
-
232
- def explain_prediction(self, input_df, predictions):
233
- """
234
- 解釋預測結果
235
-
236
- Args:
237
- input_df (pd.DataFrame): 輸入特徵
238
- predictions (dict): 預測結果
239
-
240
- Returns:
241
- str: 解釋文本
242
- """
243
- try:
244
- explanation = []
245
- explanation.append("=== 預測解釋 ===")
246
-
247
- # 分析主要驅動因素
248
- feature_importance = self.get_feature_importance()
249
- if feature_importance:
250
- explanation.append("主要影響因素:")
251
- top_features = list(feature_importance.keys())[:3]
252
- for feature in top_features:
253
- if feature in input_df.columns:
254
- value = input_df[feature].iloc[0]
255
- importance = feature_importance[feature]
256
- explanation.append(f" - {feature}: {value:.4f} (重要性: {importance:.3f})")
257
-
258
- # 分析預測趨勢
259
- explanation.append("\n預測趨勢分析:")
260
- for key, value in predictions.items():
261
- days = key.split('_')[2][1:]
262
- trend = "看漲" if value > 1 else "看跌" if value < -1 else "持平"
263
- explanation.append(f" - {days}日: {value:+.2f}% ({trend})")
264
-
265
- return "\n".join(explanation)
266
-
267
- except Exception as e:
268
- return f"解釋生成失敗: {e}"
269
-
270
- # 範例��用方式
271
- if __name__ == "__main__":
272
- # 初始化模型
273
- model = XGBoostModel()
274
-
275
- # 準備測試數據
276
- test_data = pd.DataFrame({
277
- 'close': [150.0],
278
- 'return_t-1': [0.02],
279
- 'return_t-5': [0.05],
280
- 'MA5_close': [148.0],
281
- 'volatility_5d': [0.025],
282
- 'volume_ratio_5d': [1.2],
283
- 'MACD_diff': [0.5],
284
- 'dji_return_t-1': [0.01],
285
- 'sox_return_t-1': [0.015],
286
- 'NEWS': [0.1]
287
- })
288
-
289
- print("測試模型預測器...")
290
- print("輸入特徵:")
291
- print(test_data)
292
-
293
- # 進行預測
294
- predictions = model.predict('xgboost_model', test_data)
295
-
296
- if predictions:
297
- print("\n預測成功!")
298
- print("結果說明:輸出為相對於當前價格的漲幅百分比")
299
-
300
- # 解釋預測
301
- explanation = model.explain_prediction(test_data, predictions)
302
- print(f"\n{explanation}")
303
-
304
- # 計算信心度
305
- confidence = model.get_prediction_confidence(test_data)
306
- print(f"\n預測信心度: {confidence:.2%}")
307
- else:
308
- print("預測失敗!")