AlanRex commited on
Commit
25d39de
·
verified ·
1 Parent(s): b3374d4

Upload 2 files

Browse files
Files changed (2) hide show
  1. model_predictor.py +45 -0
  2. xgboost_model.json +0 -0
model_predictor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import pandas as pd
3
+
4
+ class XGBoostModel:
5
+ # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
6
+ MODELS = {
7
+ 'xgboost_model': 'xgboost_model.json'
8
+ }
9
+
10
+ def __init__(self, default_model='xgboost_model'):
11
+ # 建立物件時,自動載入預設模型
12
+ self.current_model_name = default_model
13
+ self.model = self._load_model(self.current_model_name)
14
+
15
+ def _load_model(self, model_name):
16
+ if model_name not in self.MODELS:
17
+ raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
18
+
19
+ filename = self.MODELS[model_name]
20
+ try:
21
+ # 建立一個新的 XGBoost 模型實例
22
+ model = xgb.XGBRegressor()
23
+ # 使用 XGBoost 內建的 load_model 方法載入檔案
24
+ model.load_model(filename)
25
+ return model
26
+ except Exception as e:
27
+ raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
28
+
29
+ def predict(self, model_name, input_df):
30
+ # 如果請求的模型名稱與目前載入的不同,則動態載入
31
+ if model_name != self.current_model_name:
32
+ self.model = self._load_model(model_name)
33
+ self.current_model_name = model_name
34
+
35
+ # 進行預測
36
+ predictions = self.model.predict(input_df)
37
+
38
+ # 將預測結果轉換為字典
39
+ result = {
40
+ 'Close_t0_pred': predictions[0][0],
41
+ 'Close_t5_pred': predictions[0][1],
42
+ 'Close_t10_pred': predictions[0][2],
43
+ 'Close_t20_pred': predictions[0][3]
44
+ }
45
+ return result
xgboost_model.json ADDED
The diff for this file is too large to render. See raw diff