Spaces:
Sleeping
Sleeping
Update model_predictor.py
Browse files- model_predictor.py +73 -73
model_predictor.py
CHANGED
|
@@ -159,81 +159,81 @@ class XGBoostModel:
|
|
| 159 |
return input_df
|
| 160 |
|
| 161 |
def predict(self, model_name, input_df):
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
try:
|
| 166 |
-
# 載入模型(如果尚未載入)
|
| 167 |
-
if self.model is None:
|
| 168 |
-
model_path = f"{model_name}.json"
|
| 169 |
-
if not self.load_model(model_path):
|
| 170 |
-
return None
|
| 171 |
-
|
| 172 |
-
# 先做基本的特徵預處理(會補缺失欄位為 0,但不重新排序)
|
| 173 |
-
processed_df = input_df.copy()
|
| 174 |
-
processed_df = self.preprocess_features(processed_df.copy())
|
| 175 |
-
|
| 176 |
-
# 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
|
| 177 |
-
expected_features = None
|
| 178 |
try:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
expected_features = None
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
if
|
| 190 |
-
|
| 191 |
-
for f in
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
except Exception as e:
|
| 239 |
print(f"預測過程中發生錯誤:{e}")
|
|
|
|
| 159 |
return input_df
|
| 160 |
|
| 161 |
def predict(self, model_name, input_df):
|
| 162 |
+
"""
|
| 163 |
+
進行股價漲幅預測(已加入自動對齊模型 feature_names 的邏輯)
|
| 164 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
try:
|
| 166 |
+
# 載入模型(如果尚未載入)
|
| 167 |
+
if self.model is None:
|
| 168 |
+
model_path = f"{model_name}.json"
|
| 169 |
+
if not self.load_model(model_path):
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
# 先做基本的特徵預處理(會補缺失欄位為 0,但不重新排序)
|
| 173 |
+
processed_df = input_df.copy()
|
| 174 |
+
processed_df = self.preprocess_features(processed_df.copy())
|
| 175 |
+
|
| 176 |
+
# 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
|
| 177 |
expected_features = None
|
| 178 |
+
try:
|
| 179 |
+
booster = self.model.get_booster()
|
| 180 |
+
expected_features = getattr(booster, "feature_names", None)
|
| 181 |
+
except Exception:
|
| 182 |
+
expected_features = None
|
| 183 |
+
|
| 184 |
+
if expected_features:
|
| 185 |
+
# 檢查缺失或多餘欄位
|
| 186 |
+
missing = [f for f in expected_features if f not in processed_df.columns]
|
| 187 |
+
extra = [f for f in processed_df.columns if f not in expected_features]
|
| 188 |
+
|
| 189 |
+
if missing:
|
| 190 |
+
print(f"警告:模型期待以下特徵但輸入缺失,將以 0 補齊: {missing}")
|
| 191 |
+
for f in missing:
|
| 192 |
+
processed_df[f] = 0.0
|
| 193 |
+
|
| 194 |
+
if extra:
|
| 195 |
+
print(f"注意:輸入含有模型未使用的額外特徵,將忽略: {extra}")
|
| 196 |
+
|
| 197 |
+
# 依模型期待欄位順序重排(且只保留 expected_features)
|
| 198 |
+
processed_df = processed_df[expected_features]
|
| 199 |
+
else:
|
| 200 |
+
# 如果模型沒有記錄 feature_names,退回到 class 裡預設的 feature_columns(如有)
|
| 201 |
+
processed_df = processed_df[self.feature_columns]
|
| 202 |
+
|
| 203 |
+
# DEBUG 訊息
|
| 204 |
+
print("=== 模型輸入特徵檢查(對齊後) ===")
|
| 205 |
+
print(f"輸入形狀: {processed_df.shape}")
|
| 206 |
+
print("前5個特徵值:")
|
| 207 |
+
for i, col in enumerate(processed_df.columns[:5]):
|
| 208 |
+
print(f" {col}: {processed_df[col].iloc[0]:.6f}")
|
| 209 |
+
|
| 210 |
+
# 進行預測
|
| 211 |
+
predictions = self.model.predict(processed_df)
|
| 212 |
+
print(f"原始預測輸出形狀: {getattr(predictions, 'shape', str(type(predictions)))}")
|
| 213 |
+
print(f"原始預測值: {predictions}")
|
| 214 |
+
|
| 215 |
+
# 處理多輸出或單輸出
|
| 216 |
+
if getattr(predictions, 'ndim', 1) == 1:
|
| 217 |
+
# 單輸出
|
| 218 |
+
result = {'Change_pct_t1_pred': float(predictions[0])}
|
| 219 |
+
else:
|
| 220 |
+
result = {}
|
| 221 |
+
target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
|
| 222 |
+
'Change_pct_t10_pred', 'Change_pct_t20_pred']
|
| 223 |
+
for i, key in enumerate(target_keys):
|
| 224 |
+
if i < predictions.shape[1]:
|
| 225 |
+
result[key] = float(predictions[0][i])
|
| 226 |
+
else:
|
| 227 |
+
result[key] = 0.0
|
| 228 |
+
|
| 229 |
+
# 列印摘要
|
| 230 |
+
print("=== 漲幅預測結果 ===")
|
| 231 |
+
for key, value in result.items():
|
| 232 |
+
days = key.split('_')[2][1:]
|
| 233 |
+
direction = "上漲" if value > 0 else "下跌"
|
| 234 |
+
print(f" {days}日後預測: {value:+.2f}% ({direction})")
|
| 235 |
+
|
| 236 |
+
return result
|
| 237 |
|
| 238 |
except Exception as e:
|
| 239 |
print(f"預測過程中發生錯誤:{e}")
|