Spaces:
Sleeping
Sleeping
Update model_predictor.py
Browse files- model_predictor.py +31 -47
model_predictor.py
CHANGED
|
@@ -160,7 +160,10 @@ class XGBoostModel:
|
|
| 160 |
|
| 161 |
def predict(self, model_name, input_df):
|
| 162 |
"""
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
| 164 |
"""
|
| 165 |
try:
|
| 166 |
# 載入模型(如果尚未載入)
|
|
@@ -168,72 +171,53 @@ class XGBoostModel:
|
|
| 168 |
model_path = f"{model_name}.json"
|
| 169 |
if not self.load_model(model_path):
|
| 170 |
return None
|
| 171 |
-
|
| 172 |
-
#
|
| 173 |
-
processed_df = input_df.copy()
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
# 嘗試從已載入的 xgboost 模型中取得訓練時的 feature names
|
| 177 |
-
expected_features = None
|
| 178 |
-
try:
|
| 179 |
-
booster = self.model.get_booster()
|
| 180 |
-
expected_features = getattr(booster, "feature_names", None)
|
| 181 |
-
except Exception:
|
| 182 |
-
expected_features = None
|
| 183 |
-
|
| 184 |
-
if expected_features:
|
| 185 |
-
# 檢查缺失或多餘欄位
|
| 186 |
-
missing = [f for f in expected_features if f not in processed_df.columns]
|
| 187 |
-
extra = [f for f in processed_df.columns if f not in expected_features]
|
| 188 |
-
|
| 189 |
-
if missing:
|
| 190 |
-
print(f"警告:模型期待以下特徵但輸入缺失,將以 0 補齊: {missing}")
|
| 191 |
-
for f in missing:
|
| 192 |
-
processed_df[f] = 0.0
|
| 193 |
-
|
| 194 |
-
if extra:
|
| 195 |
-
print(f"注意:輸入含有模型未使用的額外特徵,將忽略: {extra}")
|
| 196 |
-
|
| 197 |
-
# 依模型期待欄位順序重排(且只保留 expected_features)
|
| 198 |
-
processed_df = processed_df[expected_features]
|
| 199 |
-
else:
|
| 200 |
-
# 如果模型沒有記錄 feature_names,退回到 class 裡預設的 feature_columns(如有)
|
| 201 |
-
processed_df = processed_df[self.feature_columns]
|
| 202 |
-
|
| 203 |
-
# DEBUG 訊息
|
| 204 |
-
print("=== 模型輸入特徵檢查(對齊後) ===")
|
| 205 |
print(f"輸入形狀: {processed_df.shape}")
|
| 206 |
print("前5個特徵值:")
|
| 207 |
for i, col in enumerate(processed_df.columns[:5]):
|
| 208 |
print(f" {col}: {processed_df[col].iloc[0]:.6f}")
|
| 209 |
-
|
| 210 |
# 進行預測
|
| 211 |
predictions = self.model.predict(processed_df)
|
| 212 |
-
print(f"原始預測輸出形狀: {
|
| 213 |
print(f"原始預測值: {predictions}")
|
| 214 |
-
|
| 215 |
-
#
|
| 216 |
-
if
|
| 217 |
-
#
|
| 218 |
-
result = {
|
|
|
|
|
|
|
| 219 |
else:
|
|
|
|
| 220 |
result = {}
|
| 221 |
-
target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
|
| 222 |
-
|
|
|
|
| 223 |
for i, key in enumerate(target_keys):
|
| 224 |
if i < predictions.shape[1]:
|
| 225 |
result[key] = float(predictions[0][i])
|
| 226 |
else:
|
| 227 |
result[key] = 0.0
|
| 228 |
-
|
| 229 |
-
#
|
| 230 |
print("=== 漲幅預測結果 ===")
|
| 231 |
for key, value in result.items():
|
| 232 |
days = key.split('_')[2][1:]
|
| 233 |
direction = "上漲" if value > 0 else "下跌"
|
| 234 |
print(f" {days}日後預測: {value:+.2f}% ({direction})")
|
| 235 |
-
|
| 236 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
except Exception as e:
|
| 239 |
print(f"預測過程中發生錯誤:{e}")
|
|
|
|
| 160 |
|
| 161 |
def predict(self, model_name, input_df):
|
| 162 |
"""
|
| 163 |
+
進行股價漲幅預測
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
dict: 預測結果,包含各時間點的漲幅百分比
|
| 167 |
"""
|
| 168 |
try:
|
| 169 |
# 載入模型(如果尚未載入)
|
|
|
|
| 171 |
model_path = f"{model_name}.json"
|
| 172 |
if not self.load_model(model_path):
|
| 173 |
return None
|
| 174 |
+
|
| 175 |
+
# 預處理特徵
|
| 176 |
+
processed_df = self.preprocess_features(input_df.copy())
|
| 177 |
+
|
| 178 |
+
print("=== 模型輸入特徵檢查 ===")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
print(f"輸入形狀: {processed_df.shape}")
|
| 180 |
print("前5個特徵值:")
|
| 181 |
for i, col in enumerate(processed_df.columns[:5]):
|
| 182 |
print(f" {col}: {processed_df[col].iloc[0]:.6f}")
|
| 183 |
+
|
| 184 |
# 進行預測
|
| 185 |
predictions = self.model.predict(processed_df)
|
| 186 |
+
print(f"原始預測輸出形狀: {predictions.shape}")
|
| 187 |
print(f"原始預測值: {predictions}")
|
| 188 |
+
|
| 189 |
+
# 【修正】處理多輸出預測結果
|
| 190 |
+
if predictions.ndim == 1:
|
| 191 |
+
# 單輸出情況 - 只有一個時間點的預測
|
| 192 |
+
result = {
|
| 193 |
+
'Change_pct_t1_pred': float(predictions[0])
|
| 194 |
+
}
|
| 195 |
else:
|
| 196 |
+
# 多輸出情況:[t1, t5, t10, t20] - 對應訓練模型的四個輸出
|
| 197 |
result = {}
|
| 198 |
+
target_keys = ['Change_pct_t1_pred', 'Change_pct_t5_pred',
|
| 199 |
+
'Change_pct_t10_pred', 'Change_pct_t20_pred']
|
| 200 |
+
|
| 201 |
for i, key in enumerate(target_keys):
|
| 202 |
if i < predictions.shape[1]:
|
| 203 |
result[key] = float(predictions[0][i])
|
| 204 |
else:
|
| 205 |
result[key] = 0.0
|
| 206 |
+
|
| 207 |
+
# 輸出預測結果摘要
|
| 208 |
print("=== 漲幅預測結果 ===")
|
| 209 |
for key, value in result.items():
|
| 210 |
days = key.split('_')[2][1:]
|
| 211 |
direction = "上漲" if value > 0 else "下跌"
|
| 212 |
print(f" {days}日後預測: {value:+.2f}% ({direction})")
|
| 213 |
+
|
| 214 |
return result
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"預測過程中發生錯誤:{e}")
|
| 218 |
+
import traceback
|
| 219 |
+
traceback.print_exc()
|
| 220 |
+
return None
|
| 221 |
|
| 222 |
except Exception as e:
|
| 223 |
print(f"預測過程中發生錯誤:{e}")
|