AlanRex commited on
Commit
804d0b4
·
verified ·
1 Parent(s): 9cb94a7

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +31 -47
model_predictor.py CHANGED
@@ -160,7 +160,10 @@ class XGBoostModel:
160
 
161
  def predict(self, model_name, input_df):
162
  """
163
- 進行股價漲幅預測(已加入自動對齊模型 feature_names 的邏輯)
 
 
 
164
  """
165
  try:
166
  # 載入模型(如果尚未載入)
@@ -168,72 +171,53 @@ class XGBoostModel:
168
  model_path = f"{model_name}.json"
169
  if not self.load_model(model_path):
170
  return None
171
-
172
- # 先做基本的特徵預處理(會補缺失欄位為 0,但不重新排序)
173
- processed_df = input_df.copy()
174
- processed_df = self.preprocess_features(processed_df.copy())
175
-
176
- # 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
177
- expected_features = None
178
- try:
179
- booster = self.model.get_booster()
180
- expected_features = getattr(booster, "feature_names", None)
181
- except Exception:
182
- expected_features = None
183
-
184
- if expected_features:
185
- # 檢查缺失或多餘欄位
186
- missing = [f for f in expected_features if f not in processed_df.columns]
187
- extra = [f for f in processed_df.columns if f not in expected_features]
188
-
189
- if missing:
190
- print(f"警告:模型期待以下特徵但輸入缺失,將以 0 補齊: {missing}")
191
- for f in missing:
192
- processed_df[f] = 0.0
193
-
194
- if extra:
195
- print(f"注意:輸入含有模型未使用的額外特徵,將忽略: {extra}")
196
-
197
- # 依模型期待欄位順序重排(且只保留 expected_features)
198
- processed_df = processed_df[expected_features]
199
- else:
200
- # 如果模型沒有記錄 feature_names,退回到 class 裡預設的 feature_columns(如有)
201
- processed_df = processed_df[self.feature_columns]
202
-
203
- # DEBUG 訊息
204
- print("=== 模型輸入特徵檢查(對齊後) ===")
205
  print(f"輸入形狀: {processed_df.shape}")
206
  print("前5個特徵值:")
207
  for i, col in enumerate(processed_df.columns[:5]):
208
  print(f" {col}: {processed_df[col].iloc[0]:.6f}")
209
-
210
  # 進行預測
211
  predictions = self.model.predict(processed_df)
212
- print(f"原始預測輸出形狀: {getattr(predictions, 'shape', str(type(predictions)))}")
213
  print(f"原始預測值: {predictions}")
214
-
215
- # 處理多輸出或單輸出
216
- if getattr(predictions, 'ndim', 1) == 1:
217
- # 單輸出
218
- result = {'Change_pct_t1_pred': float(predictions[0])}
 
 
219
  else:
 
220
  result = {}
221
- target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
222
- 'Change_pct_t10_pred', 'Change_pct_t20_pred']
 
223
  for i, key in enumerate(target_keys):
224
  if i < predictions.shape[1]:
225
  result[key] = float(predictions[0][i])
226
  else:
227
  result[key] = 0.0
228
-
229
- # 列印摘要
230
  print("=== 漲幅預測結果 ===")
231
  for key, value in result.items():
232
  days = key.split('_')[2][1:]
233
  direction = "上漲" if value > 0 else "下跌"
234
  print(f" {days}日後預測: {value:+.2f}% ({direction})")
235
-
236
  return result
 
 
 
 
 
 
237
 
238
  except Exception as e:
239
  print(f"預測過程中發生錯誤:{e}")
 
160
 
161
  def predict(self, model_name, input_df):
162
  """
163
+ 進行股價漲幅預測
164
+
165
+ Returns:
166
+ dict: 預測結果,包含各時間點的漲幅百分比
167
  """
168
  try:
169
  # 載入模型(如果尚未載入)
 
171
  model_path = f"{model_name}.json"
172
  if not self.load_model(model_path):
173
  return None
174
+
175
+ # 預處理特徵
176
+ processed_df = self.preprocess_features(input_df.copy())
177
+
178
+ print("=== 模型輸入特徵檢查 ===")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  print(f"輸入形狀: {processed_df.shape}")
180
  print("前5個特徵值:")
181
  for i, col in enumerate(processed_df.columns[:5]):
182
  print(f" {col}: {processed_df[col].iloc[0]:.6f}")
183
+
184
  # 進行預測
185
  predictions = self.model.predict(processed_df)
186
+ print(f"原始預測輸出形狀: {predictions.shape}")
187
  print(f"原始預測值: {predictions}")
188
+
189
+ # 【修正】處理多輸出預測結果
190
+ if predictions.ndim == 1:
191
+ # 單輸出情況 - 只有一個時間點的預測
192
+ result = {
193
+ 'Change_pct_t1_pred': float(predictions[0])
194
+ }
195
  else:
196
+ # 多輸出情況:[t1, t5, t10, t20] - 對應訓練模型的四個輸出
197
  result = {}
198
+ target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
199
+ 'Change_pct_t10_pred', 'Change_pct_t20_pred']
200
+
201
  for i, key in enumerate(target_keys):
202
  if i < predictions.shape[1]:
203
  result[key] = float(predictions[0][i])
204
  else:
205
  result[key] = 0.0
206
+
207
+ # 輸出預測結果摘要
208
  print("=== 漲幅預測結果 ===")
209
  for key, value in result.items():
210
  days = key.split('_')[2][1:]
211
  direction = "上漲" if value > 0 else "下跌"
212
  print(f" {days}日後預測: {value:+.2f}% ({direction})")
213
+
214
  return result
215
+
216
+ except Exception as e:
217
+ print(f"預測過程中發生錯誤:{e}")
218
+ import traceback
219
+ traceback.print_exc()
220
+ return None
221
 
222
  except Exception as e:
223
  print(f"預測過程中發生錯誤:{e}")