AlanRex commited on
Commit
8a5febd
·
verified ·
1 Parent(s): fea8870

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +80 -58
model_predictor.py CHANGED
@@ -159,65 +159,87 @@ class XGBoostModel:
159
  return input_df
160
 
161
  def predict(self, model_name, input_df):
162
- """
163
- 進行股價漲幅預測
164
-
165
- Returns:
166
- dict: 預測結果,包含各時間點的漲幅百分比
167
- """
 
 
 
 
 
 
 
 
 
 
168
  try:
169
- # 載入模型(如果尚未載入)
170
- if self.model is None:
171
- model_path = f"{model_name}.json"
172
- if not self.load_model(model_path):
173
- return None
174
-
175
- # 預處理特徵
176
- processed_df = self.preprocess_features(input_df.copy())
177
-
178
- print("=== 模型輸入特徵檢查 ===")
179
- print(f"輸入形狀: {processed_df.shape}")
180
- print("前5個特徵值:")
181
- for i, col in enumerate(processed_df.columns[:5]):
182
- print(f" {col}: {processed_df[col].iloc[0]:.6f}")
183
-
184
- # 進行預測
185
- predictions = self.model.predict(processed_df)
186
- print(f"原始預測輸出形狀: {predictions.shape}")
187
- print(f"原始預測值: {predictions}")
188
-
189
- # 【修正】處理多輸出預測結果
190
- if predictions.ndim == 1:
191
- # 單輸出情況 - 只有一個時間點的預測
192
- result = {
193
- 'Change_pct_t1_pred': float(predictions[0])
194
- }
195
- else:
196
- # 多輸出情況:[t1, t5, t10, t20] - 對應訓練模型的四個輸出
197
- result = {}
198
- target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
199
- 'Change_pct_t10_pred', 'Change_pct_t20_pred']
200
-
201
- for i, key in enumerate(target_keys):
202
- if i < predictions.shape[1]:
203
- result[key] = float(predictions[0][i])
204
- else:
205
- result[key] = 0.0
206
-
207
- # 輸出預測結果摘要
208
- print("=== 漲幅預測結果 ===")
209
- for key, value in result.items():
210
- days = key.split('_')[2][1:]
211
- direction = "上漲" if value > 0 else "下跌"
212
- print(f" {days}日後預測: {value:+.2f}% ({direction})")
213
-
214
- return result
215
-
216
- except Exception as e:
217
- print(f"預測過程中發生錯誤:{e}")
218
- import traceback
219
- traceback.print_exc()
220
- return None
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  def predict_single_timeframe(self, model_name, input_df, days):
223
  """預測特定時間框架的漲幅"""
 
159
  return input_df
160
 
161
  def predict(self, model_name, input_df):
162
+ """
163
+ 進行股價漲幅預測(已加入自動對齊模型 feature_names 的邏輯)
164
+ """
165
+ try:
166
+ # 載入模型(如果尚未載入)
167
+ if self.model is None:
168
+ model_path = f"{model_name}.json"
169
+ if not self.load_model(model_path):
170
+ return None
171
+
172
+ # 先做基本的特徵預處理(會補缺失欄位為 0,但不重新排序)
173
+ processed_df = input_df.copy()
174
+ processed_df = self.preprocess_features(processed_df.copy())
175
+
176
+ # 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
177
+ expected_features = None
178
  try:
179
+ booster = self.model.get_booster()
180
+ expected_features = getattr(booster, "feature_names", None)
181
+ except Exception:
182
+ expected_features = None
183
+
184
+ if expected_features:
185
+ # 檢查缺失或多餘欄位
186
+ missing = [f for f in expected_features if f not in processed_df.columns]
187
+ extra = [f for f in processed_df.columns if f not in expected_features]
188
+
189
+ if missing:
190
+ print(f"警告:模型期待以下特徵但輸入缺失,將以 0 補齊: {missing}")
191
+ for f in missing:
192
+ processed_df[f] = 0.0
193
+
194
+ if extra:
195
+ print(f"注意:輸入含有模型未使用的額外特徵,將忽略: {extra}")
196
+
197
+ # 依模型期待欄位順序重排(且只保留 expected_features)
198
+ processed_df = processed_df[expected_features]
199
+ else:
200
+ # 如果模型沒有記錄 feature_names,退回到 class 裡預設的 feature_columns(如有)
201
+ processed_df = processed_df[self.feature_columns]
202
+
203
+ # DEBUG 訊息
204
+ print("=== 模型輸入特徵檢查(對齊後) ===")
205
+ print(f"輸入形狀: {processed_df.shape}")
206
+ print("前5個特徵值:")
207
+ for i, col in enumerate(processed_df.columns[:5]):
208
+ print(f" {col}: {processed_df[col].iloc[0]:.6f}")
209
+
210
+ # 進行預測
211
+ predictions = self.model.predict(processed_df)
212
+ print(f"原始預測輸出形狀: {getattr(predictions, 'shape', str(type(predictions)))}")
213
+ print(f"原始預測值: {predictions}")
214
+
215
+ # 處理多輸出或單輸出
216
+ if getattr(predictions, 'ndim', 1) == 1:
217
+ # 單輸出
218
+ result = {'Change_pct_t1_pred': float(predictions[0])}
219
+ else:
220
+ result = {}
221
+ target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
222
+ 'Change_pct_t10_pred', 'Change_pct_t20_pred']
223
+ for i, key in enumerate(target_keys):
224
+ if i < predictions.shape[1]:
225
+ result[key] = float(predictions[0][i])
226
+ else:
227
+ result[key] = 0.0
228
+
229
+ # 列印摘要
230
+ print("=== 漲幅預測結果 ===")
231
+ for key, value in result.items():
232
+ days = key.split('_')[2][1:]
233
+ direction = "上漲" if value > 0 else "下跌"
234
+ print(f" {days}日後預測: {value:+.2f}% ({direction})")
235
+
236
+ return result
237
+
238
+ except Exception as e:
239
+ print(f"預測過程中發生錯誤:{e}")
240
+ import traceback
241
+ traceback.print_exc()
242
+ return None
243
 
244
  def predict_single_timeframe(self, model_name, input_df, days):
245
  """預測特定時間框架的漲幅"""