Spaces:
Running
Running
Update model_predictor.py
Browse files- model_predictor.py +9 -11
model_predictor.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
|
|
| 1 |
import xgboost as xgb
|
| 2 |
import pandas as pd
|
| 3 |
|
| 4 |
class XGBoostModel:
|
| 5 |
-
# 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
|
| 6 |
MODELS = {
|
| 7 |
'xgboost_model': 'xgboost_model.json'
|
| 8 |
}
|
| 9 |
|
| 10 |
def __init__(self, default_model='xgboost_model'):
|
| 11 |
-
# 建立物件時,自動載入預設模型
|
| 12 |
self.current_model_name = default_model
|
| 13 |
self.model = self._load_model(self.current_model_name)
|
| 14 |
|
|
@@ -18,28 +17,27 @@ class XGBoostModel:
|
|
| 18 |
|
| 19 |
filename = self.MODELS[model_name]
|
| 20 |
try:
|
| 21 |
-
# 建立一個新的 XGBoost 模型實例
|
| 22 |
model = xgb.XGBRegressor()
|
| 23 |
-
# 使用 XGBoost 內建的 load_model 方法載入檔案
|
| 24 |
model.load_model(filename)
|
| 25 |
return model
|
| 26 |
except Exception as e:
|
| 27 |
raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
|
| 28 |
|
| 29 |
def predict(self, model_name, input_df):
|
| 30 |
-
# 如果請求的模型名稱與目前載入的不同,則動態載入
|
| 31 |
if model_name != self.current_model_name:
|
| 32 |
self.model = self._load_model(model_name)
|
| 33 |
self.current_model_name = model_name
|
| 34 |
|
| 35 |
-
# 進行預測
|
| 36 |
predictions = self.model.predict(input_df)
|
| 37 |
|
| 38 |
-
#
|
|
|
|
|
|
|
|
|
|
| 39 |
result = {
|
| 40 |
-
'Close_t0_pred': predictions[0]
|
| 41 |
-
'Close_t5_pred': predictions[
|
| 42 |
-
'Close_t10_pred': predictions[
|
| 43 |
-
'Close_t20_pred': predictions[
|
| 44 |
}
|
| 45 |
return result
|
|
|
|
| 1 |
+
# model_predictor.py (修正版)
|
| 2 |
import xgboost as xgb
|
| 3 |
import pandas as pd
|
| 4 |
|
| 5 |
class XGBoostModel:
|
|
|
|
| 6 |
MODELS = {
|
| 7 |
'xgboost_model': 'xgboost_model.json'
|
| 8 |
}
|
| 9 |
|
| 10 |
def __init__(self, default_model='xgboost_model'):
|
|
|
|
| 11 |
self.current_model_name = default_model
|
| 12 |
self.model = self._load_model(self.current_model_name)
|
| 13 |
|
|
|
|
| 17 |
|
| 18 |
filename = self.MODELS[model_name]
|
| 19 |
try:
|
|
|
|
| 20 |
model = xgb.XGBRegressor()
|
|
|
|
| 21 |
model.load_model(filename)
|
| 22 |
return model
|
| 23 |
except Exception as e:
|
| 24 |
raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
|
| 25 |
|
| 26 |
def predict(self, model_name, input_df):
|
|
|
|
| 27 |
if model_name != self.current_model_name:
|
| 28 |
self.model = self._load_model(model_name)
|
| 29 |
self.current_model_name = model_name
|
| 30 |
|
|
|
|
| 31 |
predictions = self.model.predict(input_df)
|
| 32 |
|
| 33 |
+
# 【【【主要修正點】】】
|
| 34 |
+
# 根據錯誤日誌推斷,當輸入只有一筆時,`predictions` 是一個一維陣列,
|
| 35 |
+
# 例如 [pred_t0, pred_t5, pred_t10, pred_t20]。
|
| 36 |
+
# 我們應該用 predictions[0], predictions[1] 的方式來取值,而不是 predictions[0][0]。
|
| 37 |
result = {
|
| 38 |
+
'Close_t0_pred': predictions[0],
|
| 39 |
+
'Close_t5_pred': predictions[1],
|
| 40 |
+
'Close_t10_pred': predictions[2],
|
| 41 |
+
'Close_t20_pred': predictions[3]
|
| 42 |
}
|
| 43 |
return result
|