AlanRex commited on
Commit
7fd9c4d
·
verified ·
1 Parent(s): f5573c5

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +67 -12
model_predictor.py CHANGED
@@ -1,13 +1,15 @@
1
- # model_predictor.py (修正版)
2
  import xgboost as xgb
3
  import pandas as pd
 
4
 
5
  class XGBoostModel:
 
6
  MODELS = {
7
  'xgboost_model': 'xgboost_model.json'
8
  }
9
 
10
  def __init__(self, default_model='xgboost_model'):
 
11
  self.current_model_name = default_model
12
  self.model = self._load_model(self.current_model_name)
13
 
@@ -17,27 +19,80 @@ class XGBoostModel:
17
 
18
  filename = self.MODELS[model_name]
19
  try:
 
20
  model = xgb.XGBRegressor()
 
21
  model.load_model(filename)
22
  return model
23
  except Exception as e:
24
  raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
25
 
26
  def predict(self, model_name, input_df):
 
27
  if model_name != self.current_model_name:
28
  self.model = self._load_model(model_name)
29
  self.current_model_name = model_name
30
 
 
31
  predictions = self.model.predict(input_df)
32
-
33
- # 【【【主要修正點】】】
34
- # 根據錯誤日誌推斷,當輸入只有一筆時,`predictions` 是一個一維陣列,
35
- # 例如 [pred_t0, pred_t5, pred_t10, pred_t20]。
36
- # 我們應該用 predictions[0], predictions[1] 的方式來取值,而不是 predictions[0][0]。
37
- result = {
38
- 'Close_t0_pred': predictions[0],
39
- 'Close_t5_pred': predictions[1],
40
- 'Close_t10_pred': predictions[2],
41
- 'Close_t20_pred': predictions[3]
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return result
 
 
1
  import xgboost as xgb
2
  import pandas as pd
3
+ import numpy as np
4
 
5
  class XGBoostModel:
6
+ # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
7
  MODELS = {
8
  'xgboost_model': 'xgboost_model.json'
9
  }
10
 
11
  def __init__(self, default_model='xgboost_model'):
12
+ # 建立物件時,自動載入預設模型
13
  self.current_model_name = default_model
14
  self.model = self._load_model(self.current_model_name)
15
 
 
19
 
20
  filename = self.MODELS[model_name]
21
  try:
22
+ # 建立一個新的 XGBoost 模型實例
23
  model = xgb.XGBRegressor()
24
+ # 使用 XGBoost 內建的 load_model 方法載入檔案
25
  model.load_model(filename)
26
  return model
27
  except Exception as e:
28
  raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
29
 
30
  def predict(self, model_name, input_df):
31
+ # 如果請求的模型名稱與目前載入的不同,則動態載入
32
  if model_name != self.current_model_name:
33
  self.model = self._load_model(model_name)
34
  self.current_model_name = model_name
35
 
36
+ # 進行預測
37
  predictions = self.model.predict(input_df)
38
+
39
+ # 調試:印出預測結果的形狀和內容
40
+ print(f"預測結果形狀: {predictions.shape}")
41
+ print(f"預測結果類型: {type(predictions)}")
42
+ print(f"預測結果內容: {predictions}")
43
+
44
+ # 處理不同的輸出格式
45
+ try:
46
+ # 情況1: 如果是二維陣列且有4個預測值 (原始期望格式)
47
+ if len(predictions.shape) == 2 and predictions.shape[1] == 4:
48
+ result = {
49
+ 'Close_t0_pred': float(predictions[0][0]),
50
+ 'Close_t5_pred': float(predictions[0][1]),
51
+ 'Close_t10_pred': float(predictions[0][2]),
52
+ 'Close_t20_pred': float(predictions[0][3])
53
+ }
54
+
55
+ # 情況2: 如果是一維陣列且有4個預測值
56
+ elif len(predictions.shape) == 1 and len(predictions) == 4:
57
+ result = {
58
+ 'Close_t0_pred': float(predictions[0]),
59
+ 'Close_t5_pred': float(predictions[1]),
60
+ 'Close_t10_pred': float(predictions[2]),
61
+ 'Close_t20_pred': float(predictions[3])
62
+ }
63
+
64
+ # 情況3: 如果只有一個預測值(單一輸出模型)
65
+ elif len(predictions.shape) == 1 and len(predictions) == 1:
66
+ # 假設這個預測值代表最近期的預測,其他用相同值
67
+ pred_value = float(predictions[0])
68
+ result = {
69
+ 'Close_t0_pred': pred_value,
70
+ 'Close_t5_pred': pred_value,
71
+ 'Close_t10_pred': pred_value,
72
+ 'Close_t20_pred': pred_value
73
+ }
74
+
75
+ # 情況4: 如果是標量(單一數值)
76
+ elif np.isscalar(predictions):
77
+ pred_value = float(predictions)
78
+ result = {
79
+ 'Close_t0_pred': pred_value,
80
+ 'Close_t5_pred': pred_value,
81
+ 'Close_t10_pred': pred_value,
82
+ 'Close_t20_pred': pred_value
83
+ }
84
+
85
+ else:
86
+ # 其他情況:嘗試使用第一個預測值
87
+ pred_value = float(predictions.flatten()[0])
88
+ result = {
89
+ 'Close_t0_pred': pred_value,
90
+ 'Close_t5_pred': pred_value,
91
+ 'Close_t10_pred': pred_value,
92
+ 'Close_t20_pred': pred_value
93
+ }
94
+
95
+ except (IndexError, TypeError) as e:
96
+ raise ValueError(f"無法解析模型輸出格式。預測結果: {predictions}, 錯誤: {e}")
97
+
98
  return result