AITEST / model_predictor.py
AlanRex's picture
Upload 2 files
e658165 verified
raw
history blame
1.71 kB
import xgboost as xgb
import pandas as pd
class XGBoostModel:
# 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
MODELS = {
'xgboost_model': 'xgboost_model.json'
}
def __init__(self, default_model='xgboost_model'):
# 建立物件時,自動載入預設模型
self.current_model_name = default_model
self.model = self._load_model(self.current_model_name)
def _load_model(self, model_name):
if model_name not in self.MODELS:
raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
filename = self.MODELS[model_name]
try:
# 建立一個新的 XGBoost 模型實例
model = xgb.XGBRegressor()
# 使用 XGBoost 內建的 load_model 方法載入檔案
model.load_model(filename)
return model
except Exception as e:
raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
def predict(self, model_name, input_df):
# 如果請求的模型名稱與目前載入的不同,則動態載入
if model_name != self.current_model_name:
self.model = self._load_model(model_name)
self.current_model_name = model_name
# 進行預測
predictions = self.model.predict(input_df)
# 將預測結果轉換為字典
result = {
'Close_t0_pred': predictions[0][0],
'Close_t5_pred': predictions[0][1],
'Close_t10_pred': predictions[0][2],
'Close_t20_pred': predictions[0][3]
}
return result