Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- .gitattributes +1 -0
- model_predictor.py +45 -0
- xgboost_model.json +3 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,4 @@ stock_lstm_model_v2.keras filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
9CE6ABB0E688BCE5A5B3E69920220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
期末專案輸入資料20220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 39 |
taiwan_stock_predictor.keras filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 37 |
9CE6ABB0E688BCE5A5B3E69920220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
期末專案輸入資料20220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 39 |
taiwan_stock_predictor.keras filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
xgboost_model.json filter=lfs diff=lfs merge=lfs -text
|
model_predictor.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
class XGBoostModel:
|
| 5 |
+
# 使用類別變數儲存所有可用的模型名稱及其對應的檔案名稱
|
| 6 |
+
MODELS = {
|
| 7 |
+
'xgboost_model': 'xgboost_model.json'
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
def __init__(self, default_model='xgboost_model'):
|
| 11 |
+
# 建立物件時,自動載入預設模型
|
| 12 |
+
self.current_model_name = default_model
|
| 13 |
+
self.model = self._load_model(self.current_model_name)
|
| 14 |
+
|
| 15 |
+
def _load_model(self, model_name):
|
| 16 |
+
if model_name not in self.MODELS:
|
| 17 |
+
raise ValueError(f"找不到模型 '{model_name}'。可用的模型名稱:{list(self.MODELS.keys())}")
|
| 18 |
+
|
| 19 |
+
filename = self.MODELS[model_name]
|
| 20 |
+
try:
|
| 21 |
+
# 建立一個新的 XGBoost 模型實例
|
| 22 |
+
model = xgb.XGBRegressor()
|
| 23 |
+
# 使用 XGBoost 內建的 load_model 方法載入檔案
|
| 24 |
+
model.load_model(filename)
|
| 25 |
+
return model
|
| 26 |
+
except Exception as e:
|
| 27 |
+
raise FileNotFoundError(f"無法在本地找到或載入模型檔案 '{filename}':{e}")
|
| 28 |
+
|
| 29 |
+
def predict(self, model_name, input_df):
|
| 30 |
+
# 如果請求的模型名稱與目前載入的不同,則動態載入
|
| 31 |
+
if model_name != self.current_model_name:
|
| 32 |
+
self.model = self._load_model(model_name)
|
| 33 |
+
self.current_model_name = model_name
|
| 34 |
+
|
| 35 |
+
# 進行預測
|
| 36 |
+
predictions = self.model.predict(input_df)
|
| 37 |
+
|
| 38 |
+
# 將預測結果轉換為字典
|
| 39 |
+
result = {
|
| 40 |
+
'Close_t0_pred': predictions[0][0],
|
| 41 |
+
'Close_t5_pred': predictions[0][1],
|
| 42 |
+
'Close_t10_pred': predictions[0][2],
|
| 43 |
+
'Close_t20_pred': predictions[0][3]
|
| 44 |
+
}
|
| 45 |
+
return result
|
xgboost_model.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:483a6c4c35047206bd2c8360e76b611512ac71cd336045f448b5d42f37277248
|
| 3 |
+
size 10692988
|