AlanRex commited on
Commit
43505ce
·
verified ·
1 Parent(s): 996c927

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +58 -175
model_predictor.py CHANGED
@@ -10,196 +10,79 @@ import pickle
10
  import joblib
11
 
12
  class XGBoostModel:
13
- def __init__(self):
14
- """
15
- 初始化 XGBoost 模型預測器
16
-
17
- 【重要更新】
18
- - 模型現在輸出漲幅百分比而非絕對價格
19
- - 支援 1日、5日、10日、20日的漲幅預測
20
- """
21
- self.model = None
22
- self.scaler = None
23
- self.feature_columns = [
24
- 'close', # 前一日收盤價
25
- 'return_t-1', # 前一日報酬率
26
- 'return_t-5', # 過去 5 日累積報酬率
27
- 'MA5_close', # 5 日移動平均價
28
- 'volatility_5d', # 5 日報酬標準差
29
- 'volume_ratio_5d', # 今日成交量 ÷ 5 日均量
30
- 'MACD_diff', # MACD - signal
31
- 'dji_return_t-1', # 前一日道瓊指數報酬率
32
- 'sox_return_t-1', # 前一日費半指數報酬率
33
- 'NEWS' # 新聞情緒分數
34
- ]
35
-
36
- # 【新增】輸出目標對應表
37
- self.output_targets = {
38
- 1: 'Change_pct_t1_pred', # 1天後漲幅%
39
- 5: 'Change_pct_t5_pred', # 5天後漲幅%
40
- 10: 'Change_pct_t10_pred', # 10天後漲幅%
41
- 20: 'Change_pct_t20_pred' # 20天後漲幅%
42
- }
43
-
44
- print("XGBoost 模型預測器初始化完成")
45
- print("輸出格式:漲幅百分比 (1日, 5日, 10日, 20日)")
46
 
47
- def load_model(self, model_path):
48
  """
49
- 載入預訓練的 XGBoost 模型
50
-
51
- Args:
52
- model_path (str): 模型檔案路徑 (.json 格式)
53
-
54
- Returns:
55
- bool: 是否成功載入
56
  """
57
- try:
58
- # 檢查模型檔案是否存在
59
- if not os.path.exists(model_path):
60
- print(f"錯誤:找不到模型檔案 {model_path}")
61
- return False
62
-
63
- # 載入 XGBoost 模型
64
- self.model = xgb.XGBRegressor()
65
- self.model.load_model(model_path)
66
-
67
- print(f"成功載入模型:{model_path}")
68
- print(f"預期特徵數量:{len(self.feature_columns)}")
69
-
70
- return True
71
-
72
- except Exception as e:
73
- print(f"載入模型時發生錯誤:{e}")
74
- return False
75
 
76
- def load_scaler(self, scaler_path):
77
  """
78
- 載入特徵標準化器
79
-
80
- Args:
81
- scaler_path (str): 標準化器檔案路徑 (.pkl 格式)
82
-
83
- Returns:
84
- bool: 是否成功載入
85
  """
86
- try:
87
- if os.path.exists(scaler_path):
88
- self.scaler = joblib.load(scaler_path)
89
- print(f"成功載入標準化器:{scaler_path}")
90
- return True
91
- else:
92
- print(f"警告:找不到標準化器檔案 {scaler_path}")
93
- print("將使用預設標準化器")
94
- self.scaler = StandardScaler()
95
- return False
96
 
97
- except Exception as e:
98
- print(f"載入標準化器時發生錯誤:{e}")
99
- self.scaler = StandardScaler()
100
- return False
101
-
102
- def preprocess_features(self, input_df):
103
- """
104
- 預處理輸入特徵
105
-
106
- Args:
107
- input_df (pd.DataFrame): 輸入特徵 DataFrame
108
-
109
- Returns:
110
- pd.DataFrame: 預處理後的特徵
111
- """
112
  try:
113
- # 確保輸入包含所有必要特徵
114
- missing_features = [f for f in self.feature_columns if f not in input_df.columns]
115
- if missing_features:
116
- print(f"警告:缺少以下特徵:{missing_features}")
117
- # 用 0 填補缺少的特徵
118
- for feature in missing_features:
119
- input_df[feature] = 0
120
-
121
- # 按照預期順序重新排列特徵
122
- input_df = input_df[self.feature_columns]
123
-
124
- # 處理 NaN 值
125
- input_df = input_df.fillna(0)
126
-
127
- # 如果有標準化器,進行標準化
128
- if self.scaler is not None:
129
- try:
130
- # 嘗試使用已訓練的標準化器
131
- scaled_features = self.scaler.transform(input_df)
132
- input_df = pd.DataFrame(scaled_features,
133
- columns=input_df.columns,
134
- index=input_df.index)
135
- except Exception as scaler_error:
136
- print(f"標準化過程發生錯誤:{scaler_error}")
137
- print("跳過標準化步驟")
138
-
139
- return input_df
140
-
141
  except Exception as e:
142
- print(f"特徵預處理時發生錯誤:{e}")
143
- return input_df
144
 
145
  def predict(self, model_name, input_df):
146
  """
147
- 進行股價漲幅預測
148
-
149
  Args:
150
- model_name (str): 模型名稱(用於載入對應模型)
151
- input_df (pd.DataFrame): 輸入特徵
152
-
153
  Returns:
154
- dict: 預測結果,包含各時間點的漲幅百分比
 
155
  """
156
- try:
157
- # 載入模型(如果尚未載入)
158
- if self.model is None:
159
- model_path = f"{model_name}.json"
160
- if not self.load_model(model_path):
161
- return None
162
-
163
- # 載入標準化器(如果存在)
164
- if self.scaler is None:
165
- scaler_path = f"{model_name}_scaler.pkl"
166
- self.load_scaler(scaler_path)
167
-
168
- # 預處理特徵
169
- processed_df = self.preprocess_features(input_df.copy())
170
-
171
- # 進行預測
172
- predictions = self.model.predict(processed_df)
173
-
174
- # 【重要修改】將預測結果格式化為漲幅百分比
175
- if predictions.ndim == 1:
176
- # 如果只有一個輸出,假設是 1 日預測
177
- result = {
178
- 'Change_pct_t1_pred': float(predictions[0])
179
- }
180
- else:
181
- # 多輸出情況:1日, 5日, 10日, 20日
182
- result = {
183
- 'Change_pct_t1_pred': float(predictions[0][0]) if len(predictions[0]) > 0 else 0.0,
184
- 'Change_pct_t5_pred': float(predictions[0][1]) if len(predictions[0]) > 1 else 0.0,
185
- 'Change_pct_t10_pred': float(predictions[0][2]) if len(predictions[0]) > 2 else 0.0,
186
- 'Change_pct_t20_pred': float(predictions[0][3]) if len(predictions[0]) > 3 else 0.0
187
- }
188
-
189
- # 輸出預測結果摘要
190
- print("=== 漲幅預測結果 ===")
191
- for key, value in result.items():
192
- days = key.split('_')[2][1:] # 提取天數
193
- direction = "上漲" if value > 0 else "下跌"
194
- print(f" {days}日後預測: {value:+.2f}% ({direction})")
195
-
196
- return result
197
-
198
- except Exception as e:
199
- print(f"預測過程中發生錯誤:{e}")
200
- import traceback
201
- traceback.print_exc()
202
- return None
203
 
204
  def predict_single_timeframe(self, model_name, input_df, days):
205
  """
 
10
  import joblib
11
 
12
  class XGBoostModel:
13
+ """
14
+ 用於載入和使用預先訓練好的 XGBoost 模型的類別。
15
+ """
16
+ # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
17
+ MODELS = {
18
+ 'xgboost_model': 'xgboost_model.json'
19
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def __init__(self, default_model='xgboost_model'):
22
  """
23
+ 初始化時自動載入預設模型。
 
 
 
 
 
 
24
  """
25
+ self.current_model_name = default_model
26
+ self.model = self._load_model(self.current_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def _load_model(self, model_name):
29
  """
30
+ 從檔案載入 XGBoost 模型。
 
 
 
 
 
 
31
  """
32
+ if model_name not in self.MODELS:
33
+ raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
 
 
 
 
 
 
 
 
34
 
35
+ filename = self.MODELS[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
+ # 建立一個新的 XGBoost 模型實例
38
+ model = xgb.XGBRegressor()
39
+ # 使用 XGBoost 內建的 load_model 方法載入檔案
40
+ model.load_model(filename)
41
+ print(f"成功載入模型檔案: {filename}")
42
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  except Exception as e:
44
+ raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
 
45
 
46
  def predict(self, model_name, input_df):
47
  """
48
+ 使用載入的模型進行預測。
49
+
50
  Args:
51
+ model_name (str): 要使用的模型名稱。
52
+ input_df (pd.DataFrame): 包含特徵數據的 DataFrame,應只有一筆紀錄。
53
+
54
  Returns:
55
+ dict: 包含四個預測目標的預測結果字典。
56
+ {'Change_pct_t1_pred': float, 'Change_pct_t5_pred': float, ...}
57
  """
58
+ # 如果請求的模型名稱與目前載入的不同,則動態載入
59
+ if model_name != self.current_model_name:
60
+ self.model = self._load_model(model_name)
61
+ self.current_model_name = model_name
62
+
63
+ # 進行預測
64
+ # model.predict 會回傳一個 numpy 陣列,形狀為 (n_samples, n_targets)
65
+ # 在我們的案例中,n_samples=1, n_targets=4
66
+ predictions = self.model.predict(input_df)
67
+
68
+ # 【【核心修正】】
69
+ # 您的模型是多輸出模型,預測結果是一個包含4個值的陣列。
70
+ # 我們需要將這個陣列轉換為一個包含各預測目標的字典,以便 app.py 使用。
71
+ # predictions[0] 會取得第一筆樣本的所有預測值 (一個有4個元素的陣列)
72
+ if predictions.ndim == 2 and predictions.shape[0] > 0:
73
+ pred_values = predictions[0]
74
+ elif predictions.ndim == 1:
75
+ pred_values = predictions
76
+ else:
77
+ raise ValueError("預測結果的格式不符合預期。")
78
+
79
+ result = {
80
+ 'Change_pct_t1_pred': pred_values[0],
81
+ 'Change_pct_t5_pred': pred_values[1],
82
+ 'Change_pct_t10_pred': pred_values[2],
83
+ 'Change_pct_t20_pred': pred_values[3]
84
+ }
85
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def predict_single_timeframe(self, model_name, input_df, days):
88
  """