AlanRex commited on
Commit
ac0c9bd
·
verified ·
1 Parent(s): 66a7471

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +233 -1
model_predictor.py CHANGED
@@ -190,4 +190,236 @@ class XGBoostModel:
190
  print("=== 漲幅預測結果 ===")
191
  for key, value in result.items():
192
  days = key.split('_')[2][1:] # 提取天數
193
- direction = "↗️ 上漲" if value > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  print("=== 漲幅預測結果 ===")
191
  for key, value in result.items():
192
  days = key.split('_')[2][1:] # 提取天數
193
+ direction = "上漲" if value > 0 else "下跌"
194
+ print(f" {days}日後預測: {value:+.2f}% ({direction})")
195
+
196
+ return result
197
+
198
+ except Exception as e:
199
+ print(f"預測過程中發生錯誤:{e}")
200
+ import traceback
201
+ traceback.print_exc()
202
+ return None
203
+
204
+ def predict_single_timeframe(self, model_name, input_df, days):
205
+ """
206
+ 預測特定時間框架的漲幅
207
+
208
+ Args:
209
+ model_name (str): 模型名稱
210
+ input_df (pd.DataFrame): 輸入特徵
211
+ days (int): 預測天數 (1, 5, 10, 20)
212
+
213
+ Returns:
214
+ float: 預測的漲幅百分比
215
+ """
216
+ try:
217
+ predictions = self.predict(model_name, input_df)
218
+ if predictions is None:
219
+ return None
220
+
221
+ # 根據天數選擇對應的預測結果
222
+ target_key = f'Change_pct_t{days}_pred'
223
+
224
+ if target_key in predictions:
225
+ return predictions[target_key]
226
+ else:
227
+ print(f"警告:找不到 {days} 日預測結果")
228
+ return None
229
+
230
+ except Exception as e:
231
+ print(f"單一時間框架預測時發生錯誤:{e}")
232
+ return None
233
+
234
+ def get_prediction_confidence(self, input_df):
235
+ """
236
+ 評估預測的信心度
237
+
238
+ Args:
239
+ input_df (pd.DataFrame): 輸入特徵
240
+
241
+ Returns:
242
+ float: 信心度 (0-1)
243
+ """
244
+ try:
245
+ # 基於特徵完整性和質量評估信心度
246
+ feature_completeness = 0
247
+ total_features = len(self.feature_columns)
248
+
249
+ for feature in self.feature_columns:
250
+ if feature in input_df.columns:
251
+ value = input_df[feature].iloc[0]
252
+ if not pd.isna(value) and value != 0:
253
+ feature_completeness += 1
254
+
255
+ completeness_ratio = feature_completeness / total_features
256
+
257
+ # 基於數據質量調整信心度
258
+ base_confidence = max(0.5, completeness_ratio)
259
+
260
+ # 如果重要特徵缺失,降低信心度
261
+ important_features = ['close', 'return_t-1', 'MA5_close']
262
+ missing_important = 0
263
+ for feature in important_features:
264
+ if feature not in input_df.columns or pd.isna(input_df[feature].iloc[0]):
265
+ missing_important += 1
266
+
267
+ if missing_important > 0:
268
+ base_confidence *= (1 - missing_important * 0.1)
269
+
270
+ return min(0.9, max(0.3, base_confidence))
271
+
272
+ except Exception as e:
273
+ print(f"計算信心度時發生錯誤:{e}")
274
+ return 0.5
275
+
276
+ def validate_input(self, input_df):
277
+ """
278
+ 驗證輸入數據的有效性
279
+
280
+ Args:
281
+ input_df (pd.DataFrame): 輸入特徵
282
+
283
+ Returns:
284
+ tuple: (是否有效, 錯誤訊息列表)
285
+ """
286
+ errors = []
287
+
288
+ try:
289
+ # 檢查是否為空
290
+ if input_df.empty:
291
+ errors.append("輸入數據為空")
292
+
293
+ # 檢查必要特徵
294
+ required_features = ['close', 'return_t-1']
295
+ for feature in required_features:
296
+ if feature not in input_df.columns:
297
+ errors.append(f"缺少必要特徵:{feature}")
298
+ elif pd.isna(input_df[feature].iloc[0]):
299
+ errors.append(f"必要特徵包含空值:{feature}")
300
+
301
+ # 檢查數據合理性
302
+ if 'close' in input_df.columns:
303
+ close_price = input_df['close'].iloc[0]
304
+ if close_price <= 0:
305
+ errors.append(f"收盤價不合理:{close_price}")
306
+
307
+ if 'return_t-1' in input_df.columns:
308
+ return_val = input_df['return_t-1'].iloc[0]
309
+ if abs(return_val) > 0.5: # 單日漲跌幅超過50%可能有問題
310
+ errors.append(f"報酬率異常:{return_val:.3f}")
311
+
312
+ return len(errors) == 0, errors
313
+
314
+ except Exception as e:
315
+ errors.append(f"驗證過程發生錯誤:{e}")
316
+ return False, errors
317
+
318
+ def get_feature_importance(self):
319
+ """
320
+ 獲取特徵重要性
321
+
322
+ Returns:
323
+ dict: 特徵重要性字典
324
+ """
325
+ try:
326
+ if self.model is None:
327
+ return None
328
+
329
+ # 獲取特徵重要性
330
+ importance_scores = self.model.feature_importances_
331
+
332
+ # 創建特徵重要性字典
333
+ importance_dict = {}
334
+ for i, feature in enumerate(self.feature_columns):
335
+ if i < len(importance_scores):
336
+ importance_dict[feature] = float(importance_scores[i])
337
+
338
+ # 按重要性排序
339
+ sorted_importance = dict(sorted(importance_dict.items(),
340
+ key=lambda x: x[1],
341
+ reverse=True))
342
+
343
+ return sorted_importance
344
+
345
+ except Exception as e:
346
+ print(f"獲取特徵重要性時發生錯誤:{e}")
347
+ return None
348
+
349
+ def explain_prediction(self, input_df, predictions):
350
+ """
351
+ 解釋預測結果
352
+
353
+ Args:
354
+ input_df (pd.DataFrame): 輸入特徵
355
+ predictions (dict): 預測結果
356
+
357
+ Returns:
358
+ str: 解釋文本
359
+ """
360
+ try:
361
+ explanation = []
362
+ explanation.append("=== 預測解釋 ===")
363
+
364
+ # 分析主要驅動因素
365
+ feature_importance = self.get_feature_importance()
366
+ if feature_importance:
367
+ explanation.append("主要影響因素:")
368
+ top_features = list(feature_importance.keys())[:3]
369
+ for feature in top_features:
370
+ if feature in input_df.columns:
371
+ value = input_df[feature].iloc[0]
372
+ importance = feature_importance[feature]
373
+ explanation.append(f" - {feature}: {value:.4f} (重要性: {importance:.3f})")
374
+
375
+ # 分析預測趨勢
376
+ explanation.append("\n預測趨勢分析:")
377
+ for key, value in predictions.items():
378
+ days = key.split('_')[2][1:]
379
+ trend = "看漲" if value > 1 else "看跌" if value < -1 else "持平"
380
+ explanation.append(f" - {days}日: {value:+.2f}% ({trend})")
381
+
382
+ return "\n".join(explanation)
383
+
384
+ except Exception as e:
385
+ return f"解釋生成失敗: {e}"
386
+
387
+ # 範例使用方式
388
+ if __name__ == "__main__":
389
+ # 初始化模型
390
+ model = XGBoostModel()
391
+
392
+ # 準備測試數據
393
+ test_data = pd.DataFrame({
394
+ 'close': [150.0],
395
+ 'return_t-1': [0.02],
396
+ 'return_t-5': [0.05],
397
+ 'MA5_close': [148.0],
398
+ 'volatility_5d': [0.025],
399
+ 'volume_ratio_5d': [1.2],
400
+ 'MACD_diff': [0.5],
401
+ 'dji_return_t-1': [0.01],
402
+ 'sox_return_t-1': [0.015],
403
+ 'NEWS': [0.1]
404
+ })
405
+
406
+ print("測試模型預測器...")
407
+ print("輸入特徵:")
408
+ print(test_data)
409
+
410
+ # 進行預測
411
+ predictions = model.predict('xgboost_model_v1_3_percentage_output', test_data)
412
+
413
+ if predictions:
414
+ print("\n預測成功!")
415
+ print("結果說明:輸出為相對於當前價格的漲幅百分比")
416
+
417
+ # 解釋預測
418
+ explanation = model.explain_prediction(test_data, predictions)
419
+ print(f"\n{explanation}")
420
+
421
+ # 計算信心度
422
+ confidence = model.get_prediction_confidence(test_data)
423
+ print(f"\n預測信心度: {confidence:.2%}")
424
+ else:
425
+ print("預測失敗!")