Spaces:
Sleeping
Sleeping
Delete predictor_logic.py
Browse files- predictor_logic.py +0 -65
predictor_logic.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import xgboost as xgb
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
# 定義一個類別來封裝模型載入和預測邏輯
|
| 6 |
-
class StockPredictor:
|
| 7 |
-
"""
|
| 8 |
-
這個類別負責載入模型和資料,並提供預測方法。
|
| 9 |
-
"""
|
| 10 |
-
def __init__(self, model_path="xgboost_model.json", data_path="期末專案輸入資料20220912-20250909.xlsx"):
|
| 11 |
-
"""
|
| 12 |
-
初始化預測器物件,載入模型和資料集。
|
| 13 |
-
可接受無引數,自動帶入預設路徑檔名。
|
| 14 |
-
"""
|
| 15 |
-
if not os.path.exists(model_path):
|
| 16 |
-
raise FileNotFoundError(f"找不到模型檔案: {model_path}")
|
| 17 |
-
if not os.path.exists(data_path):
|
| 18 |
-
raise FileNotFoundError(f"找不到資料檔案: {data_path}")
|
| 19 |
-
|
| 20 |
-
self.model = xgb.XGBRegressor()
|
| 21 |
-
self.model.load_model(model_path)
|
| 22 |
-
self.historical_df = pd.read_excel(data_path)
|
| 23 |
-
|
| 24 |
-
# 確保 '日期' 欄位是 datetime 格式
|
| 25 |
-
self.historical_df['日期'] = pd.to_datetime(self.historical_df['日期'], format='%Y%m%d')
|
| 26 |
-
|
| 27 |
-
# 確保欄位名稱正確
|
| 28 |
-
self.historical_df.rename(columns={
|
| 29 |
-
'加權指\n數開盤': 'Open',
|
| 30 |
-
'加權指\n數最高': 'High',
|
| 31 |
-
'加權指\n數最低': 'Low',
|
| 32 |
-
'加權指\n數收盤': 'Close',
|
| 33 |
-
'加權指數\n成交量': 'Volume'
|
| 34 |
-
}, inplace=True)
|
| 35 |
-
|
| 36 |
-
def predict(self, date_str):
|
| 37 |
-
"""
|
| 38 |
-
輸入一個日期字串,自動抓取前一筆資料進行預測。
|
| 39 |
-
返回預測結果。
|
| 40 |
-
"""
|
| 41 |
-
# 將輸入日期轉換為 datetime 物件
|
| 42 |
-
pred_date = pd.to_datetime(date_str)
|
| 43 |
-
|
| 44 |
-
# 尋找使用者輸入日期在資料集中的位置
|
| 45 |
-
current_row_index = self.historical_df[self.historical_df['日期'] == pred_date].index
|
| 46 |
-
|
| 47 |
-
if current_row_index.empty:
|
| 48 |
-
raise ValueError("找不到您輸入日期所對應的資料。")
|
| 49 |
-
|
| 50 |
-
# 取得前一筆資料的索引
|
| 51 |
-
prev_data_index = current_row_index[0] - 1
|
| 52 |
-
|
| 53 |
-
if prev_data_index < 0:
|
| 54 |
-
raise ValueError("資料集沒有前一筆資料可以進行預測。")
|
| 55 |
-
|
| 56 |
-
# 抓取前一筆資料
|
| 57 |
-
selected_row = self.historical_df.iloc[[prev_data_index]]
|
| 58 |
-
|
| 59 |
-
# 準備模型輸入
|
| 60 |
-
input_data = selected_row[['Open', 'High', 'Low', 'Volume']].values
|
| 61 |
-
|
| 62 |
-
# 進行預測,模型需要二維陣列的輸入
|
| 63 |
-
predicted_price = self.model.predict(input_data.reshape(1, -1))[0]
|
| 64 |
-
|
| 65 |
-
return predicted_price
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|