AlanRex commited on
Commit
53bb8fb
·
verified ·
1 Parent(s): b1a597f

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +9 -11
model_predictor.py CHANGED
@@ -1,14 +1,13 @@
 
1
  import xgboost as xgb
2
  import pandas as pd
3
 
4
  class XGBoostModel:
5
- # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
6
  MODELS = {
7
  'xgboost_model': 'xgboost_model.json'
8
  }
9
 
10
  def __init__(self, default_model='xgboost_model'):
11
- # 建立物件時,自動載入預設模型
12
  self.current_model_name = default_model
13
  self.model = self._load_model(self.current_model_name)
14
 
@@ -18,28 +17,27 @@ class XGBoostModel:
18
 
19
  filename = self.MODELS[model_name]
20
  try:
21
- # 建立一個新的 XGBoost 模型實例
22
  model = xgb.XGBRegressor()
23
- # 使用 XGBoost 內建的 load_model 方法載入檔案
24
  model.load_model(filename)
25
  return model
26
  except Exception as e:
27
  raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
28
 
29
  def predict(self, model_name, input_df):
30
- # 如果請求的模型名稱與目前載入的不同,則動態載入
31
  if model_name != self.current_model_name:
32
  self.model = self._load_model(model_name)
33
  self.current_model_name = model_name
34
 
35
- # 進行預測
36
  predictions = self.model.predict(input_df)
37
 
38
- # 將預測結果轉換為字典
 
 
 
39
  result = {
40
- 'Close_t0_pred': predictions[0][0],
41
- 'Close_t5_pred': predictions[0][1],
42
- 'Close_t10_pred': predictions[0][2],
43
- 'Close_t20_pred': predictions[0][3]
44
  }
45
  return result
 
1
+ # model_predictor.py (修正版)
2
  import xgboost as xgb
3
  import pandas as pd
4
 
5
  class XGBoostModel:
 
6
  MODELS = {
7
  'xgboost_model': 'xgboost_model.json'
8
  }
9
 
10
  def __init__(self, default_model='xgboost_model'):
 
11
  self.current_model_name = default_model
12
  self.model = self._load_model(self.current_model_name)
13
 
 
17
 
18
  filename = self.MODELS[model_name]
19
  try:
 
20
  model = xgb.XGBRegressor()
 
21
  model.load_model(filename)
22
  return model
23
  except Exception as e:
24
  raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
25
 
26
  def predict(self, model_name, input_df):
 
27
  if model_name != self.current_model_name:
28
  self.model = self._load_model(model_name)
29
  self.current_model_name = model_name
30
 
 
31
  predictions = self.model.predict(input_df)
32
 
33
+ # 【【【主要修正點】】】
34
+ # 根據錯誤日誌推斷,當輸入只有一筆時,`predictions` 是一個一維陣列,
35
+ # 例如 [pred_t0, pred_t5, pred_t10, pred_t20]。
36
+ # 我們應該用 predictions[0], predictions[1] 的方式來取值,而不是 predictions[0][0]。
37
  result = {
38
+ 'Close_t0_pred': predictions[0],
39
+ 'Close_t5_pred': predictions[1],
40
+ 'Close_t10_pred': predictions[2],
41
+ 'Close_t20_pred': predictions[3]
42
  }
43
  return result