AlanRex commited on
Commit
9cb94a7
·
verified ·
1 Parent(s): 8a5febd

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +73 -73
model_predictor.py CHANGED
@@ -159,81 +159,81 @@ class XGBoostModel:
159
  return input_df
160
 
161
  def predict(self, model_name, input_df):
162
- """
163
- 進行股價漲幅預測(已加入自動對齊模型 feature_names 的邏輯)
164
- """
165
- try:
166
- # 載入模型(如果尚未載入)
167
- if self.model is None:
168
- model_path = f"{model_name}.json"
169
- if not self.load_model(model_path):
170
- return None
171
-
172
- # 先做基本的特徵預處理(會補缺失欄位為 0,但不重新排序)
173
- processed_df = input_df.copy()
174
- processed_df = self.preprocess_features(processed_df.copy())
175
-
176
- # 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
177
- expected_features = None
178
  try:
179
- booster = self.model.get_booster()
180
- expected_features = getattr(booster, "feature_names", None)
181
- except Exception:
 
 
 
 
 
 
 
 
182
  expected_features = None
183
-
184
- if expected_features:
185
- # 檢查缺失或多餘欄位
186
- missing = [f for f in expected_features if f not in processed_df.columns]
187
- extra = [f for f in processed_df.columns if f not in expected_features]
188
-
189
- if missing:
190
- print(f"警告:模型期待以下特徵但輸入缺失,將以 0 補齊: {missing}")
191
- for f in missing:
192
- processed_df[f] = 0.0
193
-
194
- if extra:
195
- print(f"注意:輸入含有模型未使用的額外特徵,將忽略: {extra}")
196
-
197
- # 依模型期待欄位順序重排(且只保留 expected_features)
198
- processed_df = processed_df[expected_features]
199
- else:
200
- # 如果模型沒有記錄 feature_names,退回到 class 裡預設的 feature_columns(如有)
201
- processed_df = processed_df[self.feature_columns]
202
-
203
- # DEBUG 訊息
204
- print("=== 模型輸入特徵檢查(對齊後) ===")
205
- print(f"輸入形狀: {processed_df.shape}")
206
- print("前5個特徵值:")
207
- for i, col in enumerate(processed_df.columns[:5]):
208
- print(f" {col}: {processed_df[col].iloc[0]:.6f}")
209
-
210
- # 進行預測
211
- predictions = self.model.predict(processed_df)
212
- print(f"原始預測輸出形狀: {getattr(predictions, 'shape', str(type(predictions)))}")
213
- print(f"原始預測值: {predictions}")
214
-
215
- # 處理多輸出或單輸出
216
- if getattr(predictions, 'ndim', 1) == 1:
217
- # 單輸出
218
- result = {'Change_pct_t1_pred': float(predictions[0])}
219
- else:
220
- result = {}
221
- target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
222
- 'Change_pct_t10_pred', 'Change_pct_t20_pred']
223
- for i, key in enumerate(target_keys):
224
- if i < predictions.shape[1]:
225
- result[key] = float(predictions[0][i])
226
- else:
227
- result[key] = 0.0
228
-
229
- # 列印摘要
230
- print("=== 漲幅預測結果 ===")
231
- for key, value in result.items():
232
- days = key.split('_')[2][1:]
233
- direction = "上漲" if value > 0 else "下跌"
234
- print(f" {days}日後預測: {value:+.2f}% ({direction})")
235
-
236
- return result
 
 
 
 
 
237
 
238
  except Exception as e:
239
  print(f"預測過程中發生錯誤:{e}")
 
159
  return input_df
160
 
161
  def predict(self, model_name, input_df):
162
+ """
163
+ 進行股價漲幅預測(已加入自動對齊模型 feature_names 的邏輯)
164
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  try:
166
+ # 載入模型(如果尚未載入)
167
+ if self.model is None:
168
+ model_path = f"{model_name}.json"
169
+ if not self.load_model(model_path):
170
+ return None
171
+
172
+ # 先做基本的特徵預處理(會補缺失欄位為 0,但不重新排序)
173
+ processed_df = input_df.copy()
174
+ processed_df = self.preprocess_features(processed_df.copy())
175
+
176
+ # 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
177
  expected_features = None
178
+ try:
179
+ booster = self.model.get_booster()
180
+ expected_features = getattr(booster, "feature_names", None)
181
+ except Exception:
182
+ expected_features = None
183
+
184
+ if expected_features:
185
+ # 檢查缺失或多餘欄位
186
+ missing = [f for f in expected_features if f not in processed_df.columns]
187
+ extra = [f for f in processed_df.columns if f not in expected_features]
188
+
189
+ if missing:
190
+ print(f"警告:模型期待以下特徵但輸入缺失,將以 0 補齊: {missing}")
191
+ for f in missing:
192
+ processed_df[f] = 0.0
193
+
194
+ if extra:
195
+ print(f"注意:輸入含有模型未使用的額外特徵,將忽略: {extra}")
196
+
197
+ # 依模型期待欄位順序重排(且只保留 expected_features)
198
+ processed_df = processed_df[expected_features]
199
+ else:
200
+ # 如果模型沒有記錄 feature_names,退回到 class 裡預設的 feature_columns(如有)
201
+ processed_df = processed_df[self.feature_columns]
202
+
203
+ # DEBUG 訊息
204
+ print("=== 模型輸入特徵檢查(對齊後) ===")
205
+ print(f"輸入形狀: {processed_df.shape}")
206
+ print("前5個特徵值:")
207
+ for i, col in enumerate(processed_df.columns[:5]):
208
+ print(f" {col}: {processed_df[col].iloc[0]:.6f}")
209
+
210
+ # 進行預測
211
+ predictions = self.model.predict(processed_df)
212
+ print(f"原始預測輸出形狀: {getattr(predictions, 'shape', str(type(predictions)))}")
213
+ print(f"原始預測值: {predictions}")
214
+
215
+ # 處理多輸出或單輸出
216
+ if getattr(predictions, 'ndim', 1) == 1:
217
+ # 單輸出
218
+ result = {'Change_pct_t1_pred': float(predictions[0])}
219
+ else:
220
+ result = {}
221
+ target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
222
+ 'Change_pct_t10_pred', 'Change_pct_t20_pred']
223
+ for i, key in enumerate(target_keys):
224
+ if i < predictions.shape[1]:
225
+ result[key] = float(predictions[0][i])
226
+ else:
227
+ result[key] = 0.0
228
+
229
+ # 列印摘要
230
+ print("=== 漲幅預測結果 ===")
231
+ for key, value in result.items():
232
+ days = key.split('_')[2][1:]
233
+ direction = "上漲" if value > 0 else "下跌"
234
+ print(f" {days}日後預測: {value:+.2f}% ({direction})")
235
+
236
+ return result
237
 
238
  except Exception as e:
239
  print(f"預測過程中發生錯誤:{e}")