AlanRex commited on
Commit
e658165
·
verified ·
1 Parent(s): 9e3eb59

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. model_predictor.py +45 -0
  3. xgboost_model.json +3 -0
.gitattributes CHANGED
@@ -37,3 +37,4 @@ stock_lstm_model_v2.keras filter=lfs diff=lfs merge=lfs -text
37
  9CE6ABB0E688BCE5A5B3E69920220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
38
  期末專案輸入資料20220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
39
  taiwan_stock_predictor.keras filter=lfs diff=lfs merge=lfs -text
 
 
37
  9CE6ABB0E688BCE5A5B3E69920220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
38
  期末專案輸入資料20220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
39
  taiwan_stock_predictor.keras filter=lfs diff=lfs merge=lfs -text
40
+ xgboost_model.json filter=lfs diff=lfs merge=lfs -text
model_predictor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import pandas as pd
3
+
4
+ class XGBoostModel:
5
+ # 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
6
+ MODELS = {
7
+ 'xgboost_model': 'xgboost_model.json'
8
+ }
9
+
10
+ def __init__(self, default_model='xgboost_model'):
11
+ # 建立物件時,自動載入預設模型
12
+ self.current_model_name = default_model
13
+ self.model = self._load_model(self.current_model_name)
14
+
15
+ def _load_model(self, model_name):
16
+ if model_name not in self.MODELS:
17
+ raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
18
+
19
+ filename = self.MODELS[model_name]
20
+ try:
21
+ # 建立一個新的 XGBoost 模型實例
22
+ model = xgb.XGBRegressor()
23
+ # 使用 XGBoost 內建的 load_model 方法載入檔案
24
+ model.load_model(filename)
25
+ return model
26
+ except Exception as e:
27
+ raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
28
+
29
+ def predict(self, model_name, input_df):
30
+ # 如果請求的模型名稱與目前載入的不同,則動態載入
31
+ if model_name != self.current_model_name:
32
+ self.model = self._load_model(model_name)
33
+ self.current_model_name = model_name
34
+
35
+ # 進行預測
36
+ predictions = self.model.predict(input_df)
37
+
38
+ # 將預測結果轉換為字典
39
+ result = {
40
+ 'Close_t0_pred': predictions[0][0],
41
+ 'Close_t5_pred': predictions[0][1],
42
+ 'Close_t10_pred': predictions[0][2],
43
+ 'Close_t20_pred': predictions[0][3]
44
+ }
45
+ return result
xgboost_model.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:483a6c4c35047206bd2c8360e76b611512ac71cd336045f448b5d42f37277248
3
+ size 10692988