AlanRex commited on
Commit
66a7471
·
verified ·
1 Parent(s): 06e2021

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +182 -34
model_predictor.py CHANGED
@@ -1,45 +1,193 @@
1
- import xgboost as xgb
 
 
 
2
  import pandas as pd
 
 
 
 
 
3
 
4
  class XGBoostModel:
5
- # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
6
- MODELS = {
7
- 'xgboost_model': 'xgboost_model.json'
8
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def __init__(self, default_model='xgboost_model'):
11
- # 建立物件時,自動載入預設模型
12
- self.current_model_name = default_model
13
- self.model = self._load_model(self.current_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def _load_model(self, model_name):
16
- if model_name not in self.MODELS:
17
- raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- filename = self.MODELS[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- # 建立一個新的 XGBoost 模型實例
22
- model = xgb.XGBRegressor()
23
- # 使用 XGBoost 內建的 load_model 方法載入檔案
24
- model.load_model(filename)
25
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
- raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
 
28
 
29
  def predict(self, model_name, input_df):
30
- # 如果請求的模型名稱與目前載入的不同,則動態載入
31
- if model_name != self.current_model_name:
32
- self.model = self._load_model(model_name)
33
- self.current_model_name = model_name
34
 
35
- # 進行預測
36
- predictions = self.model.predict(input_df)
37
-
38
- # 將預測結果轉換為字典
39
- result = {
40
- 'Close_t0_pred': predictions[0][0],
41
- 'Close_t5_pred': predictions[0][1],
42
- 'Close_t10_pred': predictions[0][2],
43
- 'Close_t20_pred': predictions[0][3]
44
- }
45
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_predictor.py - 支援漲幅百分比輸出的XGBoost模型預測器
2
+ # 修改版本:輸出改為漲幅百分比而非絕對價格
3
+
4
+ import os
5
  import pandas as pd
6
+ import numpy as np
7
+ import xgboost as xgb
8
+ from sklearn.preprocessing import StandardScaler
9
+ import pickle
10
+ import joblib
11
 
12
  class XGBoostModel:
13
+ def __init__(self):
14
+ """
15
+ 初始化 XGBoost 模型預測器
16
+
17
+ 【重要更新】
18
+ - 模型現在輸出漲幅百分比而非絕對價格
19
+ - 支援 1日、5日、10日、20日的漲幅預測
20
+ """
21
+ self.model = None
22
+ self.scaler = None
23
+ self.feature_columns = [
24
+ 'close', # 前一日收盤價
25
+ 'return_t-1', # 前一日報酬率
26
+ 'return_t-5', # 過去 5 日累積報酬率
27
+ 'MA5_close', # 5 日移動平均價
28
+ 'volatility_5d', # 5 日報酬標準差
29
+ 'volume_ratio_5d', # 今日成交量 ÷ 5 日均量
30
+ 'MACD_diff', # MACD - signal
31
+ 'dji_return_t-1', # 前一日道瓊指數報酬率
32
+ 'sox_return_t-1', # 前一日費半指數報酬率
33
+ 'NEWS' # 新聞情緒分數
34
+ ]
35
+
36
+ # 【新增】輸出目標對應表
37
+ self.output_targets = {
38
+ 1: 'Change_pct_t1_pred', # 1天後漲幅%
39
+ 5: 'Change_pct_t5_pred', # 5天後漲幅%
40
+ 10: 'Change_pct_t10_pred', # 10天後漲幅%
41
+ 20: 'Change_pct_t20_pred' # 20天後漲幅%
42
+ }
43
+
44
+ print("XGBoost 模型預測器初始化完成")
45
+ print("輸出格式:漲幅百分比 (1日, 5日, 10日, 20日)")
46
 
47
+ def load_model(self, model_path):
48
+ """
49
+ 載入預訓練的 XGBoost 模型
50
+
51
+ Args:
52
+ model_path (str): 模型檔案路徑 (.json 格式)
53
+
54
+ Returns:
55
+ bool: 是否成功載入
56
+ """
57
+ try:
58
+ # 檢查模型檔案是否存在
59
+ if not os.path.exists(model_path):
60
+ print(f"錯誤:找不到模型檔案 {model_path}")
61
+ return False
62
+
63
+ # 載入 XGBoost 模型
64
+ self.model = xgb.XGBRegressor()
65
+ self.model.load_model(model_path)
66
+
67
+ print(f"成功載入模型:{model_path}")
68
+ print(f"預期特徵數量:{len(self.feature_columns)}")
69
+
70
+ return True
71
+
72
+ except Exception as e:
73
+ print(f"載入模型時發生錯誤:{e}")
74
+ return False
75
 
76
+ def load_scaler(self, scaler_path):
77
+ """
78
+ 載入特徵標準化器
79
+
80
+ Args:
81
+ scaler_path (str): 標準化器檔案路徑 (.pkl 格式)
82
+
83
+ Returns:
84
+ bool: 是否成功載入
85
+ """
86
+ try:
87
+ if os.path.exists(scaler_path):
88
+ self.scaler = joblib.load(scaler_path)
89
+ print(f"成功載入標準化器:{scaler_path}")
90
+ return True
91
+ else:
92
+ print(f"警告:找不到標準化器檔案 {scaler_path}")
93
+ print("將使用預設標準化器")
94
+ self.scaler = StandardScaler()
95
+ return False
96
 
97
+ except Exception as e:
98
+ print(f"載入標準化器時發生錯誤:{e}")
99
+ self.scaler = StandardScaler()
100
+ return False
101
+
102
+ def preprocess_features(self, input_df):
103
+ """
104
+ 預處理輸入特徵
105
+
106
+ Args:
107
+ input_df (pd.DataFrame): 輸入特徵 DataFrame
108
+
109
+ Returns:
110
+ pd.DataFrame: 預處理後的特徵
111
+ """
112
  try:
113
+ # 確保輸入包含所有必要特徵
114
+ missing_features = [f for f in self.feature_columns if f not in input_df.columns]
115
+ if missing_features:
116
+ print(f"警告:缺少以下特徵:{missing_features}")
117
+ # 用 0 填補缺少的特徵
118
+ for feature in missing_features:
119
+ input_df[feature] = 0
120
+
121
+ # 按照預期順序重新排列特徵
122
+ input_df = input_df[self.feature_columns]
123
+
124
+ # 處理 NaN 值
125
+ input_df = input_df.fillna(0)
126
+
127
+ # 如果有標準化器,進行標準化
128
+ if self.scaler is not None:
129
+ try:
130
+ # 嘗試使用已訓練的標準化器
131
+ scaled_features = self.scaler.transform(input_df)
132
+ input_df = pd.DataFrame(scaled_features,
133
+ columns=input_df.columns,
134
+ index=input_df.index)
135
+ except Exception as scaler_error:
136
+ print(f"標準化過程發生錯誤:{scaler_error}")
137
+ print("跳過標準化步驟")
138
+
139
+ return input_df
140
+
141
  except Exception as e:
142
+ print(f"特徵預處理時發生錯誤:{e}")
143
+ return input_df
144
 
145
  def predict(self, model_name, input_df):
146
+ """
147
+ 進行股價漲幅預測
 
 
148
 
149
+ Args:
150
+ model_name (str): 模型名稱(用於載入對應模型)
151
+ input_df (pd.DataFrame): 輸入特徵
152
+
153
+ Returns:
154
+ dict: 預測結果,包含各時間點的漲幅百分比
155
+ """
156
+ try:
157
+ # 載入模型(如果尚未載入)
158
+ if self.model is None:
159
+ model_path = f"{model_name}.json"
160
+ if not self.load_model(model_path):
161
+ return None
162
+
163
+ # 載入標準化器(如果存在)
164
+ if self.scaler is None:
165
+ scaler_path = f"{model_name}_scaler.pkl"
166
+ self.load_scaler(scaler_path)
167
+
168
+ # 預處理特徵
169
+ processed_df = self.preprocess_features(input_df.copy())
170
+
171
+ # 進行預測
172
+ predictions = self.model.predict(processed_df)
173
+
174
+ # 【重要修改】將預測結果格式化為漲幅百分比
175
+ if predictions.ndim == 1:
176
+ # 如果只有一個輸出,假設是 1 日預測
177
+ result = {
178
+ 'Change_pct_t1_pred': float(predictions[0])
179
+ }
180
+ else:
181
+ # 多輸出情況:1日, 5日, 10日, 20日
182
+ result = {
183
+ 'Change_pct_t1_pred': float(predictions[0][0]) if len(predictions[0]) > 0 else 0.0,
184
+ 'Change_pct_t5_pred': float(predictions[0][1]) if len(predictions[0]) > 1 else 0.0,
185
+ 'Change_pct_t10_pred': float(predictions[0][2]) if len(predictions[0]) > 2 else 0.0,
186
+ 'Change_pct_t20_pred': float(predictions[0][3]) if len(predictions[0]) > 3 else 0.0
187
+ }
188
+
189
+ # 輸出預測結果摘要
190
+ print("=== 漲幅預測結果 ===")
191
+ for key, value in result.items():
192
+ days = key.split('_')[2][1:] # 提取天數
193
+ direction = "↗️ 上漲" if value > 0